Creating model wrappers#

To create a model wrapper, you will create a child class of boa’s BaseWrapper class. BaseWrapper defines the core functions that must be defined in your model wrapper:

Apart from these core functions, your model wrapper can have additional functions as needed (for example, to help with formatting or scaling model outputs)

See FETCH3’s Wrapper for an example.

Example wrapper functions#

The BaseWrapper.write_configs() function#

This function is used to write out the configuration files used in an individual optimization trial run, (i.e. your model’s configuration files) or to dynamically write a run script to start an optimization trial run.

This function is how boa gives a new set of parameters for your model to run during each trial.

FETCH3’s wrapper provides a simple example of this function for a case where the model’s parameters simply need to be written to a yaml file:

def write_configs(trial_dir, parameters, model_options):
    """
    Write model configuration file for each trial (model run). This is the config file used by FETCH3
    for the model run.

    The config file is written as ```config.yml``` inside the trial directory.

    Parameters
    ----------
    trial_dir : Path
        Trial directory where the config file will be written
    parameters : list
        Model parameters for the trial, generated by the ax client
    model_options : dict
        Model options loaded from the experiment config yml file.

    Returns
    -------
    str
        Path for the config file
    """
    with open(trial_dir / "config.yml", "w") as f:
        # Write model options from loaded config
        # Parameters for the trial from Ax
        config_dict = {"model_options": model_options, "parameters": parameters}
        yaml.dump(config_dict, f)
        return f.name

The palm_wrapper used to wrap PALM provides an example of a model with more complicated configuration requirements. Here, the parameters are written to a YAML file, but then a batch job script must also be written for each optimization trial run.

    def write_configs(self, trial: Trial) -> None:
        """
        This function is usually used to write out the configurations files used
        in an individual optimization trial run, or to dynamically write a run
        script to start an optimization trial run.

        Parameters
        ----------
        trial : BaseTrial
        """
        trial_config = copy.deepcopy(self.config)
        job_name = zfilled_trial_index(trial.index)
        job_output_dir = self.experiment_dir / job_name
        job_output_dir.mkdir(parents=True)

        run_time = trial_config["model_options"]["output_end_time"] * trial_config["model_options"][
            "palmrun_walltime_scalar"]
        data_analyses_time = (
                (trial_config["model_options"]["output_end_time"] - trial_config["model_options"][
                    "output_start_time"])
                * trial_config["model_options"]["data_analyse_walltime_scalar"])
        trial_config_path = job_output_dir / "trial_config.yaml"
        job_script_path = (job_output_dir / "slurm_job.sh").resolve()

        trial_config["parameters"] = trial.arm.parameters
        trial_config["model_options"]["config_path"] = str(trial_config_path)
        trial_config["model_options"]["job_script_path"] = str(job_script_path)
        trial_config["model_options"]["job_name"] = job_name
        trial_config["model_options"]["log_file"] = str(job_output_dir / f"{job_name}_%j.log")
        trial_config["model_options"]["job_output_dir"] = str(job_output_dir)
        trial_config["model_options"]["run_time"] = int(run_time)
        trial_config["model_options"]["data_analyses_time"] = data_analyses_time
        trial_config["model_options"]["batch_time"] = int(run_time + data_analyses_time)

        self.paths_by_trial[trial.index] = dict(trial_config_path=trial_config_path,
                                                job_script_path=job_script_path)

        with open(JOB_SCRIPT_PATH) as template:
            job_script = template.read()
        jinja_env = jinja2.Environment(
            loader=jinja2.BaseLoader(),
        )
        template = jinja_env.from_string(job_script)
        job_script = template.render(**trial_config["model_options"])

        with open(job_script_path, "w") as f:
            f.write(job_script)

        with open(trial_config_path, 'w') as f:
            yaml.dump(trial_config, f)

The BaseWrapper.run_model() function#

This function defines how to start a run of your model. In most cases, it can be as simple as launching a python or shell script to start a model run.

FETCH3’s wrapper provides an example of this function for the case where the model run is started by running a python script with command line arguments:

    def run_model(self, trial: Trial):

        trial_dir = make_trial_dir(self.experiment_dir, trial.index)

        config_dir = write_configs(trial_dir, trial.arm.parameters, self.model_settings)

        model_dir = self.ex_settings["model_dir"]

        # with cd_and_cd_back(model_dir):
        os.chdir(model_dir)

        cmd = (
            f"python main.py --config_path {config_dir} --data_path"
            f" {self.ex_settings['data_path']} --output_path {trial_dir}"
        )

        args = cmd.split()
        popen = subprocess.Popen(args, stdout=subprocess.PIPE, universal_newlines=True)
        self._processes.append(popen)

The palm_wrapper used to wrap PALM takes the batch job script written in write_configs and runs it, starting a job on an HPC. The job script also utilizes the YAML file written above as well.

    def run_model(self, trial: Trial):
        """
        Runs a model by deploying a given trial.

        Parameters
        ----------
        trial : BaseTrial
        """
        trial_config = load_yaml(self.paths_by_trial[trial.index]["trial_config_path"], normalize=False)

        job_script_path = trial_config["model_options"]["job_script_path"]
        cmd = f"sbatch {job_script_path}"

        args = cmd.split()
        subprocess.Popen(
            args, stdout=subprocess.PIPE, universal_newlines=True
        )

The BaseWrapper.set_trial_status() function#

Marks the status of a trial to reflect the status of the model run associated with that trial.

This function defines the criteria for determining the status of the model run for a trial (e.g., whether the model run is completed/still running, failed, etc). Each trial will be polled periodically to determine its status.

The approach for determining the trial status will depend on the structure of the particular model and its outputs. One way to do this is checking the log files of the model.

In these two examples, the trial status is determined by checking the log file of the model for specific outputs:

    def set_trial_status(self, trial: Trial) -> None:
        """ "Get status of the job by a given ID. For simplicity of the example,
        return an Ax `TrialStatus`.
        """
        log_file = get_trial_dir(self.experiment_dir, trial.index) / "fetch3.log"

        if log_file.exists():
            with open(log_file, "r") as f:
                contents = f.read()
            if "Error completing Run! Reason:" in contents:
                trial.mark_failed()
            elif "run complete" in contents:
                trial.mark_completed()
    def set_trial_status(self, trial: Trial) -> None:
        """
        The trial gets polled from time to time to see if it is completed, failed, still running,
        etc. This marks the trial as one of those options based on some criteria of the model.
        If the model is still running, don't do anything with the trial.

        Parameters
        ----------
        trial : BaseTrial

        Examples
        --------
        trial.mark_completed()
        trial.mark_failed()
        trial.mark_abandoned()
        trial.mark_early_stopped()

        See Also
        --------
        # TODO add sphinx link to ax trial status

        """
        trial_config = load_yaml(self.paths_by_trial[trial.index]["trial_config_path"], normalize=False)

        log_file = Path(trial_config["model_options"]["log_file"])

        if log_file.exists():
            with open(log_file, "r") as f:
                contents = f.read()
            if "palmrun crashed" in contents:
                trial.mark_abandoned()
            elif "error:" in contents:
                trial.mark_failed()
            if "all OUTPUT-files saved" in contents:
                trial.mark_completed()

The BaseWrapper.fetch_trial_data() function#

Retrieves the trial data (i.e., model outputs) and prepares it for the metric(s) used in the objective function. The return value needs to be a dictionary with the keys matching the keys of the metric function used in the objective function.

    def fetch_trial_data(self, trial: Trial, *args, **kwargs):

        modelfile = (
            get_trial_dir(self.experiment_dir, trial.index) / self.ex_settings["output_fname"]
        )

        y_pred, y_true = get_model_obs(
            modelfile,
            self.ex_settings["obsfile"],
            self.ex_settings,
            self.model_settings,
            trial.arm.parameters,
        )
        return dict(y_pred=y_pred, y_true=y_true)
    def fetch_trial_data(self, trial: Trial, metric_properties: dict, metric_name: str,  *args, **kwargs):
        """
        Retrieves the trial data and prepares it for the metric(s) used in the objective
        function.

        For example, for a case where you are minimizing the error between a model and observations, using RMSE as a
        metric, this function would load the model output and the corresponding observation data that will be passed to
        the RMSE metric.

        The return value of this function is a dictionary, with keys that match the keys
        of the metric used in the objective function.
        # TODO work on this description

        Parameters
        ----------
        trial : Trial
        metric_properties: dict
        metric_name: str

        Returns
        -------
        dict
            A dictionary with the keys matching the keys of the metric function
                used in the objective
        """
        trial_config = trial.run_metadata["trial_config_path"]
        job_output_dir = trial_config["model_options"]["job_output_dir"]
        data_filepath = job_output_dir / "output.json"

        with open(data_filepath, 'r') as f:
            data = json.load(f)
        output = np.array(data["output"])
        return dict(a=output)

Full Examples#

class Fetch3Wrapper(BaseWrapper):
    _processes = []

    def __init__(self, ex_settings, model_settings, experiment_dir):
        self.ex_settings = ex_settings
        self.model_settings = model_settings
        self.experiment_dir = experiment_dir

    def run_model(self, trial: Trial):

        trial_dir = make_trial_dir(self.experiment_dir, trial.index)

        config_dir = write_configs(trial_dir, trial.arm.parameters, self.model_settings)

        model_dir = self.ex_settings["model_dir"]

        # with cd_and_cd_back(model_dir):
        os.chdir(model_dir)

        cmd = (
            f"python main.py --config_path {config_dir} --data_path"
            f" {self.ex_settings['data_path']} --output_path {trial_dir}"
        )

        args = cmd.split()
        popen = subprocess.Popen(args, stdout=subprocess.PIPE, universal_newlines=True)
        self._processes.append(popen)

    def set_trial_status(self, trial: Trial) -> None:
        """ "Get status of the job by a given ID. For simplicity of the example,
        return an Ax `TrialStatus`.
        """
        log_file = get_trial_dir(self.experiment_dir, trial.index) / "fetch3.log"

        if log_file.exists():
            with open(log_file, "r") as f:
                contents = f.read()
            if "Error completing Run! Reason:" in contents:
                trial.mark_failed()
            elif "run complete" in contents:
                trial.mark_completed()

    def fetch_trial_data(self, trial: Trial, *args, **kwargs):

        modelfile = (
            get_trial_dir(self.experiment_dir, trial.index) / self.ex_settings["output_fname"]
        )

        y_pred, y_true = get_model_obs(
            modelfile,
            self.ex_settings["obsfile"],
            self.ex_settings,
            self.model_settings,
            trial.arm.parameters,
        )
        return dict(y_pred=y_pred, y_true=y_true)

link to source: jemissik/fetch3_nhl

class Wrapper(BaseWrapper):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.paths_by_trial = {}
    def write_configs(self, trial: Trial) -> None:
        """
        This function is usually used to write out the configurations files used
        in an individual optimization trial run, or to dynamically write a run
        script to start an optimization trial run.

        Parameters
        ----------
        trial : BaseTrial
        """
        trial_config = copy.deepcopy(self.config)
        job_name = zfilled_trial_index(trial.index)
        job_output_dir = self.experiment_dir / job_name
        job_output_dir.mkdir(parents=True)

        run_time = trial_config["model_options"]["output_end_time"] * trial_config["model_options"][
            "palmrun_walltime_scalar"]
        data_analyses_time = (
                (trial_config["model_options"]["output_end_time"] - trial_config["model_options"][
                    "output_start_time"])
                * trial_config["model_options"]["data_analyse_walltime_scalar"])
        trial_config_path = job_output_dir / "trial_config.yaml"
        job_script_path = (job_output_dir / "slurm_job.sh").resolve()

        trial_config["parameters"] = trial.arm.parameters
        trial_config["model_options"]["config_path"] = str(trial_config_path)
        trial_config["model_options"]["job_script_path"] = str(job_script_path)
        trial_config["model_options"]["job_name"] = job_name
        trial_config["model_options"]["log_file"] = str(job_output_dir / f"{job_name}_%j.log")
        trial_config["model_options"]["job_output_dir"] = str(job_output_dir)
        trial_config["model_options"]["run_time"] = int(run_time)
        trial_config["model_options"]["data_analyses_time"] = data_analyses_time
        trial_config["model_options"]["batch_time"] = int(run_time + data_analyses_time)

        self.paths_by_trial[trial.index] = dict(trial_config_path=trial_config_path,
                                                job_script_path=job_script_path)

        with open(JOB_SCRIPT_PATH) as template:
            job_script = template.read()
        jinja_env = jinja2.Environment(
            loader=jinja2.BaseLoader(),
        )
        template = jinja_env.from_string(job_script)
        job_script = template.render(**trial_config["model_options"])

        with open(job_script_path, "w") as f:
            f.write(job_script)

        with open(trial_config_path, 'w') as f:
            yaml.dump(trial_config, f)

    def run_model(self, trial: Trial):
        """
        Runs a model by deploying a given trial.

        Parameters
        ----------
        trial : BaseTrial
        """
        trial_config = load_yaml(self.paths_by_trial[trial.index]["trial_config_path"], normalize=False)

        job_script_path = trial_config["model_options"]["job_script_path"]
        cmd = f"sbatch {job_script_path}"

        args = cmd.split()
        subprocess.Popen(
            args, stdout=subprocess.PIPE, universal_newlines=True
        )

    def set_trial_status(self, trial: Trial) -> None:
        """
        The trial gets polled from time to time to see if it is completed, failed, still running,
        etc. This marks the trial as one of those options based on some criteria of the model.
        If the model is still running, don't do anything with the trial.

        Parameters
        ----------
        trial : BaseTrial

        Examples
        --------
        trial.mark_completed()
        trial.mark_failed()
        trial.mark_abandoned()
        trial.mark_early_stopped()

        See Also
        --------
        # TODO add sphinx link to ax trial status

        """
        trial_config = load_yaml(self.paths_by_trial[trial.index]["trial_config_path"], normalize=False)

        log_file = Path(trial_config["model_options"]["log_file"])

        if log_file.exists():
            with open(log_file, "r") as f:
                contents = f.read()
            if "palmrun crashed" in contents:
                trial.mark_abandoned()
            elif "error:" in contents:
                trial.mark_failed()
            if "all OUTPUT-files saved" in contents:
                trial.mark_completed()

    def fetch_trial_data(self, trial: Trial, metric_properties: dict, metric_name: str,  *args, **kwargs):
        """
        Retrieves the trial data and prepares it for the metric(s) used in the objective
        function.

        For example, for a case where you are minimizing the error between a model and observations, using RMSE as a
        metric, this function would load the model output and the corresponding observation data that will be passed to
        the RMSE metric.

        The return value of this function is a dictionary, with keys that match the keys
        of the metric used in the objective function.
        # TODO work on this description

        Parameters
        ----------
        trial : Trial
        metric_properties: dict
        metric_name: str

        Returns
        -------
        dict
            A dictionary with the keys matching the keys of the metric function
                used in the objective
        """
        trial_config = trial.run_metadata["trial_config_path"]
        job_output_dir = trial_config["model_options"]["job_output_dir"]
        data_filepath = job_output_dir / "output.json"

        with open(data_filepath, 'r') as f:
            data = json.load(f)
        output = np.array(data["output"])
        return dict(a=output)

link to source: madeline-scyphers/palm_wrapper