跳到内容

Bananas

基类: MetaOptimizer

贝叶斯优化 NAS (BANANAS) 实现作为元优化器。它结合了贝叶斯优化和神经网络架构搜索的元素。

属性

名称 类型 描述
using_step_function bool

优化器是否使用步长函数。默认为 False。

config object

包含各种设置的配置对象。

epochs int

训练的 epoch 数量。

performance_metric str

用于评估的性能指标。

dataset str

用于训练的数据集。

k int

用于调优的超参数。

num_init int

初始化的数量。

num_ensemble int

集成数量。

predictor_type str

要使用的预测器类型。

acq_fn_type str

要使用的采集函数类型。

acq_fn_optimization str

要使用的采集函数优化类型。

encoding_type str

使用的编码类型。

num_arches_to_mutate int

要变异的架构数量。

max_mutations int

最大变异数量。

num_candidates int

候选架构数量。

max_zerocost int

最大零成本。

train_data list

用于训练的数据列表。

next_batch list

用于下一批数据的数据列表。

history torch.nn.ModuleList

模型历史。

zc bool

零成本选项。

semi bool

半监督学习选项。

zc_api API

零成本预测器的 API。

use_zc_api bool

是否使用零成本 API。

zc_names list

零成本预测器的名称。

zc_only bool

是否只使用零成本预测器。

adapt_search_space(search_space, scope=None, dataset_api=None)

为元优化器调整提供的搜索空间。

参数

名称 类型 描述 默认值
search_space SearchSpace

要使用的搜索空间。

必填
scope str

要使用的优化器作用域。默认为搜索空间提供的作用域。

None
dataset_api API

要使用的数据集的 API。

None

引发

类型 描述
AssertionError

如果搜索空间不可查询。

get_arch_as_string(arch)

将架构转换为字符串。

参数

名称 类型 描述 默认值
arch dict

要转换的架构。

必填

返回

名称 类型 描述
str

架构的字符串表示形式。

get_checkpointables()

检索模型的检查点对象。

返回

名称 类型 描述
dict

模型的检查点对象。

get_final_architecture()

检索最终(最佳)架构。

返回

名称 类型 描述
dict

最终架构。

get_model_size()

检索模型的 MB 大小。

返回

名称 类型 描述
float

模型的 MB 大小。

get_op_optimizer()

检索操作优化器。

引发

类型 描述
NotImplementedError

此方法应在子类中实现。

get_zero_cost_predictors()

为 self.zc_names 中的每个方法生成零成本预测器。

返回

名称 类型 描述
dict

零成本预测器的字典。

new_epoch(epoch)

执行新 epoch 的操作。

参数

名称 类型 描述 默认值
epoch int

epoch 号。

必填

query_zc_scores(arch)

计算给定架构的零成本分数。

参数

名称 类型 描述 默认值
arch dict

要计算零成本分数的架构。

必填

返回

名称 类型 描述
dict

提供的架构的零成本分数字典。

test_statistics()

计算测试统计信息。

返回

名称 类型 描述
float

测试统计信息。

train_statistics(report_incumbent=True)

计算训练统计信息。

参数

名称 类型 描述 默认值
report_incumbent bool

是否报告当前最佳架构。默认为 True。

True

返回

名称 类型 描述
tuple

包含各种训练统计信息的元组。