Transbench101
基类: Graph
Transbench 101 搜索空间的实现。
此类创建表示为图的神经网络架构。它专为 Transbench 101 基准设计,提供搜索空间和任务特定配置的接口。
属性
名称 | 类型 | 描述 |
---|---|---|
OPTIMIZER_SCOPE |
List[str]
|
优化器要考虑的阶段列表。 |
QUERYABLE |
bool
|
指示类是否可查询的布尔值。 |
__init__(dataset='jigsaw', use_small_model=True, create_graph=False, n_classes=10, in_channels=3)
初始化 TransBench101SearchSpaceMicro 类。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset |
str
|
数据集名称。默认为 'jigsaw'。 |
'jigsaw'
|
use_small_model |
bool
|
如果为 True,使用小型模型。默认为 True。 |
True
|
create_graph |
bool
|
初始化时是否创建图。默认为 False。 |
False
|
n_classes |
int
|
类别数。默认为 10。 |
10
|
in_channels |
int
|
输入通道数。默认为 3。 |
3
|
encode(encoding_type='adjacency_one_hot')
编码当前架构。
此函数根据指定的编码类型对当前架构进行编码。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
encoding_type |
str
|
要执行的编码类型。默认为 "adjacency_one_hot"。 |
'adjacency_one_hot'
|
返回
名称 | 类型 | 描述 |
---|---|---|
多种类型 |
架构的编码表示。 |
encode_spec(encoding_type='adjacency_one_hot')
根据指定的编码类型编码架构。
此函数专门根据提供的编码类型编码 'TransBench101SearchSpaceMicro' 的架构。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
encoding_type |
str
|
要执行的编码类型。默认为 "adjacency_one_hot"。 |
'adjacency_one_hot'
|
返回
名称 | 类型 | 描述 |
---|---|---|
多种类型 |
架构的编码表示。 |
抛出
类型 | 描述 |
---|---|
NotImplementedError
|
如果当前搜索空间不支持此编码类型。 |
forward_before_global_avg_pool(x)
在全局平均池化操作之前的正向传播方法。
此函数根据数据集和图创建状态确定要调用的正向方法,并相应地返回输出。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
x |
torch.Tensor
|
要正向传播的输入张量。 |
必需 |
返回
类型 | 描述 |
---|---|
torch.Tensor: 全局平均池化操作之前的输出张量。 |
抛出
类型 | 描述 |
---|---|
异常
|
如果当前数据集和 NASLib 图设置未实现此方法。 |
get_arch_iterator(dataset_api=None)
获取所有可能架构的迭代器。
此函数返回一个迭代器,用于生成操作索引的所有可能组合。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
dict
|
用于查询数据集相关信息的 API。默认为 None。 |
None
|
返回
名称 | 类型 | 描述 |
---|---|---|
迭代器 |
所有可能架构上的迭代器。 |
get_hash()
获取当前架构的可哈希表示。
此函数返回一个操作索引的元组,用作架构的唯一标识符。
返回
名称 | 类型 | 描述 |
---|---|---|
元组 |
操作索引的元组。 |
get_loss_fn()
根据数据集获取适当的损失函数。
此函数根据数据集返回应用于训练的损失函数。
返回
名称 | 类型 | 描述 |
---|---|---|
函数 |
适用于数据集的损失函数。 |
get_nbhd(dataset_api=None)
获取当前架构的所有邻居。
此函数通过改变每条边上的单个操作索引来返回所有可能的邻居架构列表。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
用于查询数据集相关信息的 API。默认为 None。 |
None
|
返回
名称 | 类型 | 描述 |
---|---|---|
列表 |
邻居架构列表。 |
get_op_indices()
获取图的操作索引。
返回
名称 | 类型 | 描述 |
---|---|---|
列表 |
操作索引列表。 |
抛出
类型 | 描述 |
---|---|
NotImplementedError
|
如果 op_indices 和模型都未设置。 |
get_type()
获取搜索空间的类型。
此函数返回表示搜索空间类型的字符串。
返回
名称 | 类型 | 描述 |
---|---|---|
str |
搜索空间的类型,在此例中为 'transbench101_micro'。 |
mutate(parent, dataset_api=None)
从父架构变异单个操作索引。
此函数通过随机选择一条边并更改其操作索引来对父架构执行变异操作。然后将变异后的架构更新到 NASlib 对象中。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
parent |
父架构对象。 |
必需 | |
dataset_api |
用于查询数据集相关信息的 API。默认为 None。 |
None
|
返回
名称 | 类型 | 描述 |
---|---|---|
None |
就地更新对象。 |
query(metric=None, dataset=None, path=None, epoch=-1, full_lc=False, dataset_api=None)
根据指定的指标、数据集和其他参数查询 transbench 101 的结果。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
metric |
指标
|
要查询的指标。 |
None
|
dataset |
str
|
要查询的数据集。 |
None
|
path |
str
|
从中加载结果的路径。 |
None
|
epoch |
int
|
要查询的 epoch 数。默认为 -1。 |
-1
|
full_lc |
bool
|
用于检索完整学习曲线的标志。默认为 False。 |
False
|
dataset_api |
dict
|
用于查询的数据集 API。 |
None
|
返回
名称 | 类型 | 描述 |
---|---|---|
Any |
基于指标和数据集的查询结果。 |
抛出
类型 | 描述 |
---|---|
NotImplementedError
|
如果查询 Metric.ALL 或未传入数据集 API。 |
sample_random_architecture(dataset_api=None, load_labeled=False)
采样一个随机有效架构。
此函数采样一个随机架构,并在相应更新对象之前确保其有效性。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
dict
|
用于查询数据集相关信息的 API。默认为 None。 |
None
|
load_labeled |
bool
|
是否加载标注架构。默认为 False。 |
False
|
返回
名称 | 类型 | 描述 |
---|---|---|
None |
就地更新对象。 |
sample_random_labeled_architecture()
采样一个随机标注架构。
此函数从标注架构列表中采样一个随机架构,并相应更新对象。
返回
名称 | 类型 | 描述 |
---|---|---|
None |
就地更新对象。 |
set_op_indices(op_indices)
设置操作索引并相应更新架构。
此函数根据给定的 op_indices 更新 NASlib 对象中的操作索引和边。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
op_indices |
列表
|
要设置的操作索引列表。 |
必需 |
返回
名称 | 类型 | 描述 |
---|---|---|
None |
就地更新对象。 |
set_spec(op_indices, dataset_api=None)
统一不同搜索空间的操作索引设置器。
此函数仅调用 set_op_indices
来设置操作索引。用于保持不同搜索空间代码的一致性。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
op_indices |
列表
|
要设置的操作索引列表。 |
必需 |
dataset_api |
用于查询数据集相关信息的 API。默认为 None。 |
None
|
返回
名称 | 类型 | 描述 |
---|---|---|
None |
就地更新对象。 |
基类: Graph
TransBench 101 搜索空间的实现,提供了与 TransBench 101 表格基准的接口。
属性
名称 | 类型 | 描述 |
---|---|---|
OPTIMIZER_SCOPE |
List[str]
|
定义优化范围。 |
QUERYABLE |
bool
|
定义类对象是否可查询。 |
__init__(dataset='jigsaw', *arg, **kwargs)
初始化 TransBench101SearchSpaceMacro 类。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset |
str
|
搜索空间中使用的数据集。默认为 'jigsaw'。 |
'jigsaw'
|
*arg |
可变长度参数。 |
()
|
|
**kwargs |
任意关键字参数。 |
{}
|
encode(encoding_type=EncodingType.ADJACENCY_ONE_HOT)
根据指定的编码类型编码架构。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
encoding_type |
EncodingType
|
要使用的编码类型。默认为 EncodingType.ADJACENCY_ONE_HOT。 |
EncodingType.ADJACENCY_ONE_HOT
|
返回
类型 | 描述 |
---|---|
编码后的架构。 |
forward_before_global_avg_pool(x)
执行正向传播直到全局平均池化之前的层或最后一个卷积层之前的层。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
x |
torch.Tensor
|
输入张量。 |
必需 |
返回
类型 | 描述 |
---|---|
torch.Tensor: 指定层之前的输出张量。 |
get_hash()
获取操作索引的元组哈希值。
返回
名称 | 类型 | 描述 |
---|---|---|
Tuple |
包含操作索引的元组。 |
get_loss_fn()
根据数据集属性获取适当的损失函数。
返回
类型 | 描述 |
---|---|
一个 PyTorch 损失函数。 |
get_nbhd(dataset_api=None)
获取架构的所有邻居。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
数据集的 API。 |
None
|
返回
类型 | 描述 |
---|---|
邻居架构列表,每个都包裹在 PyTorch 模块中。 |
get_op_indices()
获取操作索引。
返回
名称 | 类型 | 描述 |
---|---|---|
Any |
操作索引。 |
抛出
类型 | 描述 |
---|---|
ValueError
|
如果未设置 op_indices。 |
get_type()
获取搜索空间的类型。
返回
名称 | 类型 | 描述 |
---|---|---|
str |
搜索空间的类型,在此例中为 'transbench101_macro'。 |
mutate(parent, dataset_api=None)
从父操作索引中变异一个操作并更新 naslib 对象。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
parent |
TransBench101SearchSpaceMacro
|
从中变异的父架构。 |
必需 |
dataset_api |
数据集的 API。 |
None
|
返回
名称 | 类型 | 描述 |
---|---|---|
None |
就地用变异后的操作更新对象。 |
query(metric=None, dataset=None, path=None, epoch=-1, full_lc=False, dataset_api=None)
查询 TransBench 101 的结果。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
metric |
指标
|
要查询的指标类型。 |
None
|
dataset |
str
|
要查询的数据集。 |
None
|
path |
str
|
查询路径。 |
None
|
epoch |
int
|
Epoch 数。默认为 -1。 |
-1
|
full_lc |
bool
|
完整学习曲线的标志。默认为 False。 |
False
|
dataset_api |
dict
|
数据集的 API。必须提供。 |
None
|
返回
名称 | 类型 | 描述 |
---|---|---|
Any |
基于指标和其他可选参数的查询结果。 |
抛出
类型 | 描述 |
---|---|
NotImplementedError
|
对于不受支持的指标或缺少数据集 API。 |
sample_random_architecture(dataset_api=None, load_labeled=False)
采样一个随机架构并相应更新 naslib 对象中的边。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
数据集的 API。 |
None
|
|
load_labeled |
是否加载标注架构。默认为 False。 |
False
|
返回
类型 | 描述 |
---|---|
如果 load_labeled 为 True,则为采样架构,否则就地更新对象。 |
sample_random_labeled_architecture()
采样一个随机标注架构。
返回
类型 | 描述 |
---|---|
采样架构。 |
抛出
类型 | 描述 |
---|---|
AssertionError
|
如果未提供标注架构。 |
set_op_indices(op_indices)
设置操作索引并相应更新 naslib 对象中的边。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
op_indices |
新的操作索引。 |
必需 |
set_spec(op_indices, dataset_api=None)
设置搜索空间的规范。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
op_indices |
新的操作索引。 |
必需 | |
dataset_api |
数据集的 API。 |
None
|