跳到内容

Transbench101

基类: Graph

Transbench 101 搜索空间的实现。

此类创建表示为图的神经网络架构。它专为 Transbench 101 基准设计,提供搜索空间和任务特定配置的接口。

属性

名称 类型 描述
OPTIMIZER_SCOPE List[str]

优化器要考虑的阶段列表。

QUERYABLE bool

指示类是否可查询的布尔值。

__init__(dataset='jigsaw', use_small_model=True, create_graph=False, n_classes=10, in_channels=3)

初始化 TransBench101SearchSpaceMicro 类。

参数

名称 类型 描述 默认值
dataset str

数据集名称。默认为 'jigsaw'。

'jigsaw'
use_small_model bool

如果为 True,使用小型模型。默认为 True。

True
create_graph bool

初始化时是否创建图。默认为 False。

False
n_classes int

类别数。默认为 10。

10
in_channels int

输入通道数。默认为 3。

3

encode(encoding_type='adjacency_one_hot')

编码当前架构。

此函数根据指定的编码类型对当前架构进行编码。

参数

名称 类型 描述 默认值
encoding_type str

要执行的编码类型。默认为 "adjacency_one_hot"。

'adjacency_one_hot'

返回

名称 类型 描述
多种类型

架构的编码表示。

encode_spec(encoding_type='adjacency_one_hot')

根据指定的编码类型编码架构。

此函数专门根据提供的编码类型编码 'TransBench101SearchSpaceMicro' 的架构。

参数

名称 类型 描述 默认值
encoding_type str

要执行的编码类型。默认为 "adjacency_one_hot"。

'adjacency_one_hot'

返回

名称 类型 描述
多种类型

架构的编码表示。

抛出

类型 描述
NotImplementedError

如果当前搜索空间不支持此编码类型。

forward_before_global_avg_pool(x)

在全局平均池化操作之前的正向传播方法。

此函数根据数据集和图创建状态确定要调用的正向方法,并相应地返回输出。

参数

名称 类型 描述 默认值
x torch.Tensor

要正向传播的输入张量。

必需

返回

类型 描述

torch.Tensor: 全局平均池化操作之前的输出张量。

抛出

类型 描述
异常

如果当前数据集和 NASLib 图设置未实现此方法。

get_arch_iterator(dataset_api=None)

获取所有可能架构的迭代器。

此函数返回一个迭代器,用于生成操作索引的所有可能组合。

参数

名称 类型 描述 默认值
dataset_api dict

用于查询数据集相关信息的 API。默认为 None。

None

返回

名称 类型 描述
迭代器

所有可能架构上的迭代器。

get_hash()

获取当前架构的可哈希表示。

此函数返回一个操作索引的元组,用作架构的唯一标识符。

返回

名称 类型 描述
元组

操作索引的元组。

get_loss_fn()

根据数据集获取适当的损失函数。

此函数根据数据集返回应用于训练的损失函数。

返回

名称 类型 描述
函数

适用于数据集的损失函数。

get_nbhd(dataset_api=None)

获取当前架构的所有邻居。

此函数通过改变每条边上的单个操作索引来返回所有可能的邻居架构列表。

参数

名称 类型 描述 默认值
dataset_api

用于查询数据集相关信息的 API。默认为 None。

None

返回

名称 类型 描述
列表

邻居架构列表。

get_op_indices()

获取图的操作索引。

返回

名称 类型 描述
列表

操作索引列表。

抛出

类型 描述
NotImplementedError

如果 op_indices 和模型都未设置。

get_type()

获取搜索空间的类型。

此函数返回表示搜索空间类型的字符串。

返回

名称 类型 描述
str

搜索空间的类型,在此例中为 'transbench101_micro'。

mutate(parent, dataset_api=None)

从父架构变异单个操作索引。

此函数通过随机选择一条边并更改其操作索引来对父架构执行变异操作。然后将变异后的架构更新到 NASlib 对象中。

参数

名称 类型 描述 默认值
parent

父架构对象。

必需
dataset_api

用于查询数据集相关信息的 API。默认为 None。

None

返回

名称 类型 描述
None

就地更新对象。

query(metric=None, dataset=None, path=None, epoch=-1, full_lc=False, dataset_api=None)

根据指定的指标、数据集和其他参数查询 transbench 101 的结果。

参数

名称 类型 描述 默认值
metric 指标

要查询的指标。

None
dataset str

要查询的数据集。

None
path str

从中加载结果的路径。

None
epoch int

要查询的 epoch 数。默认为 -1。

-1
full_lc bool

用于检索完整学习曲线的标志。默认为 False。

False
dataset_api dict

用于查询的数据集 API。

None

返回

名称 类型 描述
Any

基于指标和数据集的查询结果。

抛出

类型 描述
NotImplementedError

如果查询 Metric.ALL 或未传入数据集 API。

sample_random_architecture(dataset_api=None, load_labeled=False)

采样一个随机有效架构。

此函数采样一个随机架构,并在相应更新对象之前确保其有效性。

参数

名称 类型 描述 默认值
dataset_api dict

用于查询数据集相关信息的 API。默认为 None。

None
load_labeled bool

是否加载标注架构。默认为 False。

False

返回

名称 类型 描述
None

就地更新对象。

sample_random_labeled_architecture()

采样一个随机标注架构。

此函数从标注架构列表中采样一个随机架构,并相应更新对象。

返回

名称 类型 描述
None

就地更新对象。

set_op_indices(op_indices)

设置操作索引并相应更新架构。

此函数根据给定的 op_indices 更新 NASlib 对象中的操作索引和边。

参数

名称 类型 描述 默认值
op_indices 列表

要设置的操作索引列表。

必需

返回

名称 类型 描述
None

就地更新对象。

set_spec(op_indices, dataset_api=None)

统一不同搜索空间的操作索引设置器。

此函数仅调用 set_op_indices 来设置操作索引。用于保持不同搜索空间代码的一致性。

参数

名称 类型 描述 默认值
op_indices 列表

要设置的操作索引列表。

必需
dataset_api

用于查询数据集相关信息的 API。默认为 None。

None

返回

名称 类型 描述
None

就地更新对象。

基类: Graph

TransBench 101 搜索空间的实现,提供了与 TransBench 101 表格基准的接口。

属性

名称 类型 描述
OPTIMIZER_SCOPE List[str]

定义优化范围。

QUERYABLE bool

定义类对象是否可查询。

__init__(dataset='jigsaw', *arg, **kwargs)

初始化 TransBench101SearchSpaceMacro 类。

参数

名称 类型 描述 默认值
dataset str

搜索空间中使用的数据集。默认为 'jigsaw'。

'jigsaw'
*arg

可变长度参数。

()
**kwargs

任意关键字参数。

{}

encode(encoding_type=EncodingType.ADJACENCY_ONE_HOT)

根据指定的编码类型编码架构。

参数

名称 类型 描述 默认值
encoding_type EncodingType

要使用的编码类型。默认为 EncodingType.ADJACENCY_ONE_HOT。

EncodingType.ADJACENCY_ONE_HOT

返回

类型 描述

编码后的架构。

forward_before_global_avg_pool(x)

执行正向传播直到全局平均池化之前的层或最后一个卷积层之前的层。

参数

名称 类型 描述 默认值
x torch.Tensor

输入张量。

必需

返回

类型 描述

torch.Tensor: 指定层之前的输出张量。

get_hash()

获取操作索引的元组哈希值。

返回

名称 类型 描述
Tuple

包含操作索引的元组。

get_loss_fn()

根据数据集属性获取适当的损失函数。

返回

类型 描述

一个 PyTorch 损失函数。

get_nbhd(dataset_api=None)

获取架构的所有邻居。

参数

名称 类型 描述 默认值
dataset_api

数据集的 API。

None

返回

类型 描述

邻居架构列表,每个都包裹在 PyTorch 模块中。

get_op_indices()

获取操作索引。

返回

名称 类型 描述
Any

操作索引。

抛出

类型 描述
ValueError

如果未设置 op_indices。

get_type()

获取搜索空间的类型。

返回

名称 类型 描述
str

搜索空间的类型,在此例中为 'transbench101_macro'。

mutate(parent, dataset_api=None)

从父操作索引中变异一个操作并更新 naslib 对象。

参数

名称 类型 描述 默认值
parent TransBench101SearchSpaceMacro

从中变异的父架构。

必需
dataset_api

数据集的 API。

None

返回

名称 类型 描述
None

就地用变异后的操作更新对象。

query(metric=None, dataset=None, path=None, epoch=-1, full_lc=False, dataset_api=None)

查询 TransBench 101 的结果。

参数

名称 类型 描述 默认值
metric 指标

要查询的指标类型。

None
dataset str

要查询的数据集。

None
path str

查询路径。

None
epoch int

Epoch 数。默认为 -1。

-1
full_lc bool

完整学习曲线的标志。默认为 False。

False
dataset_api dict

数据集的 API。必须提供。

None

返回

名称 类型 描述
Any

基于指标和其他可选参数的查询结果。

抛出

类型 描述
NotImplementedError

对于不受支持的指标或缺少数据集 API。

sample_random_architecture(dataset_api=None, load_labeled=False)

采样一个随机架构并相应更新 naslib 对象中的边。

参数

名称 类型 描述 默认值
dataset_api

数据集的 API。

None
load_labeled

是否加载标注架构。默认为 False。

False

返回

类型 描述

如果 load_labeled 为 True,则为采样架构,否则就地更新对象。

sample_random_labeled_architecture()

采样一个随机标注架构。

返回

类型 描述

采样架构。

抛出

类型 描述
AssertionError

如果未提供标注架构。

set_op_indices(op_indices)

设置操作索引并相应更新 naslib 对象中的边。

参数

名称 类型 描述 默认值
op_indices

新的操作索引。

必需

set_spec(op_indices, dataset_api=None)

设置搜索空间的规范。

参数

名称 类型 描述 默认值
op_indices

新的操作索引。

必需
dataset_api

数据集的 API。

None