NAS-Bench-301
基类:Graph
此类表示刘等人在以下论文中概述的 CIFAR-10 搜索空间:
Liu et al., 2019. "DARTS: Differentiable Architecture Search"
搜索空间包括一个预定义的、未优化的宏图,以及两种可学习单元:正常单元和归约单元。每条边包含 8 种基本操作。
属性
名称 | 类型 | 描述 |
---|---|---|
OPTIMIZER_SCOPE |
List[str]
|
优化过程中同一单元不同实例的目标。单元被分为正常/归约单元类型和阶段。这种划分对于在每个阶段设置正确的通道至关重要。架构优化器应平等对待所有实例。 |
QUERYABLE |
bool
|
指示搜索空间是否可查询的标志。 |
sample_without_replacement = False
instance-attribute
使用 init 中指定的参数构建搜索空间。
__init__(n_classes=10, in_channels=3, auxiliary=True)
构造 DARTS 搜索空间的新实例。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
n_classes |
int
|
考虑的类别数量。默认为 10。 |
10
|
in_channels |
int
|
输入通道的数量。默认为 3。 |
3
|
auxiliary |
bool
|
启用或禁用辅助输出的标志。默认为 True。 |
True
|
请注意,由于 networkx 的实现,init 方法不接受参数。要更改类别数量,应在类初始化之前设置静态属性 NUM_CLASSES
。CIFAR-10 的默认值为 10。
auxiliary_logits()
从模型图中获取辅助对数几率(logits)。
返回
类型 | 描述 |
---|---|
torch.Tensor
|
torch.Tensor:从模型图中获取的辅助对数几率(logits)。 |
encode(encoding_type=EncodingType.ADJACENCY_ONE_HOT)
将架构图编码为指定的编码类型。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
encoding_type |
EncodingType
|
要使用的编码类型。默认为 EncodingType.ADJACENCY_ONE_HOT。 |
EncodingType.ADJACENCY_ONE_HOT
|
返回
名称 | 类型 | 描述 |
---|---|---|
Any |
架构的编码表示。 |
forward_before_global_avg_pool(x)
运行模型的前向传播直到全局平均池化层。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
x |
torch.Tensor
|
输入张量。 |
必需 |
返回
名称 | 类型 | 描述 |
---|---|---|
list |
list
|
各层输出张量的列表。 |
get_arch_iterator(dataset_api)
获取 nasbench301 数据中架构的迭代器。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
dict
|
数据集 API。 |
必需 |
返回
名称 | 类型 | 描述 |
---|---|---|
迭代器 |
迭代器
|
架构的迭代器。 |
get_compact()
获取架构的紧凑表示。如果模型已实例化且紧凑表示不存在,则将模型转换为紧凑形式。
返回
名称 | 类型 | 描述 |
---|---|---|
tuple |
tuple
|
架构的紧凑形式。 |
get_configspace(path_to_configspace_obj=os.path.join(get_project_root(), 'search_spaces/nasbench301/configspace.json'))
staticmethod
返回搜索空间的配置空间对象。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
path_to_configspace_obj |
str
|
ConfigSpace JSON 编码的路径。 |
os.path.join(get_project_root(), 'search_spaces/nasbench301/configspace.json')
|
返回
类型 | 描述 |
---|---|
ConfigSpace.ConfigutationSpace:一个 ConfigSpace 对象。 |
get_hash()
获取架构的紧凑哈希值。
返回
名称 | 类型 | 描述 |
---|---|---|
tuple |
tuple
|
架构的哈希值。 |
get_loss_fn()
获取用于训练架构的损失函数。
返回
名称 | 类型 | 描述 |
---|---|---|
Callable |
Callable
|
损失函数。 |
get_nbhd(dataset_api=None)
获取当前架构的所有邻居。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
dict
|
数据集 API。 |
None
|
返回
名称 | 类型 | 描述 |
---|---|---|
list |
list
|
当前架构所有邻居的列表。 |
get_type()
获取架构的类型。
返回
名称 | 类型 | 描述 |
---|---|---|
str |
str
|
架构的类型。 |
load_labeled_architecture(dataset_api=None)
从 NasBench301 训练数据中加载一个随机架构,并更新图对象以匹配该架构。此方法应由尚未离散化的新的 NasBench301SearchSpace() 对象调用。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
dict
|
包含架构信息的数据集 API。 |
None
|
mutate(parent, mutation_rate=1, dataset_api=None)
通过改变父架构中的一个操作来变异架构,然后更新 naslib 对象和 op_indices。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
parent |
Graph
|
父架构图。 |
必需 |
mutation_rate |
int
|
变异率。默认为 1。 |
1
|
dataset_api |
dict
|
数据集 API。 |
None
|
prepare_discretization()
准备图以进行离散化。
在此搜索空间中,一个节点最多可以有两条入边。此方法确保满足此条件,为进一步离散化准备图。
prepare_evaluation()
此方法准备模型用于评估。在 DARTS 中,评估模型在 stem 之后有 32 个通道,并且每个阶段包含 3 个正常单元。
query(metric=None, dataset=None, path=None, epoch=-1, full_lc=False, dataset_api=None)
查询 NasBench301 的结果。如果架构是从 NasBench301 训练数据中加载的,则可以查询特定 epoch 的训练损失或验证准确率。否则,只能使用 NasBench301 查询 epoch 100 的验证准确率。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
metric |
Metric
|
要查询的所需指标。 |
None
|
dataset |
str
|
要使用的数据集。目前仅支持 'cifar10' 数据集。 |
None
|
path |
str
|
保存的模型路径。 |
None
|
epoch |
int
|
要查询的特定 epoch。默认为 -1。 |
-1
|
full_lc |
bool
|
指示是否应返回完整学习曲线的标志。默认为 False。 |
False
|
dataset_api |
dict
|
用于查询模型的数据集 API。 |
None
|
返回
类型 | 描述 |
---|---|
Union[float, dict]
|
Union[float, dict]:查询结果。 |
抛出异常
类型 | 描述 |
---|---|
NotImplementedError
|
如果 dataset_api 为 None。 |
AssertionError
|
如果数据集不是 'cifar10' 或 None。 |
sample_random_architecture(dataset_api=None, load_labeled=False)
采样一个随机架构并相应地更新 naslib 对象中的边。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
dict
|
数据集 API。如果 load_labeled 为 True,则必需。 |
None
|
load_labeled |
bool
|
是否从训练数据中加载架构。 |
False
|
sample_random_labeled_architecture()
从标记的架构中采样一个随机架构。
set_compact(compact)
设置架构的紧凑表示。如果模型已实例化且紧凑形式不存在,则将紧凑表示转换为模型。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
compact |
tuple
|
架构的紧凑形式。 |
必需 |
set_spec(compact, dataset_api=None)
设置架构规范,使其不可变。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
compact |
tuple
|
架构的紧凑形式。 |
必需 |
dataset_api |
dict
|
数据集 API。 |
None
|