NAS-Bench-NLP
基类: Graph
表示 NLP(自然语言处理)环境下的 NAS(神经网络架构搜索)的搜索空间。
注意
当前不支持为 nas-bench-nlp 架构构建 NASLib 对象。
属性
名称 | 类型 | 描述 |
---|---|---|
QUERYABLE |
bool
|
指定此类是否支持查询架构。 |
__init__()
初始化搜索空间的新实例。
encode(encoding_type=EncodingType.ADJACENCY_ONE_HOT)
将架构编码为特定格式。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
encoding_type |
EncodingType
|
使用的编码类型。默认为 |
EncodingType.ADJACENCY_ONE_HOT
|
返回值
类型 | 描述 |
---|---|
架构的编码表示。 |
get_arch_iterator(dataset_api=None)
获取一个迭代器,用于迭代数据集 API 中的架构。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
dict
|
包含架构信息的数据集 API。 |
None
|
返回值
类型 | 描述 |
---|---|
np.array: 用于迭代的架构表示数组。 |
get_compact()
获取架构的紧凑表示。
返回值
类型 | 描述 |
---|---|
架构的紧凑表示。 |
抛出
类型 | 描述 |
---|---|
AssertionError
|
如果未设置紧凑表示。 |
get_hash()
根据架构的紧凑表示获取哈希值。
返回值
类型 | 描述 |
---|---|
架构的哈希表示。 |
get_max_epochs()
获取训练的最大 epoch 数。
返回值
名称 | 类型 | 描述 |
---|---|---|
int |
最大 epoch 数 (49)。 |
get_nbhd(dataset_api=None)
根据当前架构获取邻域架构。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
object
|
用于查询架构的数据集 API。默认为 None。 |
None
|
返回值
名称 | 类型 | 描述 |
---|---|---|
list |
邻域架构列表。 |
目前与 mutate() 中有相同的待办事项。
get_type()
获取搜索空间的类型。
返回值
名称 | 类型 | 描述 |
---|---|---|
str |
搜索空间的类型(“nlp”)。 |
load_labeled_architecture(dataset_api=None, max_nodes=12)
将标记的架构加载到搜索空间中。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
包含可用架构信息的数据集 API。 |
None
|
|
max_nodes |
int
|
架构的最大节点数。 |
12
|
返回值
名称 | 类型 | 描述 |
---|---|---|
None |
架构被加载到实例中。 |
mutate(parent, mutation_rate=1, dataset_api=None)
通过改变边或操作来变异架构。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
parent |
object
|
父架构。 |
必需 |
mutation_rate |
int
|
要执行的变异次数。默认为 1。 |
1
|
dataset_api |
object
|
用于查询架构的数据集 API。默认为 None。 |
None
|
返回值
名称 | 类型 | 描述 |
---|---|---|
None |
就地修改架构。 |
待办事项:通过添加/移除节点进行变异。待办事项:变异隐藏节点列表。待办事项:初始隐藏节点之间的边未变异。
query(metric=None, dataset=None, path=None, epoch=-1, full_lc=False, dataset_api=None)
查询架构的性能指标。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
metric |
指标
|
要查询的指标。 |
None
|
dataset |
str
|
用于查询的数据集。 |
None
|
path |
str
|
保存的架构的文件路径。 |
None
|
epoch |
int
|
查询指标时的 epoch。 |
-1
|
full_lc |
bool
|
是否查询完整的学习曲线。 |
False
|
dataset_api |
用于查询的数据集 API。 |
None
|
返回值
类型 | 描述 |
---|---|
Union[int, float, dict, list]: 查询结果。 |
抛出
类型 | 描述 |
---|---|
NotImplementedError
|
如果尝试查询额外训练 epoch 的指标。 |
sample_random_architecture(dataset_api)
采样一个满足约束的随机架构。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
dataset_api |
用于查询架构的数据集 API。 |
必需 |
返回值
类型 | 描述 |
---|---|
采样架构的紧凑表示。 |
set_compact(compact)
设置架构的紧凑表示。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
compact |
架构的紧凑表示。 |
必需 |
返回值
名称 | 类型 | 描述 |
---|---|---|
None |
就地更新架构。 |
set_spec(compact, dataset_api=None)
设置架构规范。此函数用于统一不同搜索空间的接口。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
compact |
架构的紧凑表示。 |
必需 | |
dataset_api |
数据集 API。 |
None
|
返回值
名称 | 类型 | 描述 |
---|---|---|
None |
架构规范已设置。 |