Creating a model wrapper#
To create a model wrapper, you will create a child class of boa’s BaseWrapper
class. BaseWrapper defines the core functions that must be defined in your model
wrapper:
load_config(): Defines how to load the configuration file for the experiment. There is a default version of this function that can be used.write_configs(): Defines how to write configurations for the model. This is needed so thatboacan pass the parameters it generates in each trial to the model.run_model(): Defines how to run the model.set_trial_status(): Defines how to determine the status of a trial (i.e., if the model run is completed, still running, failed, etc).fetch_trial_data(): Retrieves the trial data and prepares it for the metric(s) used in the objective function.
Apart from these core functions, your model wrapper can have additional functions as needed (for example, to help with formatting or scaling model outputs, )
See FETCH3’s Wrapper for an example.
Example wrapper functions#
The write_configs function#
This function is usually used to write out the configurations files used in an individual optimization trial run, or to dynamically write a run script to start an optimization trial run.
FETCH3’s wrapper provides a simple example of this function for the case where the parameters simply need to be written to a yaml file:
def write_configs(trial_dir, parameters, model_options):
"""
Write model configuration file for each trial (model run). This is the config file used by FETCH3
for the model run.
The config file is written as ```config.yml``` inside the trial directory.
Parameters
----------
trial_dir : Path
Trial directory where the config file will be written
parameters : list
Model parameters for the trial, generated by the ax client
model_options : dict
Model options loaded from the experiment config yml file.
Returns
-------
str
Path for the config file
"""
with open(trial_dir / "config.yml", "w") as f:
# Write model options from loaded config
# Parameters for the trial from Ax
config_dict = {"model_options": model_options, "parameters": parameters}
yaml.dump(config_dict, f)
return f.name
The palm_wrapper used to wrap PALM provides an example where parameters are written to a YAML file, but also a batch job script is written for each optimization trial run.
def write_configs(self, trial: Trial) -> None:
"""
This function is usually used to write out the configurations files used
in an individual optimization trial run, or to dynamically write a run
script to start an optimization trial run.
Parameters
----------
trial : BaseTrial
"""
trial_config = copy.deepcopy(self.config)
job_name = zfilled_trial_index(trial.index)
job_output_dir = self.experiment_dir / job_name
job_output_dir.mkdir(parents=True)
run_time = trial_config["model_options"]["output_end_time"] * trial_config["model_options"][
"palmrun_walltime_scalar"]
data_analyses_time = (
(trial_config["model_options"]["output_end_time"] - trial_config["model_options"][
"output_start_time"])
* trial_config["model_options"]["data_analyse_walltime_scalar"])
trial_config_path = job_output_dir / "trial_config.yaml"
job_script_path = (job_output_dir / "slurm_job.sh").resolve()
trial_config["parameters"] = trial.arm.parameters
trial_config["model_options"]["config_path"] = trial_config_path
trial_config["model_options"]["job_script_path"] = job_script_path
trial_config["model_options"]["job_name"] = job_name
trial_config["model_options"]["log_file"] = job_output_dir / f"{job_name}_%j.log"
trial_config["model_options"]["job_output_dir"] = job_output_dir
trial_config["model_options"]["run_time"] = run_time
trial_config["model_options"]["data_analyses_time"] = data_analyses_time
trial_config["model_options"]["batch_time"] = run_time + data_analyses_time
trial.update_run_metadata(
dict(trial_config_path=trial_config_path,
job_script_path=job_script_path))
with open(JOB_SCRIPT_PATH) as template:
job_script = template.read()
job_script.format(**trial_config["model_options"])
with open(job_script_path, "w") as f:
f.write(job_script)
with open(trial_config_path, 'w') as f:
yaml.dump(trial_config, f)
The run_model function#
This function can simply launch a python or shell script to start a model run.
FETCH3’s wrapper provides an example of this function for the case where the model run is started by running a python script with command line arguments:
def run_model(self, trial: Trial):
trial_dir = make_trial_dir(self.experiment_dir, trial.index)
config_dir = write_configs(trial_dir, trial.arm.parameters, self.model_settings)
model_dir = self.ex_settings["model_dir"]
# with cd_and_cd_back(model_dir):
os.chdir(model_dir)
cmd = (
f"python main.py --config_path {config_dir} --data_path"
f" {self.ex_settings['data_path']} --output_path {trial_dir}"
)
args = cmd.split()
popen = subprocess.Popen(args, stdout=subprocess.PIPE, universal_newlines=True)
self._processes.append(popen)
The palm_wrapper used to wrap PALM takes the batch job script written in write_configs and runs it, starting a job. The job script also utilizes the YAML file written above as well.
def run_model(self, trial: Trial):
"""
Runs a model by deploying a given trial.
Parameters
----------
trial : BaseTrial
"""
trial_config = self._load_trial_config(trial)
job_script_path = trial_config["model_options"]["job_script_path"]
cmd = f"sbatch {job_script_path}"
args = cmd.split()
subprocess.Popen(
args, stdout=subprocess.PIPE, universal_newlines=True
)
The set_trial_status function#
Marks the status of a trial to reflect the status of the model run for the trial.
Each trial will be polled periodically to determine its status (completed, failed, still running, etc). This function defines the criteria for determining the status of the model run for a trial (e.g., whether the model run is completed/still running, failed, etc). The trial status is updated accordingly when the trial is polled.
The approach for determining the trial status will depend on the structure of the particular model and its outputs. One example is checking the log files of the model.
In these two examples, the trial status is determined by checking the log file of the model for specific outputs:
def set_trial_status(self, trial: Trial) -> None:
""" "Get status of the job by a given ID. For simplicity of the example,
return an Ax `TrialStatus`.
"""
log_file = get_trial_dir(self.experiment_dir, trial.index) / "fetch3.log"
if log_file.exists():
with open(log_file, "r") as f:
contents = f.read()
if "Error completing Run! Reason:" in contents:
trial.mark_failed()
elif "run complete" in contents:
trial.mark_completed()
def set_trial_status(self, trial: Trial) -> None:
"""
The trial gets polled from time to time to see if it is completed, failed, still running,
etc. This marks the trial as one of those options based on some criteria of the model.
If the model is still running, don't do anything with the trial.
Parameters
----------
trial : BaseTrial
Examples
--------
trial.mark_completed()
trial.mark_failed()
trial.mark_abandoned()
trial.mark_early_stopped()
See Also
--------
# TODO add sphinx link to ax trial status
"""
trial_config = self._load_trial_config(trial)
log_file = trial_config["model_options"]["log_file"]
if log_file.exists():
with open(log_file, "r") as f:
contents = f.read()
if "palmrun crashed" in contents:
trial.mark_abandoned()
elif "error:" in contents:
trial.mark_failed()
if "all OUTPUT-files saved" in contents:
trial.mark_completed()
The fetch_trial_data function#
Retrieves the trial data and prepares it for the metric(s) used in the objective function. The return value needs to be a dictionary with the keys matching the keys of the metric function used in the objective function.
def fetch_trial_data(self, trial: Trial, *args, **kwargs):
modelfile = (
get_trial_dir(self.experiment_dir, trial.index) / self.ex_settings["output_fname"]
)
y_pred, y_true = get_model_obs(
modelfile,
self.ex_settings["obsfile"],
self.ex_settings,
self.model_settings,
trial.arm.parameters,
)
return dict(y_pred=y_pred, y_true=y_true)
def fetch_trial_data(self, trial: Trial, metric_properties: dict, metric_name: str, *args, **kwargs):
"""
Retrieves the trial data and prepares it for the metric(s) used in the objective
function.
For example, for a case where you are minimizing the error between a model and observations, using RMSE as a
metric, this function would load the model output and the corresponding observation data that will be passed to
the RMSE metric.
The return value of this function is a dictionary, with keys that match the keys
of the metric used in the objective function.
# TODO work on this description
Parameters
----------
trial : Trial
metric_properties: dict
metric_name: str
Returns
-------
dict
A dictionary with the keys matching the keys of the metric function
used in the objective
"""
trial_config = trial.run_metadata["trial_config_path"]
job_output_dir = trial_config["model_options"]["job_output_dir"]
data_filepath = job_output_dir / "r_ca.json"
with open(data_filepath, 'r') as f:
data = json.load(f)
r_ca = np.array(data["1"])
return dict(a=r_ca)
Full Examples#
class Fetch3Wrapper(BaseWrapper):
_processes = []
def __init__(self, ex_settings, model_settings, experiment_dir):
self.ex_settings = ex_settings
self.model_settings = model_settings
self.experiment_dir = experiment_dir
def run_model(self, trial: Trial):
trial_dir = make_trial_dir(self.experiment_dir, trial.index)
config_dir = write_configs(trial_dir, trial.arm.parameters, self.model_settings)
model_dir = self.ex_settings["model_dir"]
# with cd_and_cd_back(model_dir):
os.chdir(model_dir)
cmd = (
f"python main.py --config_path {config_dir} --data_path"
f" {self.ex_settings['data_path']} --output_path {trial_dir}"
)
args = cmd.split()
popen = subprocess.Popen(args, stdout=subprocess.PIPE, universal_newlines=True)
self._processes.append(popen)
def set_trial_status(self, trial: Trial) -> None:
""" "Get status of the job by a given ID. For simplicity of the example,
return an Ax `TrialStatus`.
"""
log_file = get_trial_dir(self.experiment_dir, trial.index) / "fetch3.log"
if log_file.exists():
with open(log_file, "r") as f:
contents = f.read()
if "Error completing Run! Reason:" in contents:
trial.mark_failed()
elif "run complete" in contents:
trial.mark_completed()
def fetch_trial_data(self, trial: Trial, *args, **kwargs):
modelfile = (
get_trial_dir(self.experiment_dir, trial.index) / self.ex_settings["output_fname"]
)
y_pred, y_true = get_model_obs(
modelfile,
self.ex_settings["obsfile"],
self.ex_settings,
self.model_settings,
trial.arm.parameters,
)
return dict(y_pred=y_pred, y_true=y_true)
link to source: https://github.com/jemissik/fetch3_nhl/blob/develop/fetch3/optimize/fetch_wrapper.py
class Wrapper(BaseWrapper):
def __init__(self):
self.config = None
self.model_settings = None
self.ex_settings = None
self.experiment_dir = None
def load_config(self, config_file: os.PathLike):
"""
Load config file and return a dictionary # TODO finish this
Parameters
----------
config_file : os.PathLike
File path for the experiment configuration file
Returns
-------
loaded_config: dict
"""
config = load_yaml(config_file)
experiment_name = get_dt_now_as_str()
config["optimization_options"]["experiment"]["name"] = experiment_name
self.config = config
self.model_settings = self.config["model_options"]
self.ex_settings = self.config["optimization_options"]
self.experiment_dir = Path(self.model_settings["optimization_output_dir"]).expanduser() / experiment_name
def write_configs(self, trial: Trial) -> None:
"""
This function is usually used to write out the configurations files used
in an individual optimization trial run, or to dynamically write a run
script to start an optimization trial run.
Parameters
----------
trial : BaseTrial
"""
trial_config = copy.deepcopy(self.config)
job_name = zfilled_trial_index(trial.index)
job_output_dir = self.experiment_dir / job_name
job_output_dir.mkdir(parents=True)
run_time = trial_config["model_options"]["output_end_time"] * trial_config["model_options"][
"palmrun_walltime_scalar"]
data_analyses_time = (
(trial_config["model_options"]["output_end_time"] - trial_config["model_options"][
"output_start_time"])
* trial_config["model_options"]["data_analyse_walltime_scalar"])
trial_config_path = job_output_dir / "trial_config.yaml"
job_script_path = (job_output_dir / "slurm_job.sh").resolve()
trial_config["parameters"] = trial.arm.parameters
trial_config["model_options"]["config_path"] = trial_config_path
trial_config["model_options"]["job_script_path"] = job_script_path
trial_config["model_options"]["job_name"] = job_name
trial_config["model_options"]["log_file"] = job_output_dir / f"{job_name}_%j.log"
trial_config["model_options"]["job_output_dir"] = job_output_dir
trial_config["model_options"]["run_time"] = run_time
trial_config["model_options"]["data_analyses_time"] = data_analyses_time
trial_config["model_options"]["batch_time"] = run_time + data_analyses_time
trial.update_run_metadata(
dict(trial_config_path=trial_config_path,
job_script_path=job_script_path))
with open(JOB_SCRIPT_PATH) as template:
job_script = template.read()
job_script.format(**trial_config["model_options"])
with open(job_script_path, "w") as f:
f.write(job_script)
with open(trial_config_path, 'w') as f:
yaml.dump(trial_config, f)
def run_model(self, trial: Trial):
"""
Runs a model by deploying a given trial.
Parameters
----------
trial : BaseTrial
"""
trial_config = self._load_trial_config(trial)
job_script_path = trial_config["model_options"]["job_script_path"]
cmd = f"sbatch {job_script_path}"
args = cmd.split()
subprocess.Popen(
args, stdout=subprocess.PIPE, universal_newlines=True
)
def set_trial_status(self, trial: Trial) -> None:
"""
The trial gets polled from time to time to see if it is completed, failed, still running,
etc. This marks the trial as one of those options based on some criteria of the model.
If the model is still running, don't do anything with the trial.
Parameters
----------
trial : BaseTrial
Examples
--------
trial.mark_completed()
trial.mark_failed()
trial.mark_abandoned()
trial.mark_early_stopped()
See Also
--------
# TODO add sphinx link to ax trial status
"""
trial_config = self._load_trial_config(trial)
log_file = trial_config["model_options"]["log_file"]
if log_file.exists():
with open(log_file, "r") as f:
contents = f.read()
if "palmrun crashed" in contents:
trial.mark_abandoned()
elif "error:" in contents:
trial.mark_failed()
if "all OUTPUT-files saved" in contents:
trial.mark_completed()
def fetch_trial_data(self, trial: Trial, metric_properties: dict, metric_name: str, *args, **kwargs):
"""
Retrieves the trial data and prepares it for the metric(s) used in the objective
function.
For example, for a case where you are minimizing the error between a model and observations, using RMSE as a
metric, this function would load the model output and the corresponding observation data that will be passed to
the RMSE metric.
The return value of this function is a dictionary, with keys that match the keys
of the metric used in the objective function.
# TODO work on this description
Parameters
----------
trial : Trial
metric_properties: dict
metric_name: str
Returns
-------
dict
A dictionary with the keys matching the keys of the metric function
used in the objective
"""
trial_config = trial.run_metadata["trial_config_path"]
job_output_dir = trial_config["model_options"]["job_output_dir"]
data_filepath = job_output_dir / "r_ca.json"
with open(data_filepath, 'r') as f:
data = json.load(f)
r_ca = np.array(data["1"])
return dict(a=r_ca)
@staticmethod
def _load_trial_config(trial: Trial):
trial_config_path = trial.run_metadata["trial_config_path"]
trial_config = load_yaml(trial_config_path, normalize=False)
return trial_config
link to source: https://github.com/madeline-scyphers/palm_wrapper/blob/main/palm_wrapper/optimize/wrapper.py