跳到内容

NAS-Bench-301

基类:Graph

此类表示刘等人在以下论文中概述的 CIFAR-10 搜索空间:

Liu et al., 2019. "DARTS: Differentiable Architecture Search"

搜索空间包括一个预定义的、未优化的宏图,以及两种可学习单元:正常单元和归约单元。每条边包含 8 种基本操作。

属性

名称 类型 描述
OPTIMIZER_SCOPE List[str]

优化过程中同一单元不同实例的目标。单元被分为正常/归约单元类型和阶段。这种划分对于在每个阶段设置正确的通道至关重要。架构优化器应平等对待所有实例。

QUERYABLE bool

指示搜索空间是否可查询的标志。

sample_without_replacement = False instance-attribute

使用 init 中指定的参数构建搜索空间。

__init__(n_classes=10, in_channels=3, auxiliary=True)

构造 DARTS 搜索空间的新实例。

参数

名称 类型 描述 默认值
n_classes int

考虑的类别数量。默认为 10。

10
in_channels int

输入通道的数量。默认为 3。

3
auxiliary bool

启用或禁用辅助输出的标志。默认为 True。

True

请注意,由于 networkx 的实现,init 方法不接受参数。要更改类别数量,应在类初始化之前设置静态属性 NUM_CLASSES。CIFAR-10 的默认值为 10。

auxiliary_logits()

从模型图中获取辅助对数几率(logits)。

返回

类型 描述
torch.Tensor

torch.Tensor:从模型图中获取的辅助对数几率(logits)。

encode(encoding_type=EncodingType.ADJACENCY_ONE_HOT)

将架构图编码为指定的编码类型。

参数

名称 类型 描述 默认值
encoding_type EncodingType

要使用的编码类型。默认为 EncodingType.ADJACENCY_ONE_HOT。

EncodingType.ADJACENCY_ONE_HOT

返回

名称 类型 描述
Any

架构的编码表示。

forward_before_global_avg_pool(x)

运行模型的前向传播直到全局平均池化层。

参数

名称 类型 描述 默认值
x torch.Tensor

输入张量。

必需

返回

名称 类型 描述
list list

各层输出张量的列表。

get_arch_iterator(dataset_api)

获取 nasbench301 数据中架构的迭代器。

参数

名称 类型 描述 默认值
dataset_api dict

数据集 API。

必需

返回

名称 类型 描述
迭代器 迭代器

架构的迭代器。

get_compact()

获取架构的紧凑表示。如果模型已实例化且紧凑表示不存在,则将模型转换为紧凑形式。

返回

名称 类型 描述
tuple tuple

架构的紧凑形式。

get_configspace(path_to_configspace_obj=os.path.join(get_project_root(), 'search_spaces/nasbench301/configspace.json')) staticmethod

返回搜索空间的配置空间对象。

参数

名称 类型 描述 默认值
path_to_configspace_obj str

ConfigSpace JSON 编码的路径。

os.path.join(get_project_root(), 'search_spaces/nasbench301/configspace.json')

返回

类型 描述

ConfigSpace.ConfigutationSpace:一个 ConfigSpace 对象。

get_hash()

获取架构的紧凑哈希值。

返回

名称 类型 描述
tuple tuple

架构的哈希值。

get_loss_fn()

获取用于训练架构的损失函数。

返回

名称 类型 描述
Callable Callable

损失函数。

get_nbhd(dataset_api=None)

获取当前架构的所有邻居。

参数

名称 类型 描述 默认值
dataset_api dict

数据集 API。

None

返回

名称 类型 描述
list list

当前架构所有邻居的列表。

get_type()

获取架构的类型。

返回

名称 类型 描述
str str

架构的类型。

load_labeled_architecture(dataset_api=None)

从 NasBench301 训练数据中加载一个随机架构,并更新图对象以匹配该架构。此方法应由尚未离散化的新的 NasBench301SearchSpace() 对象调用。

参数

名称 类型 描述 默认值
dataset_api dict

包含架构信息的数据集 API。

None

mutate(parent, mutation_rate=1, dataset_api=None)

通过改变父架构中的一个操作来变异架构,然后更新 naslib 对象和 op_indices。

参数

名称 类型 描述 默认值
parent Graph

父架构图。

必需
mutation_rate int

变异率。默认为 1。

1
dataset_api dict

数据集 API。

None

prepare_discretization()

准备图以进行离散化。

在此搜索空间中,一个节点最多可以有两条入边。此方法确保满足此条件,为进一步离散化准备图。

prepare_evaluation()

此方法准备模型用于评估。在 DARTS 中,评估模型在 stem 之后有 32 个通道,并且每个阶段包含 3 个正常单元。

query(metric=None, dataset=None, path=None, epoch=-1, full_lc=False, dataset_api=None)

查询 NasBench301 的结果。如果架构是从 NasBench301 训练数据中加载的,则可以查询特定 epoch 的训练损失或验证准确率。否则,只能使用 NasBench301 查询 epoch 100 的验证准确率。

参数

名称 类型 描述 默认值
metric Metric

要查询的所需指标。

None
dataset str

要使用的数据集。目前仅支持 'cifar10' 数据集。

None
path str

保存的模型路径。

None
epoch int

要查询的特定 epoch。默认为 -1。

-1
full_lc bool

指示是否应返回完整学习曲线的标志。默认为 False。

False
dataset_api dict

用于查询模型的数据集 API。

None

返回

类型 描述
Union[float, dict]

Union[float, dict]:查询结果。

抛出异常

类型 描述
NotImplementedError

如果 dataset_api 为 None。

AssertionError

如果数据集不是 'cifar10' 或 None。

sample_random_architecture(dataset_api=None, load_labeled=False)

采样一个随机架构并相应地更新 naslib 对象中的边。

参数

名称 类型 描述 默认值
dataset_api dict

数据集 API。如果 load_labeled 为 True,则必需。

None
load_labeled bool

是否从训练数据中加载架构。

False

sample_random_labeled_architecture()

从标记的架构中采样一个随机架构。

set_compact(compact)

设置架构的紧凑表示。如果模型已实例化且紧凑形式不存在,则将紧凑表示转换为模型。

参数

名称 类型 描述 默认值
compact tuple

架构的紧凑形式。

必需

set_spec(compact, dataset_api=None)

设置架构规范,使其不可变。

参数

名称 类型 描述 默认值
compact tuple

架构的紧凑形式。

必需
dataset_api dict

数据集 API。

None