基础预测器
Bases: MetaOptimizer
BasePredictor 是基于预测的神经架构搜索 (NAS) 方法的基类。它为使用机器学习模型预测神经架构性能而无需训练它们的方法提供了基础。派生自 MetaOptimizer 类。
属性
名称 | 类型 | 描述 |
---|---|---|
using_step_function |
bool
|
表示此优化器没有步进函数(step function)的标志。 |
config |
CfgNode
|
搜索过程的配置设置。 |
epochs |
int
|
搜索过程的 epoch 数量。 |
performance_metric |
Metric
|
用于评估架构的性能指标。 |
dataset |
str
|
用于评估的数据集。 |
k |
int
|
每个周期要评估的架构数量。 |
num_init |
int
|
初始随机架构的数量。 |
test_size |
int
|
用于评估预测器的测试集大小。 |
predictor_type |
str
|
要使用的预测器类型(例如,“LGB”,“MLP”)。 |
num_ensemble |
int
|
集成模型中的模型数量。 |
encoding_type |
str
|
架构的编码类型(例如,“adjacency_one_hot”)。 |
debug_predictor |
bool
|
如果为 True,将打印调试信息。 |
train_data |
list
|
一个列表,用于存储训练数据(架构-性能对)。 |
choices |
list
|
一个列表,用于存储选择的架构。 |
history |
torch.nn.ModuleList
|
一个列表,用于存储架构的历史记录。 |
__init__(config)
使用配置设置初始化 BasePredictor 类。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
config |
CfgNode
|
搜索过程的配置设置。 |
必需 |
adapt_search_space(search_space, scope=None, dataset_api=None)
调整搜索空间。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
search_space |
Graph
|
要调整的搜索空间。 |
必需 |
scope |
str
|
搜索的范围。默认为 None。 |
None
|
dataset_api |
dict
|
数据集的 API。默认为 None。 |
None
|
evaluate_predictor(xtrain, ytrain, xtest, test_pred, slice_size=4)
评估预测器以进行调试。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
xtrain |
训练数据(架构)。 |
必需 | |
ytrain |
训练标签(性能)。 |
必需 | |
xtest |
测试数据(架构)。 |
必需 | |
test_pred |
测试数据的预测性能。 |
必需 | |
slice_size |
int
|
每个切片中要打印的项目数量。默认为 4。 |
4
|
get_checkpointables()
获取可检查点(checkpoint)的模型。
返回
名称 | 类型 | 描述 |
---|---|---|
dict |
一个字典,其中键为“model”,值为架构的历史记录。 |
get_final_architecture()
获取搜索中最终(最佳)架构。
返回
名称 | 类型 | 描述 |
---|---|---|
Graph |
搜索期间找到的最佳架构。 |
get_model_size()
获取模型大小。
返回
名称 | 类型 | 描述 |
---|---|---|
float |
模型的兆字节(MB)大小。 |
get_op_optimizer()
获取操作的优化器。此方法在此类中未实现,调用时将引发错误。
引发
类型 | 描述 |
---|---|
NotImplementedError
|
始终,因为此方法在此类中未实现。 |
new_epoch(epoch)
在搜索过程中开始一个新的 epoch,采样一个新的架构进行训练,或使用预测器选择一个架构。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
epoch |
int
|
当前的 epoch 编号。 |
必需 |
test_statistics()
报告测试统计信息。
返回
名称 | 类型 | 描述 |
---|---|---|
float |
最佳架构的原始性能指标。 |
train_statistics(report_incumbent=True)
报告训练后的统计信息。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
report_incumbent |
bool
|
是报告当前最佳架构还是最新架构。默认为 True。 |
True
|
返回
名称 | 类型 | 描述 |
---|---|---|
tuple |
包含训练准确率、验证准确率和测试准确率的元组。 |