顺序用法

默认情况下,auto-sklearn 并行拟合机器学习模型并构建它们的集成。但是,也可以按顺序运行这两个过程。下面的示例展示了如何先拟合模型,然后再构建集成。

from pprint import pprint

import sklearn.model_selection
import sklearn.datasets
import sklearn.metrics

import autosklearn.classification

数据加载

from autosklearn.ensembles.ensemble_selection import EnsembleSelection

X, y = sklearn.datasets.load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
    X, y, random_state=1
)

构建并拟合分类器

automl = autosklearn.classification.AutoSklearnClassifier(
    time_left_for_this_task=60,
    tmp_folder="/tmp/autosklearn_sequential_example_tmp",
    # Do not construct ensembles in parallel to avoid using more than one
    # core at a time. The ensemble will be constructed after auto-sklearn
    # finished fitting all machine learning models.
    ensemble_class=None,
    delete_tmp_folder_after_terminate=False,
)
automl.fit(X_train, y_train, dataset_name="breast_cancer")

# This call to fit_ensemble uses all models trained in the previous call
# to fit to build an ensemble which can be used with automl.predict()
automl.fit_ensemble(y_train, ensemble_class=EnsembleSelection)
RunKey(config_id=1, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.028368794326241176, time=2.138148307800293, status=<StatusType.SUCCESS: 1>, starttime=1663665046.729183, endtime=1663665048.8936107, additional_info={'duration': 2.025052070617676, 'num_run': 2, 'train_loss': 0.0, 'configuration_origin': 'Initial design'})
RunKey(config_id=2, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.028368794326241176, time=1.1886756420135498, status=<StatusType.SUCCESS: 1>, starttime=1663665048.8985617, endtime=1663665050.113383, additional_info={'duration': 1.0999813079833984, 'num_run': 3, 'train_loss': 0.01754385964912286, 'configuration_origin': 'Initial design'})
RunKey(config_id=3, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.05673758865248224, time=1.8330206871032715, status=<StatusType.SUCCESS: 1>, starttime=1663665050.1179628, endtime=1663665051.9798045, additional_info={'duration': 1.750197410583496, 'num_run': 4, 'train_loss': 0.0, 'configuration_origin': 'Initial design'})
RunKey(config_id=4, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.03546099290780147, time=2.5171611309051514, status=<StatusType.SUCCESS: 1>, starttime=1663665051.984322, endtime=1663665054.5282989, additional_info={'duration': 2.3936502933502197, 'num_run': 5, 'train_loss': 0.0035087719298245723, 'configuration_origin': 'Initial design'})
RunKey(config_id=5, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.028368794326241176, time=1.2775936126708984, status=<StatusType.SUCCESS: 1>, starttime=1663665054.5333724, endtime=1663665055.84306, additional_info={'duration': 1.198927640914917, 'num_run': 6, 'train_loss': 0.024561403508771895, 'configuration_origin': 'Initial design'})
RunKey(config_id=6, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.014184397163120588, time=1.840346336364746, status=<StatusType.SUCCESS: 1>, starttime=1663665055.8481805, endtime=1663665057.716139, additional_info={'duration': 1.7313201427459717, 'num_run': 7, 'train_loss': 0.0, 'configuration_origin': 'Initial design'})
RunKey(config_id=7, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.03546099290780147, time=2.512295961380005, status=<StatusType.SUCCESS: 1>, starttime=1663665057.7238793, endtime=1663665060.265423, additional_info={'duration': 2.364440441131592, 'num_run': 8, 'train_loss': 0.0035087719298245723, 'configuration_origin': 'Initial design'})
RunKey(config_id=8, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.04255319148936165, time=2.1315276622772217, status=<StatusType.SUCCESS: 1>, starttime=1663665060.271228, endtime=1663665062.429951, additional_info={'duration': 2.0131988525390625, 'num_run': 9, 'train_loss': 0.0035087719298245723, 'configuration_origin': 'Initial design'})
RunKey(config_id=9, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.028368794326241176, time=2.4673690795898438, status=<StatusType.SUCCESS: 1>, starttime=1663665062.4357092, endtime=1663665064.9292157, additional_info={'duration': 2.3498237133026123, 'num_run': 10, 'train_loss': 0.0, 'configuration_origin': 'Initial design'})
RunKey(config_id=10, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.028368794326241176, time=2.5794708728790283, status=<StatusType.SUCCESS: 1>, starttime=1663665064.934585, endtime=1663665067.5437174, additional_info={'duration': 2.4640066623687744, 'num_run': 11, 'train_loss': 0.0035087719298245723, 'configuration_origin': 'Initial design'})
RunKey(config_id=11, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.03546099290780147, time=1.5726971626281738, status=<StatusType.SUCCESS: 1>, starttime=1663665067.5497928, endtime=1663665069.1497045, additional_info={'duration': 1.4734737873077393, 'num_run': 12, 'train_loss': 0.0, 'configuration_origin': 'Initial design'})
RunKey(config_id=12, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.028368794326241176, time=1.610743522644043, status=<StatusType.SUCCESS: 1>, starttime=1663665069.156014, endtime=1663665070.7938027, additional_info={'duration': 1.5205156803131104, 'num_run': 13, 'train_loss': 0.0, 'configuration_origin': 'Initial design'})
RunKey(config_id=13, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.028368794326241176, time=2.3221628665924072, status=<StatusType.SUCCESS: 1>, starttime=1663665070.8001842, endtime=1663665073.1480403, additional_info={'duration': 2.2304294109344482, 'num_run': 14, 'train_loss': 0.010526315789473717, 'configuration_origin': 'Initial design'})
RunKey(config_id=14, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.049645390070921946, time=5.315098762512207, status=<StatusType.SUCCESS: 1>, starttime=1663665073.1545255, endtime=1663665078.4978535, additional_info={'duration': 5.217244863510132, 'num_run': 15, 'train_loss': 0.0, 'configuration_origin': 'Initial design'})
RunKey(config_id=15, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.021276595744680882, time=1.2758090496063232, status=<StatusType.SUCCESS: 1>, starttime=1663665078.505013, endtime=1663665079.8135793, additional_info={'duration': 1.1897251605987549, 'num_run': 16, 'train_loss': 0.0, 'configuration_origin': 'Initial design'})
RunKey(config_id=16, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.03546099290780147, time=2.031766176223755, status=<StatusType.SUCCESS: 1>, starttime=1663665079.820808, endtime=1663665081.878261, additional_info={'duration': 1.929640769958496, 'num_run': 17, 'train_loss': 0.0, 'configuration_origin': 'Initial design'})
RunKey(config_id=17, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.03546099290780147, time=2.4669880867004395, status=<StatusType.SUCCESS: 1>, starttime=1663665081.8850935, endtime=1663665084.380212, additional_info={'duration': 2.343514919281006, 'num_run': 18, 'train_loss': 0.0, 'configuration_origin': 'Initial design'})
RunKey(config_id=18, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.028368794326241176, time=3.1608927249908447, status=<StatusType.SUCCESS: 1>, starttime=1663665084.3869042, endtime=1663665087.5789628, additional_info={'duration': 3.045881748199463, 'num_run': 19, 'train_loss': 0.0035087719298245723, 'configuration_origin': 'Initial design'})
RunKey(config_id=19, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.07801418439716312, time=0.8530073165893555, status=<StatusType.SUCCESS: 1>, starttime=1663665087.5859113, endtime=1663665088.4659152, additional_info={'duration': 0.7710719108581543, 'num_run': 20, 'train_loss': 0.10526315789473684, 'configuration_origin': 'Initial design'})
RunKey(config_id=20, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.021276595744680882, time=1.751319169998169, status=<StatusType.SUCCESS: 1>, starttime=1663665088.4732761, endtime=1663665090.2543724, additional_info={'duration': 1.6337840557098389, 'num_run': 21, 'train_loss': 0.007017543859649145, 'configuration_origin': 'Initial design'})
RunKey(config_id=21, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=0.028368794326241176, time=1.2900927066802979, status=<StatusType.SUCCESS: 1>, starttime=1663665090.261764, endtime=1663665091.5791755, additional_info={'duration': 1.198190450668335, 'num_run': 22, 'train_loss': 0.0035087719298245723, 'configuration_origin': 'Initial design'})
RunKey(config_id=22, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=1.0, time=2.0078036785125732, status=<StatusType.TIMEOUT: 2>, starttime=1663665091.5878708, endtime=1663665094.6180243, additional_info={'error': 'Timeout', 'configuration_origin': 'Initial design'})
RunKey(config_id=23, instance_id='{"task_id": "breast_cancer"}', seed=0, budget=0.0) RunValue(cost=1.0, time=0.0, status=<StatusType.STOP: 8>, starttime=1663665094.6260393, endtime=1663665094.6260395, additional_info={})

AutoSklearnClassifier(delete_tmp_folder_after_terminate=False,
                      ensemble_class=<class 'autosklearn.ensembles.ensemble_selection.EnsembleSelection'>,
                      per_run_time_limit=6, time_left_for_this_task=60,
                      tmp_folder='/tmp/autosklearn_sequential_example_tmp')

获取最终集成模型的得分

predictions = automl.predict(X_test)
print(automl.sprint_statistics())
print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions))
auto-sklearn results:
  Dataset name: breast_cancer
  Metric: accuracy
  Best validation score: 0.985816
  Number of target algorithm runs: 22
  Number of successful target algorithm runs: 21
  Number of crashed target algorithm runs: 0
  Number of target algorithms that exceeded the time limit: 1
  Number of target algorithms that exceeded the memory limit: 0

Accuracy score 0.9440559440559441

脚本总运行时间: ( 0 分 58.297 秒)

由 Sphinx-Gallery 生成的画廊