扩展 auto-sklearn

auto-sklearn 可以轻松地通过新的分类、回归和特征预处理方法进行扩展。为此,用户需要实现一个包装器类并将其注册到 auto-sklearn。本手册将引导您完成此过程。

编写组件

根据用途,组件必须是以下基类之一的子类:

通常,这些类是现有机器学习模型的包装器,仅添加 auto-sklearn 所需的功能。当然,您也可以直接在组件内部实现机器学习算法。

每个组件必须实现一个返回其配置空间的方法、一个查询组件属性的方法,以及根据组件任务实现的 fit()predict()transform() 等方法。这些方法在以下子节中描述:get_hyperparameter_search_space()get_properties()

编写组件类后,您需要告知 auto-sklearn 它的存在。根据组件类型,您需要使用以下函数调用来添加它:

autosklearn.pipeline.components.classification.add_classifier(classifier: Type[autosklearn.pipeline.components.base.AutoSklearnClassificationAlgorithm]) None[source]
autosklearn.pipeline.components.regression.add_regressor(regressor: Type[autosklearn.pipeline.components.base.AutoSklearnRegressionAlgorithm]) None[source]
autosklearn.pipeline.components.feature_preprocessing.add_preprocessor(preprocessor: Type[autosklearn.pipeline.components.base.AutoSklearnPreprocessingAlgorithm]) None[source]

get_hyperparameter_search_space()

返回 ConfigSpace.configuration_space.ConfigurationSpace 的实例。

另请参阅抽象定义: AutoSklearnClassificationAlgorithm.get_hyperparameter_search_space() AutoSklearnRegressionAlgorithm.get_hyperparameter_search_space() AutoSklearnPreprocessingAlgorithm.get_hyperparameter_search_space()

要了解如何创建 ConfigurationSpace 对象,请查看 github.com 上的源代码。

get_properties()

返回一个字典,它定义了在构建机器学习管道时如何使用该组件。必须指定以下字段:

  • shortnamestr

    组件的缩写

  • namestr

    组件的全称

  • handles_regressionbool

    组件是否可以处理回归数据

  • handles_classificationbool

    组件是否可以处理分类数据

  • handles_multiclassbool

    组件是否可以处理多分类数据

  • handles_multilabelbool

    组件是否可以处理多标签分类数据

  • is_deterministicbool

    组件在多次使用时,使用相同的随机种子是否给出相同的结果

  • inputtuple

    组件可以处理的输入数据类型,可以有多个值

    • autosklearn.constants.DENSE

      密集数据数组,与 autosklearn.constants.SPARSE 互斥

    • autosklearn.constants.SPARSE

      稀疏数据矩阵,与 autosklearn.constants.DENSE 互斥

    • autosklearn.constants.UNSIGNED_DATA

      无符号数据数组,表示仅有正输入,与 autosklearn.constants.SIGNED_DATA 互斥

    • autosklearn.constants.SIGNED_DATA

      有符号数据数组,表示可以有正负输入值,与 autosklearn.constants.UNSIGNED_DATA 互斥

  • outputtuple

    组件产生的输出数据类型

    • autosklearn.constants.PREDICTIONS

      预测结果,例如分类器产生的

    • autosklearn.constants.INPUT

      与输入形式相同的数据

    • autosklearn.constants.DENSE

      密集数据数组,与 autosklearn.constants.SPARSE 互斥。这意味着稀疏数据将被转换为密集表示。

    • autosklearn.constants.SPARSE

      稀疏数据矩阵,与 autosklearn.constants.DENSE 互斥。这意味着密集数据将被转换为稀疏表示

    • autosklearn.constants.UNSIGNED_DATA

      无符号数据数组,表示仅有正输入,与 autosklearn.constants.SIGNED_DATA 互斥。这允许使用只能处理正数据的算法。

    • autosklearn.constants.SIGNED_DATA

      有符号数据数组,表示可以有正负输入值,与 autosklearn.constants.UNSIGNED_DATA 互斥

分类

除了 get_properties()get_hyperparameter_search_space() 之外,您还必须实现 AutoSklearnClassificationAlgorithm.fit()AutoSklearnClassificationAlgorithm.predict() 方法。这些是 scikit-learn 预测器 API 的实现。

回归

除了 get_properties()get_hyperparameter_search_space() 之外,您还必须实现 AutoSklearnRegressionAlgorithm.fit()AutoSklearnRegressionAlgorithm.predict() 方法。这些是 scikit-learn 预测器 API 的实现。

特征预处理

除了 get_properties()get_hyperparameter_search_space() 之外,您还必须实现 AutoSklearnPreprocessingAlgorithm.fit()AutoSklearnPreprocessingAlgorithm.transform() 方法。这些是 scikit-learn 预测器 API 的实现。