跳到内容

DARTS

基类: MetaOptimizer

Liu 等人 2019 年论文 DARTS: Differentiable Architecture Search 的实现。

属性

名称 类型 描述
learning_rate float

操作优化的学习率。

momentum float

动量因子。

weight_decay float

权重衰减 (L2 正则化)。

grad_clip int

梯度裁剪值。

unrolled bool

是否使用展开式反向传播。

arch_learning_rate float

架构优化的学习率。

arch_weight_decay float

架构优化的权重衰减。

op_optimizer str

操作权重的优化器 ('SGD', 'Adam' 等)。

arch_optimizer str

架构权重的优化器 ('SGD', 'Adam' 等)。

loss str

损失准则 ('CrossEntropyLoss' 等)。

architectural_weights torch.nn.ParameterList

架构权重列表。

device torch.device

运行模型的设备。

search_space obj

架构搜索空间。

graph obj

计算图。

scope str

操作范围。

dataset str

用于搜索的数据集。

arch_optimizer obj

架构的 Torch 优化器。

op_optimizer obj

操作的 Torch 优化器。

loss obj

Torch 损失函数。

__init__(learning_rate=0.025, momentum=0.9, weight_decay=0.0003, grad_clip=5, unrolled=False, arch_learning_rate=0.0003, arch_weight_decay=0.001, op_optimizer='SGD', arch_optimizer='Adam', loss_criteria='CrossEntropyLoss', **kwargs)

初始化 DARTSOptimizer 的新实例。

参数

名称 类型 描述 默认值
learning_rate float

操作优化的学习率。默认为 0.025。

0.025
momentum float

动量因子。默认为 0.9。

0.9
weight_decay float

权重衰减 (L2 正则化)。默认为 0.0003。

0.0003
grad_clip int

梯度裁剪值。默认为 5。

5
unrolled bool

是否使用展开式反向传播。默认为 False。

False
arch_learning_rate float

架构优化的学习率。默认为 0.0003。

0.0003
arch_weight_decay float

架构优化的权重衰减。默认为 0.001。

0.001
op_optimizer str

操作权重的优化器 ('SGD', 'Adam' 等)。默认为 'SGD'。

'SGD'
arch_optimizer str

架构权重的优化器 ('SGD', 'Adam' 等)。默认为 'Adam'。

'Adam'
loss_criteria str

损失准则 ('CrossEntropyLoss' 等)。默认为 'CrossEntropyLoss'。

'CrossEntropyLoss'

adapt_search_space(search_space, dataset, scope=None, **kwargs)

调整用于架构优化的搜索空间。

参数

名称 类型 描述 默认值
search_space

初始搜索空间对象。

必需
dataset 数据集

用于训练/验证的数据集。

必需
scope str

应用图修改的范围。默认为 None

None
**kwargs

附加关键字参数。

{}

add_alphas(edge) 静态方法

将架构权重 (alphas) 添加到计算图的边上。

参数

名称 类型 描述 默认值
edge obj

计算图中要添加 alpha 的边。

必需

返回值

类型 描述

None

before_training()

准备模型进行训练。这将图和架构权重移至设备内存。

get_checkpointables()

获取模型中可检查点的元素,用于保存或加载。

返回值

名称 类型 描述
dict

包含所有待检查点元素的字典。

get_final_architecture()

根据当前的架构权重获取最终的离散化架构。

返回值

名称 类型 描述

作为图对象的最终架构。

get_model_size()

获取模型的参数量大小。

返回值

名称 类型 描述
float

模型的大小(以 MB 为单位)。

get_op_optimizer()

获取操作优化器的类。

返回值

名称 类型 描述
type

操作优化器的类类型。

new_epoch(epoch)

在每个新 epoch 开始时记录架构权重。

参数

名称 类型 描述 默认值
epoch int

当前 epoch 数。

必需

step(data_train, data_val)

执行单步优化。

参数

名称 类型 描述 默认值
data_train tuple

包含训练输入和标签的元组。

必需
data_val tuple

包含验证输入和标签的元组。

必需

返回值

名称 类型 描述
tuple

包含训练集 logits、验证集 logits、

训练集损失和验证集损失的元组。

test_statistics()

根据当前架构和数据集检索测试统计信息。

返回值

名称 类型 描述
float

测试准确率,如果图可查询。否则返回 None。

update_ops(edge) 静态方法

使用 MixedOp 更新每条边上的操作。

参数

名称 类型 描述 默认值
edge obj

计算图中要更新操作的边。

必需

返回值

类型 描述

None