跳到内容

NAS-Bench-ASR

基类: Graph

面向 nas-bench-asr 架构的表格基准接口。

该类扩展了 Graph 类,为 ASR 神经网络架构提供了结构。它包括创建宏图、单元块、单元以及用于搜索最优架构的查询方法。

属性

名称 类型 描述
QUERYABLE bool

架构是否可查询。

OPTIMIZER_SCOPE list of str

单元阶段名称列表。

注意

目前,不支持为 nas-bench-asr 架构构建 NASLib 对象。

__init__()

初始化 NasBenchASRSearchSpace 对象。

将属性设置为默认值,这些值将用于创建神经网络架构。

encode(encoding_type=EncodingType.ADJACENCY_ONE_HOT)

根据指定的编码类型对架构进行编码。

参数

名称 类型 描述 默认
encoding_type EncodingType

要使用的编码类型。默认为 EncodingType.ADJACENCY_ONE_HOT。

EncodingType.ADJACENCY_ONE_HOT

返回值

名称 类型 描述
object

编码后的架构。

get_compact()

获取架构的紧凑表示。

返回值

类型 描述

架构的紧凑表示。

引发

类型 描述
AssertionError

如果紧凑表示未设置。

get_hash()

根据其紧凑表示获取架构的哈希值。

返回值

类型 描述

架构的哈希值。

get_max_epochs()

获取训练的最大 epoch 数。

返回值

名称 类型 描述
int

最大 epoch 数。

get_nbhd(dataset_api=None)

获取当前架构的所有邻居。

参数

名称 类型 描述 默认
dataset_api 可选

用于获取邻居的数据集 API 实例。默认为 None。

None

返回值

名称 类型 描述
list

所有邻居架构的列表。

get_type()

获取搜索空间的类型。

返回值

名称 类型 描述
str

搜索空间的类型,在本例中为 'asr'。

mutate(parent, mutation_rate=1, dataset_api=None)

变异架构。

参数

名称 类型 描述 默认
parent NasBenchASRSearchSpace

父架构。

必需
mutation_rate int

变异率。默认为 1。

1
dataset_api DatasetAPI

用于变异的数据集 API 实例。默认为 None。

None

返回值

名称 类型 描述
None

架构会原地变异。

注意

这将以两种方式之一变异单元:改变一条边;改变一个操作。

待办事项:通过添加/删除节点进行变异。待办事项:变异隐藏节点列表。待办事项:初始隐藏节点之间的边不进行变异。

query(metric=None, dataset=None, path=None, epoch=-1, full_lc=False, dataset_api=None)

查询 nas-bench-asr 基准的结果。

参数

名称 类型 描述 默认
metric Metric

要查询的性能指标。

None
dataset str

要在其上查询的数据集。

None
path str

保存结果的文件路径。

None
epoch int

查询性能指标的 epoch 数。

-1
full_lc bool

是否返回完整的学习曲线。

False
dataset_api dict

用于查询的数据集 API。

None

返回值

类型 描述

float 或 list:查询指标的值。

sample_random_architecture(dataset_api)

根据数据集 API 采样随机架构。

参数

名称 类型 描述 默认
dataset_api

用于架构采样的数据集 API 实例。

必需

返回值

类型 描述

采样架构的紧凑表示。

set_compact(compact)

设置架构的紧凑表示。

参数

名称 类型 描述 默认
compact

架构的新紧凑表示。

必需