跳到内容

NAS-Bench-201

基类: Graph

表示 NASBench201 搜索空间。

该类提供用于查询和操作搜索空间内架构的方法,包括变异和随机采样的方法。

属性

名称 类型 描述
num_classes int

分类任务的类别数量。

in_channels int

输入通道的数量。

max_epoch int

训练的最大 epoch 数。

space_name str

搜索空间的名称。

labeled_archs list

已标注架构的列表。

instantiate_model bool

布尔值,指示是否在初始化期间实例化模型。

sample_without_replacement bool

布尔值,指示是否进行无放回采样架构。

channels list

架构不同阶段的通道数量。

op_indices list

操作的索引。

OPTIMIZER_SCOPE list

架构中阶段的列表,在优化期间进行范围界定时很有用。

QUERYABLE bool

布尔值,指示搜索空间是否可查询。

__init__(n_classes=10, in_channels=3)

构造方法。

这会使用提供的类别数量和输入通道初始化 NasBench201SearchSpace 对象。

参数

名称 类型 描述 默认
n_classes int

分类任务的类别数量。默认为 10。

10
in_channels int

输入通道的数量。默认为 3。

3

encode(encoding_type=EncodingType.ADJACENCY_ONE_HOT)

根据给定的编码类型对当前架构进行编码。

参数

名称 类型 描述 默认
encoding_type EncodingType

架构的编码类型。

EncodingType.ADJACENCY_ONE_HOT

返回值

名称 类型 描述
Any Union[List, np.ndarray, dict]

编码后的架构。返回类型取决于选择的编码类型。

抛出

类型 描述
NotImplementedError

如果给定的编码类型尚未被 nb201 支持作为架构编码。

forward_before_global_avg_pool(x)

执行直到全局平均池化层的正向传播并返回输出。

参数

名称 类型 描述 默认
x torch.Tensor

输入张量。

必需

返回值

名称 类型 描述
list list

正向传播的输出列表。

get_arch_iterator(dataset_api=None)

返回搜索空间中所有可能架构的迭代器。该迭代器是图中每条边操作数量的乘积。

参数

名称 类型 描述 默认
dataset_api 可选

数据集 API。默认为 None。

None

返回值

名称 类型 描述
Iterator Iterator

所有可能架构的迭代器。

get_hash()

获取架构的哈希表示。哈希是操作索引的元组。

返回值

名称 类型 描述
tuple tuple

架构的哈希。

get_loss_fn()

返回用于此架构的损失函数。

返回值

名称 类型 描述
Callable Callable

可用作损失函数的可调用对象(交叉熵损失函数)。

get_nbhd(dataset_api=None)

返回架构的所有邻居。

参数

名称 类型 描述 默认
dataset_api dict

包含 nasbench201 数据的 API。默认为 None。

None

返回值

名称 类型 描述
list list

邻居模型的列表。

get_op_indices()

获取架构的操作索引。如果尚未定义,它将把 naslib 对象转换为操作索引并保存它们。

返回值

名称 类型 描述
list list

架构的操作索引。

get_type()

返回搜索空间的类型。

返回值

名称 类型 描述
str str

搜索空间的类型,在此例中为 "nasbench201"。

mutate(parent, dataset_api=None)

从父操作索引中变异一个操作,并将其设置为当前对象的操作索引。

参数

名称 类型 描述 默认
parent Graph

用于变异的父 Graph 对象。

必需
dataset_api dict

包含 nasbench201 数据的 API。默认为 None。

None

query(metric, dataset, path=None, epoch=-1, full_lc=False, dataset_api=None)

根据指定的指标和数据集从 nasbench201 数据库查询结果。

参数

名称 类型 描述 默认
metric Metric

要查询的性能指标。

必需
dataset str

要查询的数据集。

必需
path str

nasbench201 数据库的路径。默认为 None。

None
epoch int

要查询的训练 epoch。默认为 -1,表示最后一个 epoch。

-1
full_lc bool

如果为 True,返回完整的学习曲线。默认为 False。

False
dataset_api dict

包含 nasbench201 数据的 API。默认为 None。

None

抛出

类型 描述
NotImplementedError

如果 metric 是 Metric.ALL 或未提供 dataset_api

返回值

名称 类型 描述
float float

查询到的结果。

sample_random_architecture(dataset_api=None, load_labeled=False)

采样一个随机架构并将其设置为当前架构。

参数

名称 类型 描述 默认
dataset_api dict

包含 nasbench201 数据的 API。默认为 None。

None
load_labeled bool

如果为 True,则改为采样一个随机已标注架构。默认为 False。

False

sample_random_labeled_architecture()

采样一个随机已标注架构并将其设置为当前架构。

set_op_indices(op_indices)

设置当前架构的操作索引。如果模型应被实例化,它将把操作索引转换为 naslib 对象。

参数

名称 类型 描述 默认
op_indices list

要设置的操作索引列表。

必需

set_spec(op_indices, dataset_api=None)

设置架构的规格。

参数

名称 类型 描述 默认
op_indices list

要设置的操作索引列表。

必需
dataset_api 可选

数据集 API。默认为 None。

None