跳到内容

NASLib 快速入门

在本指南中,我们将演示如何开始使用神经架构搜索库(NASLib)。NASLib 提供了广泛的工具来促进神经架构搜索和优化。我们将向您展示如何利用各种优化器和搜索空间,例如,结合 NasBench301 搜索空间的 DARTS 优化器,结合 NasBench201 的正则化演进,以及使用 NasBench201 探索零成本代理。让我们深入了解细节。

DARTS 和 正则化演进

加载配置和日志记录

config = utils.get_config_from_args(config_type='nas')
logger = setup_logger(config.save + "/log.log")
logger.setLevel(logging.INFO)
utils.log_args(config)

配置参数通过 utils.get_config_from_args(config_type='nas') 加载。如果未提供其他文件,则默认从 darts_defaults.yaml 加载。此配置可能包含与您的数据集、架构、优化过程等相关的设置。这种灵活性使得可以根据实验要求轻松切换不同的优化器和搜索空间。日志记录器会记录此过程,以便进行调试或审计。

定义可用优化器和搜索空间的子集

supported_optimizers = {
    're': RegularizedEvolution(config),
    'darts': DARTSOptimizer(**config),
}
supported_search_spaces = {
    'nasbench201': NasBench201SearchSpace(),
    'nasbench301': NasBench301SearchSpace(),
}

虽然 NASLib 支持广泛的优化器和搜索空间,但在此处,我们实例化了两个特定的优化器(正则化演进和 DARTS)和两个搜索空间(NasBench201 和 NasBench301)。此演示展示了如何根据您的需求从库中选择特定的工具。

准备搜索空间和优化器

dataset_api = get_dataset_api(config.search_space, config.dataset)
utils.set_seed(config.seed)
search_space = supported_search_spaces[config.search_space]
optimizer = supported_optimizers[config.optimizer]
optimizer.adapt_search_space(search_space, dataset=config.dataset, dataset_api=dataset_api)

本节设置随机种子以确保可复现性,根据配置选择优化器和搜索空间,并将优化器调整以适应所选的搜索空间。它使用数据集 API 加载与您选择的搜索空间对应的数据集。

运行优化过程

trainer = Trainer(optimizer, config, lightweight_output=True)
trainer.search(resume_from="")
trainer.evaluate(resume_from="", dataset_api=dataset_api)

使用配置好的优化器创建 Trainer 对象,并开始架构搜索过程。搜索完成后,trainer 会评估找到的最佳架构。搜索和评估都可以从先前的检查点恢复(指定为相应函数的参数)。如果 lightweight_output 参数设置为 True,则会减少每个训练周期的输出量。

如需完整代码,请参阅 getting_started

零成本

NAS 通常是计算密集型的,因为在选择最佳模型之前需要评估多个模型。为了减少所需的计算能力和时间,通常使用代理任务来评估每个模型,而不是进行完整的训练。我们可以利用零成本代理以及零成本基准来查询 NAS 实验的分数。

我们已经下载了 NAS-Bench-201 的零成本基准 API,它包含了对所有 15625 个模型在所有三个数据集(CIFAR-10、CIFAR-100 和 ImageNet16-120)上使用所有 13 个零成本代理评估后的分数。

设置实验

config = utils.get_config_from_args(config_type='zc')
logger = setup_logger(config.save + "/log.log")
logger.setLevel(logging.INFO)
utils.log_args(config)
utils.set_seed(config.seed)

supported_optimizers = {
    'bananas': Bananas,
    'npenas': Npenas,
}
supported_search_spaces = {
    'nasbench201': NasBench201SearchSpace,
    'nasbench301': NasBench301SearchSpace,
}

实验设置与之前相同。但现在我们将选择 config_type = 'zc' 来加载与零成本实验相关的配置。自定义配置可以按照编写 zc_config.yaml 的相同方式进行。我们选择 BananasNpenas 作为支持的优化器,因为这是仅有的 2 个支持零成本 API 的优化器。还可以通过将从 NAS 基准查询替换为零成本 API,将零成本 API 集成到自定义优化器的查询函数中。

准备加载器和零成本 API

graph = supported_search_spaces[config.search_space]()
train_loader, val_loader, test_loader, train_transform, valid_transform = get_train_val_loaders(config)
zc_api = get_zc_benchmark_api(config.search_space, config.dataset)

使用零成本预测器进行查询

zc_pred = config.predictor 
graph.sample_random_architecture(dataset_api=dataset_api)
graph.parse()
zc_predictor = ZeroCost(method_type=zc_pred)
zc_score = zc_predictor.query(graph=graph, dataloader=train_loader)

您可以使用您选择的零成本预测器来查询零成本分数。

零成本分数与验证准确率之间的相关性

val_accs = []
zc_scores = []
for _ in range(10):
    graph = supported_search_spaces[config.search_space]()
    graph.sample_random_architecture()
    graph.parse()
    acc = graph.query(metric=Metric.VAL_ACCURACY, dataset='cifar10',
                      dataset_api=dataset_api)
    val_accs.append(acc)
    zc_score = zc_predictor.query(graph=graph, dataloader=train_loader)
    zc_scores.append(zc_score)

corr_score = compute_scores(ytest=val_accs, test_pred=zc_scores)

print("The kendall-tau score is: ", corr_score["kendalltau"])

在自定义优化器中添加一个便捷的功能是检查零成本分数与验证准确率之间的相关性。我们提供了一个便捷的实用函数 utils.compute_scores,它可以提供 9 种类型的相关性分数。

使用零成本 API

zc_predictor = 'jacov'
spec = graph.get_hash()
zc_score = zc_api[str(spec)][zc_predictor]['score']
time_to_compute = zc_api[str(spec)][zc_predictor]['time'] 

通过直接查询零成本 API 而不是通过零成本预测器查询分数,我们可以进一步节省计算时间。

使用零成本进行训练

graph = supported_search_spaces[config.search_space]()
optimizer = supported_optimizers[config.optimizer](config, zc_api=zc_api) 
optimizer.adapt_search_space(search_space, dataset_api=dataset_api)

trainer = Trainer(optimizer, config, lightweight_output=True)
trainer.search(resume_from="")
trainer.evaluate(resume_from="", dataset_api=dataset_api)

两个支持的优化器 - BananasNpenas 接受 zc_api 参数,该参数在使用零成本 API 查询架构分数时使用。请查阅 zc_config.yaml 以了解使用零成本分数所需的其他相关参数。

更多示例请参阅 naslib tutorial搜索空间介绍预测器介绍