DrNAS
基类: DARTSOptimizer
DrNAS优化器的实现,该优化器在论文《DrNAS: Dirichlet Neural Architecture Search》(ICLR2021)中提出。
注意:许多函数与DARTS优化器相似,因此此类直接继承自DARTSOptimizer,而非MetaOptimizer。
__init__(learning_rate=0.025, momentum=0.9, weight_decay=0.0003, grad_clip=5, unrolled=False, arch_learning_rate=0.0003, arch_weight_decay=0.001, epochs=50, op_optimizer='SGD', arch_optimizer='Adam', loss_criteria='CrossEntropyLoss', **kwargs)
初始化DrNASOptimizer类的新实例。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
learning_rate |
float
|
操作权重的学习率。 |
0.025
|
momentum |
float
|
优化器的动量。 |
0.9
|
weight_decay |
float
|
操作权重的权重衰减。 |
0.0003
|
grad_clip |
int
|
梯度裁剪阈值。 |
5
|
unrolled |
bool
|
是否使用展开优化。 |
False
|
arch_learning_rate |
float
|
架构权重的学习率。 |
0.0003
|
arch_weight_decay |
float
|
架构权重的权重衰减。 |
0.001
|
epochs |
int
|
总训练轮数。 |
50
|
op_optimizer |
str
|
操作权重的优化器类型。例如,'SGD' |
'SGD'
|
arch_optimizer |
str
|
架构权重的优化器类型。例如,'Adam' |
'Adam'
|
loss_criteria |
str
|
损失准则。例如,'CrossEntropyLoss' |
'CrossEntropyLoss'
|
**kwargs |
附加关键字参数。 |
{}
|
adapt_search_space(search_space, dataset, scope=None)
为架构搜索调整搜索空间。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
search_space |
初始搜索空间。 |
必需 | |
dataset |
用于训练/验证的数据集。 |
必需 | |
scope |
要在搜索空间中更新的范围。默认值为 None。 |
None
|
get_final_architecture()
根据当前的架构权重检索最终的离散化架构。
返回
名称 | 类型 | 描述 |
---|---|---|
Graph |
图表示的最终架构。 |
new_epoch(epoch)
在每个新epoch开始时执行所需的任何操作。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
epoch |
int
|
当前epoch编号。 |
必需 |
remove_sampled_alphas(edge)
静态方法
从边的data中移除采样的架构权重 (alphas)。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
edge |
计算图中要移除采样架构权重的边。 |
必需 |
sample_alphas(edge)
静态方法
使用由beta参数化的狄利克雷分布采样架构权重 (alphas)。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
edge |
计算图中设置采样架构权重的边。 |
必需 |
step(data_train, data_val)
对架构和操作权重执行单个优化步骤。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
data_train |
tuple
|
训练数据,以输入和标签的元组形式。 |
必需 |
data_val |
tuple
|
验证数据,以输入和标签的元组形式。 |
必需 |
返回
名称 | 类型 | 描述 |
---|---|---|
tuple |
训练数据的Logits,验证数据的Logits,训练数据的损失,验证数据的损失。 |
update_ops(edge)
静态方法
将边上的原始操作替换为DrNAS特有的DrNASMixedOp。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
edge |
计算图中要替换操作的边。 |
必需 |