Source code for boa.runner

from collections import defaultdict
from typing import Any, Dict, Iterable, Set

from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.runner import Runner
from ax.core.trial import Trial

from boa.metaclasses import RunnerRegister
from boa.utils import serialize_init_args
from boa.wrapper import BaseWrapper


[docs]class WrappedJobRunner(Runner, metaclass=RunnerRegister): def __init__(self, wrapper: BaseWrapper = None, *args, **kwargs): self.wrapper = wrapper or BaseWrapper() super().__init__(*args, **kwargs)
[docs] def run(self, trial: BaseTrial) -> Dict[str, Any]: """Deploys a trial based on custom runner subclass implementation. Args: trial: The trial to deploy. Returns: Dict of run metadata from the deployment process. """ if not isinstance(trial, Trial): raise ValueError("This runner only handles `Trial`.") self.wrapper.write_configs(trial) self.wrapper.run_model(trial) # This run metadata will be attached to trial as `trial.run_metadata` # by the base `Scheduler`. return {"job_id": trial.index}
[docs] def poll_trial_status(self, trials: Iterable[BaseTrial]) -> Dict[TrialStatus, Set[int]]: """Checks the status of any non-terminal trials and returns their indices as a mapping from TrialStatus to a list of indices. Required for runners used with Ax ``Scheduler``. NOTE: Does not need to handle waiting between polling calls while trials are running; this function should just perform a single poll. Args: trials: Trials to poll. Returns: A dictionary mapping TrialStatus to a list of trial indices that have the respective status at the time of the polling. This does not need to include trials that at the time of polling already have a terminal (ABANDONED, FAILED, COMPLETED) status (but it may). """ status_dict = defaultdict(set) for trial in trials: self.wrapper.set_trial_status(trial) status_dict[trial.status].add(trial.index) return status_dict
[docs] def to_dict(self) -> dict: """Convert Ax synthetic runner to a dictionary.""" parents = self.__class__.mro()[1:] # index 0 is the class itself properties = serialize_init_args(self, parents=parents, match_private=True, exclude_fields=["wrapper"]) properties["__type"] = self.__class__.__name__ return properties