NAS-Bench-101
基础类:Graph
代表 NAS-Bench-101 的搜索空间,这是一个包含神经网络架构及其相关性能的数据集。
此类继承自 Graph 类,并提供了处理架构规范(表示)、将其转换为不同形式、查询性能指标以及采样架构的方法。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
n_classes |
int
|
分类任务的类别数。默认值为 10。 |
10
|
属性
名称 | 类型 | 描述 |
---|---|---|
QUERYABLE |
bool
|
指示此搜索空间是否可查询的标志。对于 NAS-Bench-101,始终为 True。 |
num_classes |
int
|
分类任务的类别数。 |
space_name |
str
|
搜索空间的名称。 |
spec |
dict or None
|
当前架构的字典表示。默认值为 None。 |
labeled_archs |
list
|
用于采样的已标记架构列表。 |
instantiate_model |
bool
|
如果为 True,则在设置新规范时会实例化一个模型。 |
sample_without_replacement |
bool
|
如果为 True,则一旦采样,该架构将从可用架构列表中移除。 |
convert_to_cell(matrix, ops)
将给定的矩阵和操作转换为 NAS-Bench-101 单元,表示为一个字典。
该方法通过始终返回 7x7 矩阵来确保邻接矩阵与 NAS-Bench-101 API 的兼容性。如果输入矩阵小于 7x7,方法将相应地添加空白行/列。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
matrix |
np.ndarray
|
单元的邻接矩阵。 |
必需 |
ops |
list
|
单元中的操作列表。 |
必需 |
返回值
名称 | 类型 | 描述 |
---|---|---|
dict |
dict
|
NAS-Bench-101 单元的字典表示。包含键 'matrix' 和 'ops'。 |
encode(encoding_type=EncodingType.ADJACENCY_ONE_HOT)
使用给定的编码类型对当前架构进行编码。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
encoding_type |
EncodingType
|
要使用的编码类型。默认为 ADJACENCY_ONE_HOT。 |
EncodingType.ADJACENCY_ONE_HOT
|
返回值
类型 | 描述 |
---|---|
Union[List, np.ndarray, dict]
|
编码后的架构。 |
forward_before_global_avg_pool(x)
应用架构的前向传播,直到全局平均池化层。保存并返回中间输出。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
x |
torch.Tensor
|
输入张量。 |
必需 |
返回值
名称 | 类型 | 描述 |
---|---|---|
list |
list
|
前向传播的中间输出。 |
get_arch_iterator(dataset_api)
获取 NAS-Bench-101 数据集中所有架构的迭代器。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
dict
|
NAS-Bench-101 数据集的 API。 |
必需 |
返回值
名称 | 类型 | 描述 |
---|---|---|
Iterator |
Iterator
|
NAS-Bench-101 数据集中所有架构的迭代器。 |
get_hash()
检索当前架构的哈希值。
返回值
名称 | 类型 | 描述 |
---|---|---|
tuple |
tuple
|
当前架构的哈希值。 |
get_loss_fn()
返回优化期间使用的损失函数。
返回值
名称 | 类型 | 描述 |
---|---|---|
Callable |
Callable
|
PyTorch 框架中的交叉熵损失函数。 |
get_nbhd(dataset_api)
检索当前架构的所有有效邻居。该方法同时考虑操作和边缘邻居。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
dict
|
NAS-Bench-101 数据集的 API。 |
必需 |
返回值
名称 | 类型 | 描述 |
---|---|---|
list |
list
|
所有有效邻居架构的列表。 |
get_spec()
返回当前架构规范(表示)。
返回值
名称 | 类型 | 描述 |
---|---|---|
dict |
dict
|
当前架构的规范。 |
get_type()
返回搜索空间的类型,在此情况下是 'nasbench101'。
返回值
名称 | 类型 | 描述 |
---|---|---|
str |
str
|
搜索空间的类型。 |
mutate(parent, dataset_api, edits=1)
通过以一定概率翻转边缘和改变操作来变异给定的父架构。结果架构被设置为当前规范。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
parent |
Graph
|
要进行变异的父图。 |
必需 |
dataset_api |
dict
|
NAS-Bench-101 数据集的 API。 |
必需 |
edits |
int
|
要应用的变异次数。默认值为 1。 |
1
|
代码灵感来自 https://github.com/google-research/nasbench
query(metric, dataset='cifar10', path=None, epoch=-1, full_lc=False, dataset_api=None)
从 NAS-Bench-101 数据集查询当前架构的性能指标。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
metric |
Metric
|
要查询的性能指标。 |
必需 |
dataset |
str
|
要查询指标的数据集。目前仅支持 "cifar10"。默认值为 "cifar10"。 |
'cifar10'
|
path |
str
|
NAS-Bench-101 数据集的路径。 |
None
|
epoch |
int
|
要查询指标的 epoch。如果为 -1,则返回所有可用 epoch 的指标。默认值为 -1。 |
-1
|
full_lc |
bool
|
如果为 True,则返回完整的学习曲线。默认值为 False。 |
False
|
dataset_api |
dict
|
NAS-Bench-101 数据集的 API。 |
None
|
返回值
类型 | 描述 |
---|---|
Union[list, float]
|
list 或 float:从 NAS-Bench-101 数据集查询到的指标结果。 |
引发
类型 | 描述 |
---|---|
AssertionError
|
如果数据集未知,或 epoch 不在 NAS-Bench-101 中可用 epoch 之列,或如果架构的规范为 None。 |
NotImplementedError
|
如果未提供 metric 或 dataset_api。 |
sample_random_architecture(dataset_api, load_labeled=False)
采样一个随机架构,并相应地更新 naslib 对象中的边缘。
如果 load_labeled
为 True,则改为调用 sample_random_labeled_architecture()
方法。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
dict
|
NAS-Bench-101 数据集的 API。 |
必需 |
load_labeled |
bool
|
指示是否加载已标记的架构。默认值为 False。 |
False
|
sample_random_labeled_architecture()
从 NAS-Bench-101 数据集中可用架构列表中采样一个随机的已标记架构。
架构采样后,如果 sample_without_replacement 属性为 True,则将其从池中移除。然后将采样到的架构设置为当前规范。
引发
类型 | 描述 |
---|---|
AssertionError
|
如果未提供已标记的架构。 |
set_spec(spec, dataset_api=None)
使用给定的表示设置架构的规范。
规范可以是字符串(哈希)、包含矩阵和操作的字典或元组(NASLib 表示)。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
spec |
str or dict or tuple
|
要为架构设置的规范。 |
必需 |
dataset_api |
dict
|
NAS-Bench-101 数据集的 API。 |
None
|
引发
类型 | 描述 |
---|---|
AssertionError
|
如果规范的类型不是 str、dict 或 tuple。 |