DARTS
基类: MetaOptimizer
Liu 等人 2019 年论文 DARTS: Differentiable Architecture Search 的实现。
属性
名称 | 类型 | 描述 |
---|---|---|
learning_rate |
float
|
操作优化的学习率。 |
momentum |
float
|
动量因子。 |
weight_decay |
float
|
权重衰减 (L2 正则化)。 |
grad_clip |
int
|
梯度裁剪值。 |
unrolled |
bool
|
是否使用展开式反向传播。 |
arch_learning_rate |
float
|
架构优化的学习率。 |
arch_weight_decay |
float
|
架构优化的权重衰减。 |
op_optimizer |
str
|
操作权重的优化器 ('SGD', 'Adam' 等)。 |
arch_optimizer |
str
|
架构权重的优化器 ('SGD', 'Adam' 等)。 |
loss |
str
|
损失准则 ('CrossEntropyLoss' 等)。 |
architectural_weights |
torch.nn.ParameterList
|
架构权重列表。 |
device |
torch.device
|
运行模型的设备。 |
search_space |
obj
|
架构搜索空间。 |
graph |
obj
|
计算图。 |
scope |
str
|
操作范围。 |
dataset |
str
|
用于搜索的数据集。 |
arch_optimizer |
obj
|
架构的 Torch 优化器。 |
op_optimizer |
obj
|
操作的 Torch 优化器。 |
loss |
obj
|
Torch 损失函数。 |
__init__(learning_rate=0.025, momentum=0.9, weight_decay=0.0003, grad_clip=5, unrolled=False, arch_learning_rate=0.0003, arch_weight_decay=0.001, op_optimizer='SGD', arch_optimizer='Adam', loss_criteria='CrossEntropyLoss', **kwargs)
初始化 DARTSOptimizer 的新实例。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
learning_rate |
float
|
操作优化的学习率。默认为 0.025。 |
0.025
|
momentum |
float
|
动量因子。默认为 0.9。 |
0.9
|
weight_decay |
float
|
权重衰减 (L2 正则化)。默认为 0.0003。 |
0.0003
|
grad_clip |
int
|
梯度裁剪值。默认为 5。 |
5
|
unrolled |
bool
|
是否使用展开式反向传播。默认为 False。 |
False
|
arch_learning_rate |
float
|
架构优化的学习率。默认为 0.0003。 |
0.0003
|
arch_weight_decay |
float
|
架构优化的权重衰减。默认为 0.001。 |
0.001
|
op_optimizer |
str
|
操作权重的优化器 ('SGD', 'Adam' 等)。默认为 'SGD'。 |
'SGD'
|
arch_optimizer |
str
|
架构权重的优化器 ('SGD', 'Adam' 等)。默认为 'Adam'。 |
'Adam'
|
loss_criteria |
str
|
损失准则 ('CrossEntropyLoss' 等)。默认为 'CrossEntropyLoss'。 |
'CrossEntropyLoss'
|
adapt_search_space(search_space, dataset, scope=None, **kwargs)
调整用于架构优化的搜索空间。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
search_space |
图
|
初始搜索空间对象。 |
必需 |
dataset |
数据集
|
用于训练/验证的数据集。 |
必需 |
scope |
str
|
应用图修改的范围。默认为 |
None
|
**kwargs |
附加关键字参数。 |
{}
|
add_alphas(edge)
静态方法
将架构权重 (alphas) 添加到计算图的边上。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
edge |
obj
|
计算图中要添加 alpha 的边。 |
必需 |
返回值
类型 | 描述 |
---|---|
None |
before_training()
准备模型进行训练。这将图和架构权重移至设备内存。
get_checkpointables()
获取模型中可检查点的元素,用于保存或加载。
返回值
名称 | 类型 | 描述 |
---|---|---|
dict |
包含所有待检查点元素的字典。 |
get_final_architecture()
根据当前的架构权重获取最终的离散化架构。
返回值
名称 | 类型 | 描述 |
---|---|---|
图 |
作为图对象的最终架构。 |
get_model_size()
获取模型的参数量大小。
返回值
名称 | 类型 | 描述 |
---|---|---|
float |
模型的大小(以 MB 为单位)。 |
get_op_optimizer()
获取操作优化器的类。
返回值
名称 | 类型 | 描述 |
---|---|---|
type |
操作优化器的类类型。 |
new_epoch(epoch)
在每个新 epoch 开始时记录架构权重。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
epoch |
int
|
当前 epoch 数。 |
必需 |
step(data_train, data_val)
执行单步优化。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
data_train |
tuple
|
包含训练输入和标签的元组。 |
必需 |
data_val |
tuple
|
包含验证输入和标签的元组。 |
必需 |
返回值
名称 | 类型 | 描述 |
---|---|---|
tuple |
包含训练集 logits、验证集 logits、 |
|
训练集损失和验证集损失的元组。 |
test_statistics()
根据当前架构和数据集检索测试统计信息。
返回值
名称 | 类型 | 描述 |
---|---|---|
float |
测试准确率,如果图可查询。否则返回 None。 |
update_ops(edge)
静态方法
使用 MixedOp 更新每条边上的操作。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
edge |
obj
|
计算图中要更新操作的边。 |
必需 |
返回值
类型 | 描述 |
---|---|
None |