跳到主内容

Dask 运行器

smac.runner.dask_runner #

DaskParallelRunner #

DaskParallelRunner(
    single_worker: AbstractRunner,
    patience: int = 5,
    dask_client: Client | None = None,
)

基类: AbstractRunner

用于以分布式方式提交和收集作业的接口。DaskParallelRunner 旨在遵循桥接设计模式。然而,为了减少单线程实现与并行实现中的代码量,DaskParallelRunner 包装了一个 BaseRunner 对象,该对象随后在 n_workers 上并行执行。

然后,通过传递一个实现 run 方法并能以串行方式执行的 AbstractRunner 来构造此类。接下来,这个包装类使用 dask 初始化 N 个 AbstractRunner,它们主动等待 TrialInfo 以生成 RunInfo 对象。

更准确地说,工作模型如下:

  1. 强化器通过 TrialInfo 对象指定要运行的“内容”(配置/实例/种子)。
  2. 一个抽象运行器接收此 TrialInfo 对象并通过 submit_trial 启动任务。对于 DaskParallelRunner,n_workers 接收 DaskParallelRunner.single_worker 的 pickle 对象,每个对象都具有来自 DaskParallelRunner.single_worker.run() 的 run 方法。
  3. TrialInfo 对象以分布式方式运行,其结果在每个工作器本地可用。结果由 iter_results 收集,然后传递给 SMBO。
  4. 异常也在每个工作器本地可用,需要收集。

Dask 使用 Future 对象,这些对象通过 DaskParallelRunner.client 进行管理。

参数#

single_worker : AbstractRunner 一个用于分布式运行的运行器。将使用 n_workers 进行分发。patience: int,默认为 5 如果工作器失败,等待工作器可用的时间(秒)。dask_client: Client | None,默认为 None 用户创建的 dask 客户端,可用于启动 dask 集群,然后将 SMAC 附加到其上。如果显式提供,此客户端将不会自动关闭,需要手动关闭。如果未提供(默认),将为您创建一个本地客户端并在完成后关闭。

源代码位于 smac/runner/dask_runner.py
def __init__(
    self,
    single_worker: AbstractRunner,
    patience: int = 5,
    dask_client: Client | None = None,
):
    super().__init__(
        scenario=single_worker._scenario,
        required_arguments=single_worker._required_arguments,
    )

    # The single worker to hold on to and call run on
    self._single_worker = single_worker

    # The list of futures that dask will use to indicate in progress runs
    self._pending_trials: list[Future] = []

    # Dask related variables
    self._scheduler_file: Path | None = None
    self._patience = patience

    self._client: Client
    self._close_client_at_del: bool

    if dask_client is None:
        dask.config.set({"distributed.worker.daemon": False})
        self._close_client_at_del = True
        self._client = Client(
            n_workers=self._scenario.n_workers,
            processes=True,
            threads_per_worker=1,
            local_directory=str(self._scenario.output_directory),
        )

        if self._scenario.output_directory is not None:
            self._scheduler_file = Path(self._scenario.output_directory, ".dask_scheduler_file")
            self._client.write_scheduler_file(scheduler_file=str(self._scheduler_file))
    else:
        # We just use their set up
        self._client = dask_client
        self._close_client_at_del = False

meta property #

meta: dict[str, Any]

返回创建对象的元数据。

__del__ #

__del__() -> None

确保当此对象被删除时,客户端被终止。这仅在客户端由 dask 运行器创建时进行。

源代码位于 smac/runner/dask_runner.py
def __del__(self) -> None:
    """Makes sure that when this object gets deleted, the client is terminated. This
    is only done if the client was created by the dask runner.
    """
    if self._close_client_at_del:
        self.close()

close #

close(force: bool = False) -> None

关闭客户端。

源代码位于 smac/runner/dask_runner.py
def close(self, force: bool = False) -> None:
    """Closes the client."""
    if self._close_client_at_del or force:
        self._client.close()

count_available_workers #

count_available_workers() -> int

可用工作器的总数。此数量是动态的,因为可以分配更多资源。

源代码位于 smac/runner/dask_runner.py
def count_available_workers(self) -> int:
    """Total number of workers available. This number is dynamic as more resources
    can be allocated.
    """
    return sum(self._client.nthreads().values()) - len(self._pending_trials)

run_wrapper #

run_wrapper(
    trial_info: TrialInfo,
    **dask_data_to_scatter: dict[str, Any]
) -> tuple[TrialInfo, TrialValue]

围绕 run() 的包装器,用于执行和检查给定配置的执行。此函数封装了常见的处理/流程,从而简化了 run() 的实现。

参数#

trial_info : RunInfo 包含执行独立配置运行所需足够信息的对象。dask_data_to_scatter: dict[str, Any] 当用户将其本地进程中的数据分散到分布式网络时,此数据以循环方式(按核心数分组)分发。粗略地说,我们可以将此数据保存在内存中,这样每次我们想使用大数据集执行目标函数时,就不必(反)序列化数据。例如,当您的目标函数具有跨所有目标函数共享的大数据集时,此参数非常有用。

返回值#

info : TrialInfo 包含启动配置的对象。value : TrialValue 包含配置的状态/性能信息。

源代码位于 smac/runner/abstract_runner.py
def run_wrapper(
    self, trial_info: TrialInfo, **dask_data_to_scatter: dict[str, Any]
) -> tuple[TrialInfo, TrialValue]:
    """Wrapper around run() to execute and check the execution of a given config.
    This function encapsulates common
    handling/processing, so that run() implementation is simplified.

    Parameters
    ----------
    trial_info : RunInfo
        Object that contains enough information to execute a configuration run in isolation.
    dask_data_to_scatter: dict[str, Any]
        When a user scatters data from their local process to the distributed network,
        this data is distributed in a round-robin fashion grouping by number of cores.
        Roughly speaking, we can keep this data in memory and then we do not have to (de-)serialize the data
        every time we would like to execute a target function with a big dataset.
        For example, when your target function has a big dataset shared across all the target function,
        this argument is very useful.

    Returns
    -------
    info : TrialInfo
        An object containing the configuration launched.
    value : TrialValue
        Contains information about the status/performance of config.
    """
    start = time.time()
    cpu_time = time.process_time()
    try:
        status, cost, runtime, cpu_time, additional_info = self.run(
            config=trial_info.config,
            instance=trial_info.instance,
            budget=trial_info.budget,
            seed=trial_info.seed,
            **dask_data_to_scatter,
        )
    except Exception as e:
        status = StatusType.CRASHED
        cost = self._crash_cost
        cpu_time = time.process_time() - cpu_time
        runtime = time.time() - start

        # Add context information to the error message
        exception_traceback = traceback.format_exc()
        error_message = repr(e)
        additional_info = {
            "traceback": exception_traceback,
            "error": error_message,
        }

    end = time.time()

    # Catch NaN or inf
    if not np.all(np.isfinite(cost)):
        logger.warning(
            "Target function returned infinity or nothing at all. Result is treated as CRASHED"
            f" and cost is set to {self._crash_cost}."
        )

        if "traceback" in additional_info:
            logger.warning(f"Traceback: {additional_info['traceback']}\n")

        status = StatusType.CRASHED

    if status == StatusType.CRASHED:
        cost = self._crash_cost

    trial_value = TrialValue(
        status=status,
        cost=cost,
        time=runtime,
        cpu_time=cpu_time,
        additional_info=additional_info,
        starttime=start,
        endtime=end,
    )

    return trial_info, trial_value

submit_trial #

submit_trial(
    trial_info: TrialInfo,
    **dask_data_to_scatter: dict[str, Any]
) -> None

此函数提交嵌入在 trial_info 对象中的配置,并使用一个工作器在每个工作器本地生成结果。

配置的执行遵循此过程:

。SMBO/强化器生成一个 TrialInfo。#
。SMBO 调用 submit_trial 以便工作器启动 trial_info。#
。submit_trial 内部调用 self.run()。它通过调用包含通用代码的 run_wrapper 来实现这一点,#

否则任何 run 方法都必须实现这些代码。

所有结果仅在每个工作器本地可用,因此主节点需要收集它们。

参数#

trial_info : TrialInfo 包含启动配置的对象。

dict[str, Any]

当用户将其本地进程中的数据分散到分布式网络时,此数据以循环方式(按核心数分组)分发。粗略地说,我们可以将此数据保存在内存中,这样每次我们想使用大数据集执行目标函数时,就不必(反)序列化数据。例如,当您的目标函数具有跨所有目标函数共享的大数据集时,此参数非常有用。

源代码位于 smac/runner/dask_runner.py
def submit_trial(self, trial_info: TrialInfo, **dask_data_to_scatter: dict[str, Any]) -> None:
    """This function submits a configuration embedded in a ``trial_info`` object, and uses one of
    the workers to produce a result locally to each worker.

    The execution of a configuration follows this procedure:

    #. The SMBO/intensifier generates a `TrialInfo`.
    #. SMBO calls `submit_trial` so that a worker launches the `trial_info`.
    #. `submit_trial` internally calls ``self.run()``. It does so via a call to `run_wrapper` which contains common
       code that any `run` method will otherwise have to implement.

    All results will be only available locally to each worker, so the main node needs to collect them.

    Parameters
    ----------
    trial_info : TrialInfo
        An object containing the configuration launched.

    dask_data_to_scatter: dict[str, Any]
        When a user scatters data from their local process to the distributed network,
        this data is distributed in a round-robin fashion grouping by number of cores.
        Roughly speaking, we can keep this data in memory and then we do not have to (de-)serialize the data
        every time we would like to execute a target function with a big dataset.
        For example, when your target function has a big dataset shared across all the target function,
        this argument is very useful.
    """
    # Check for resources or block till one is available
    if self.count_available_workers() <= 0:
        logger.debug("No worker available. Waiting for one to be available...")
        wait(self._pending_trials, return_when="FIRST_COMPLETED")
        self._process_pending_trials()

    # Check again to make sure that there are resources
    if self.count_available_workers() <= 0:
        logger.warning("No workers are available. This could mean workers crashed. Waiting for new workers...")
        time.sleep(self._patience)
        if self.count_available_workers() <= 0:
            raise RuntimeError(
                "Tried to execute a job, but no worker was ever available."
                "This likely means that a worker crashed or no workers were properly configured."
            )

    # At this point we can submit the job
    trial = self._client.submit(self._single_worker.run_wrapper, trial_info=trial_info, **dask_data_to_scatter)
    self._pending_trials.append(trial)