跳到内容

NAS-Bench-101

基础类:Graph

代表 NAS-Bench-101 的搜索空间,这是一个包含神经网络架构及其相关性能的数据集。

此类继承自 Graph 类,并提供了处理架构规范(表示)、将其转换为不同形式、查询性能指标以及采样架构的方法。

参数

名称 类型 描述 默认值
n_classes int

分类任务的类别数。默认值为 10。

10

属性

名称 类型 描述
QUERYABLE bool

指示此搜索空间是否可查询的标志。对于 NAS-Bench-101,始终为 True。

num_classes int

分类任务的类别数。

space_name str

搜索空间的名称。

spec dict or None

当前架构的字典表示。默认值为 None。

labeled_archs list

用于采样的已标记架构列表。

instantiate_model bool

如果为 True,则在设置新规范时会实例化一个模型。

sample_without_replacement bool

如果为 True,则一旦采样,该架构将从可用架构列表中移除。

convert_to_cell(matrix, ops)

将给定的矩阵和操作转换为 NAS-Bench-101 单元,表示为一个字典。

该方法通过始终返回 7x7 矩阵来确保邻接矩阵与 NAS-Bench-101 API 的兼容性。如果输入矩阵小于 7x7,方法将相应地添加空白行/列。

参数

名称 类型 描述 默认值
matrix np.ndarray

单元的邻接矩阵。

必需
ops list

单元中的操作列表。

必需

返回值

名称 类型 描述
dict dict

NAS-Bench-101 单元的字典表示。包含键 'matrix' 和 'ops'。

encode(encoding_type=EncodingType.ADJACENCY_ONE_HOT)

使用给定的编码类型对当前架构进行编码。

参数

名称 类型 描述 默认值
encoding_type EncodingType

要使用的编码类型。默认为 ADJACENCY_ONE_HOT。

EncodingType.ADJACENCY_ONE_HOT

返回值

类型 描述
Union[List, np.ndarray, dict]

编码后的架构。

forward_before_global_avg_pool(x)

应用架构的前向传播,直到全局平均池化层。保存并返回中间输出。

参数

名称 类型 描述 默认值
x torch.Tensor

输入张量。

必需

返回值

名称 类型 描述
list list

前向传播的中间输出。

get_arch_iterator(dataset_api)

获取 NAS-Bench-101 数据集中所有架构的迭代器。

参数

名称 类型 描述 默认值
dataset_api dict

NAS-Bench-101 数据集的 API。

必需

返回值

名称 类型 描述
Iterator Iterator

NAS-Bench-101 数据集中所有架构的迭代器。

get_hash()

检索当前架构的哈希值。

返回值

名称 类型 描述
tuple tuple

当前架构的哈希值。

get_loss_fn()

返回优化期间使用的损失函数。

返回值

名称 类型 描述
Callable Callable

PyTorch 框架中的交叉熵损失函数。

get_nbhd(dataset_api)

检索当前架构的所有有效邻居。该方法同时考虑操作和边缘邻居。

参数

名称 类型 描述 默认值
dataset_api dict

NAS-Bench-101 数据集的 API。

必需

返回值

名称 类型 描述
list list

所有有效邻居架构的列表。

get_spec()

返回当前架构规范(表示)。

返回值

名称 类型 描述
dict dict

当前架构的规范。

get_type()

返回搜索空间的类型,在此情况下是 'nasbench101'。

返回值

名称 类型 描述
str str

搜索空间的类型。

mutate(parent, dataset_api, edits=1)

通过以一定概率翻转边缘和改变操作来变异给定的父架构。结果架构被设置为当前规范。

参数

名称 类型 描述 默认值
parent Graph

要进行变异的父图。

必需
dataset_api dict

NAS-Bench-101 数据集的 API。

必需
edits int

要应用的变异次数。默认值为 1。

1

代码灵感来自 https://github.com/google-research/nasbench

query(metric, dataset='cifar10', path=None, epoch=-1, full_lc=False, dataset_api=None)

从 NAS-Bench-101 数据集查询当前架构的性能指标。

参数

名称 类型 描述 默认值
metric Metric

要查询的性能指标。

必需
dataset str

要查询指标的数据集。目前仅支持 "cifar10"。默认值为 "cifar10"。

'cifar10'
path str

NAS-Bench-101 数据集的路径。

None
epoch int

要查询指标的 epoch。如果为 -1,则返回所有可用 epoch 的指标。默认值为 -1。

-1
full_lc bool

如果为 True,则返回完整的学习曲线。默认值为 False。

False
dataset_api dict

NAS-Bench-101 数据集的 API。

None

返回值

类型 描述
Union[list, float]

list 或 float:从 NAS-Bench-101 数据集查询到的指标结果。

引发

类型 描述
AssertionError

如果数据集未知,或 epoch 不在 NAS-Bench-101 中可用 epoch 之列,或如果架构的规范为 None。

NotImplementedError

如果未提供 metric 或 dataset_api。

sample_random_architecture(dataset_api, load_labeled=False)

采样一个随机架构,并相应地更新 naslib 对象中的边缘。

如果 load_labeled 为 True,则改为调用 sample_random_labeled_architecture() 方法。

参数

名称 类型 描述 默认值
dataset_api dict

NAS-Bench-101 数据集的 API。

必需
load_labeled bool

指示是否加载已标记的架构。默认值为 False。

False

sample_random_labeled_architecture()

从 NAS-Bench-101 数据集中可用架构列表中采样一个随机的已标记架构。

架构采样后,如果 sample_without_replacement 属性为 True,则将其从池中移除。然后将采样到的架构设置为当前规范。

引发

类型 描述
AssertionError

如果未提供已标记的架构。

set_spec(spec, dataset_api=None)

使用给定的表示设置架构的规范。

规范可以是字符串(哈希)、包含矩阵和操作的字典或元组(NASLib 表示)。

参数

名称 类型 描述 默认值
spec str or dict or tuple

要为架构设置的规范。

必需
dataset_api dict

NAS-Bench-101 数据集的 API。

None

引发

类型 描述
AssertionError

如果规范的类型不是 str、dict 或 tuple。