Bananas
基类: MetaOptimizer
贝叶斯优化 NAS (BANANAS) 实现作为元优化器。它结合了贝叶斯优化和神经网络架构搜索的元素。
属性
名称 | 类型 | 描述 |
---|---|---|
using_step_function |
bool
|
优化器是否使用步长函数。默认为 False。 |
config |
object
|
包含各种设置的配置对象。 |
epochs |
int
|
训练的 epoch 数量。 |
performance_metric |
str
|
用于评估的性能指标。 |
dataset |
str
|
用于训练的数据集。 |
k |
int
|
用于调优的超参数。 |
num_init |
int
|
初始化的数量。 |
num_ensemble |
int
|
集成数量。 |
predictor_type |
str
|
要使用的预测器类型。 |
acq_fn_type |
str
|
要使用的采集函数类型。 |
acq_fn_optimization |
str
|
要使用的采集函数优化类型。 |
encoding_type |
str
|
使用的编码类型。 |
num_arches_to_mutate |
int
|
要变异的架构数量。 |
max_mutations |
int
|
最大变异数量。 |
num_candidates |
int
|
候选架构数量。 |
max_zerocost |
int
|
最大零成本。 |
train_data |
list
|
用于训练的数据列表。 |
next_batch |
list
|
用于下一批数据的数据列表。 |
history |
torch.nn.ModuleList
|
模型历史。 |
zc |
bool
|
零成本选项。 |
semi |
bool
|
半监督学习选项。 |
zc_api |
API
|
零成本预测器的 API。 |
use_zc_api |
bool
|
是否使用零成本 API。 |
zc_names |
list
|
零成本预测器的名称。 |
zc_only |
bool
|
是否只使用零成本预测器。 |
adapt_search_space(search_space, scope=None, dataset_api=None)
为元优化器调整提供的搜索空间。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
search_space |
SearchSpace
|
要使用的搜索空间。 |
必填 |
scope |
str
|
要使用的优化器作用域。默认为搜索空间提供的作用域。 |
None
|
dataset_api |
API
|
要使用的数据集的 API。 |
None
|
引发
类型 | 描述 |
---|---|
AssertionError
|
如果搜索空间不可查询。 |
get_arch_as_string(arch)
将架构转换为字符串。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
arch |
dict
|
要转换的架构。 |
必填 |
返回
名称 | 类型 | 描述 |
---|---|---|
str |
架构的字符串表示形式。 |
get_checkpointables()
检索模型的检查点对象。
返回
名称 | 类型 | 描述 |
---|---|---|
dict |
模型的检查点对象。 |
get_final_architecture()
检索最终(最佳)架构。
返回
名称 | 类型 | 描述 |
---|---|---|
dict |
最终架构。 |
get_model_size()
检索模型的 MB 大小。
返回
名称 | 类型 | 描述 |
---|---|---|
float |
模型的 MB 大小。 |
get_op_optimizer()
检索操作优化器。
引发
类型 | 描述 |
---|---|
NotImplementedError
|
此方法应在子类中实现。 |
get_zero_cost_predictors()
为 self.zc_names 中的每个方法生成零成本预测器。
返回
名称 | 类型 | 描述 |
---|---|---|
dict |
零成本预测器的字典。 |
new_epoch(epoch)
执行新 epoch 的操作。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
epoch |
int
|
epoch 号。 |
必填 |
query_zc_scores(arch)
计算给定架构的零成本分数。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
arch |
dict
|
要计算零成本分数的架构。 |
必填 |
返回
名称 | 类型 | 描述 |
---|---|---|
dict |
提供的架构的零成本分数字典。 |
test_statistics()
计算测试统计信息。
返回
名称 | 类型 | 描述 |
---|---|---|
float |
测试统计信息。 |
train_statistics(report_incumbent=True)
计算训练统计信息。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
report_incumbent |
bool
|
是否报告当前最佳架构。默认为 True。 |
True
|
返回
名称 | 类型 | 描述 |
---|---|---|
tuple |
包含各种训练统计信息的元组。 |