Creating model wrappers#
To create a model wrapper, you will create a child class of boa’s BaseWrapper
class. BaseWrapper defines the core functions that must be defined in your model
wrapper:
BaseWrapper.load_config(): Defines how to load the configuration file for the experiment. There is a default version of this function that can be used.BaseWrapper.write_configs(): Defines how to write configurations for the model. This is needed so that BOA can pass the parameters it generates in each trial to the model.BaseWrapper.run_model(): Defines how to run the model.BaseWrapper.set_trial_status(): Defines how to determine the status of a trial (i.e., if the model run is completed, still running, failed, etc).BaseWrapper.fetch_trial_data(): Retrieves the trial data (i.e., model outputs) and prepares it for the metric(s) used in the objective function.
Apart from these core functions, your model wrapper can have additional functions as needed (for example, to help with formatting or scaling model outputs)
See FETCH3’s
Wrapper for an example.
Example wrapper functions#
The BaseWrapper.write_configs() function#
This function is used to write out the configuration files used in an individual optimization trial run, (i.e. your model’s configuration files) or to dynamically write a run script to start an optimization trial run.
This function is how boa gives a new set of parameters for your model to run during each trial.
FETCH3’s wrapper provides a simple example of this function for a case where the model’s parameters simply need to be written to a yaml file:
def write_configs(trial_dir, parameters, model_options):
"""
Write model configuration file for each trial (model run). This is the config file used by FETCH3
for the model run.
The config file is written as ```config.yml``` inside the trial directory.
Parameters
----------
trial_dir : Path
Trial directory where the config file will be written
parameters : list
Model parameters for the trial, generated by the ax client
model_options : dict
Model options loaded from the experiment config yml file.
Returns
-------
str
Path for the config file
"""
with open(trial_dir / "config.yml", "w") as f:
# Write model options from loaded config
# Parameters for the trial from Ax
config_dict = {"model_options": model_options, "parameters": parameters}
yaml.dump(config_dict, f)
return f.name
The palm_wrapper used to wrap PALM provides an example of a model with more complicated configuration requirements. Here, the parameters are written to a YAML file, but then a batch job script must also be written for each optimization trial run.
def write_configs(self, trial: Trial) -> None:
"""
This function is usually used to write out the configurations files used
in an individual optimization trial run, or to dynamically write a run
script to start an optimization trial run.
Parameters
----------
trial : BaseTrial
"""
trial_config = copy.deepcopy(self.config)
job_name = zfilled_trial_index(trial.index)
job_output_dir = self.experiment_dir / job_name
job_output_dir.mkdir(parents=True)
run_time = trial_config["model_options"]["output_end_time"] * trial_config["model_options"][
"palmrun_walltime_scalar"]
data_analyses_time = (
(trial_config["model_options"]["output_end_time"] - trial_config["model_options"][
"output_start_time"])
* trial_config["model_options"]["data_analyse_walltime_scalar"])
trial_config_path = job_output_dir / "trial_config.yaml"
job_script_path = (job_output_dir / "slurm_job.sh").resolve()
trial_config["parameters"] = trial.arm.parameters
trial_config["model_options"]["config_path"] = str(trial_config_path)
trial_config["model_options"]["job_script_path"] = str(job_script_path)
trial_config["model_options"]["job_name"] = job_name
trial_config["model_options"]["log_file"] = str(job_output_dir / f"{job_name}_%j.log")
trial_config["model_options"]["job_output_dir"] = str(job_output_dir)
trial_config["model_options"]["run_time"] = int(run_time)
trial_config["model_options"]["data_analyses_time"] = data_analyses_time
trial_config["model_options"]["batch_time"] = int(run_time + data_analyses_time)
self.paths_by_trial[trial.index] = dict(trial_config_path=trial_config_path,
job_script_path=job_script_path)
with open(JOB_SCRIPT_PATH) as template:
job_script = template.read()
jinja_env = jinja2.Environment(
loader=jinja2.BaseLoader(),
)
template = jinja_env.from_string(job_script)
job_script = template.render(**trial_config["model_options"])
with open(job_script_path, "w") as f:
f.write(job_script)
with open(trial_config_path, 'w') as f:
yaml.dump(trial_config, f)
The BaseWrapper.run_model() function#
This function defines how to start a run of your model. In most cases, it can be as simple as launching a python or shell script to start a model run.
FETCH3’s wrapper provides an example of this function for the case where the model run is started by running a python script with command line arguments:
def run_model(self, trial: Trial):
trial_dir = make_trial_dir(self.experiment_dir, trial.index)
config_dir = write_configs(trial_dir, trial.arm.parameters, self.model_settings)
model_dir = self.ex_settings["model_dir"]
# with cd_and_cd_back(model_dir):
os.chdir(model_dir)
cmd = (
f"python main.py --config_path {config_dir} --data_path"
f" {self.ex_settings['data_path']} --output_path {trial_dir}"
)
args = cmd.split()
popen = subprocess.Popen(args, stdout=subprocess.PIPE, universal_newlines=True)
self._processes.append(popen)
The palm_wrapper used to wrap PALM takes the batch job script written in write_configs and runs it, starting a
job on an HPC.
The job script also utilizes the YAML file written above as well.
def run_model(self, trial: Trial):
"""
Runs a model by deploying a given trial.
Parameters
----------
trial : BaseTrial
"""
trial_config = load_yaml(self.paths_by_trial[trial.index]["trial_config_path"], normalize=False)
job_script_path = trial_config["model_options"]["job_script_path"]
cmd = f"sbatch {job_script_path}"
args = cmd.split()
subprocess.Popen(
args, stdout=subprocess.PIPE, universal_newlines=True
)
The BaseWrapper.set_trial_status() function#
Marks the status of a trial to reflect the status of the model run associated with that trial.
This function defines the criteria for determining the status of the model run for a trial (e.g., whether the model run is completed/still running, failed, etc). Each trial will be polled periodically to determine its status.
The approach for determining the trial status will depend on the structure of the particular model and its outputs. One way to do this is checking the log files of the model.
In these two examples, the trial status is determined by checking the log file of the model for specific outputs:
def set_trial_status(self, trial: Trial) -> None:
""" "Get status of the job by a given ID. For simplicity of the example,
return an Ax `TrialStatus`.
"""
log_file = get_trial_dir(self.experiment_dir, trial.index) / "fetch3.log"
if log_file.exists():
with open(log_file, "r") as f:
contents = f.read()
if "Error completing Run! Reason:" in contents:
trial.mark_failed()
elif "run complete" in contents:
trial.mark_completed()
def set_trial_status(self, trial: Trial) -> None:
"""
The trial gets polled from time to time to see if it is completed, failed, still running,
etc. This marks the trial as one of those options based on some criteria of the model.
If the model is still running, don't do anything with the trial.
Parameters
----------
trial : BaseTrial
Examples
--------
trial.mark_completed()
trial.mark_failed()
trial.mark_abandoned()
trial.mark_early_stopped()
See Also
--------
# TODO add sphinx link to ax trial status
"""
trial_config = load_yaml(self.paths_by_trial[trial.index]["trial_config_path"], normalize=False)
log_file = Path(trial_config["model_options"]["log_file"])
if log_file.exists():
with open(log_file, "r") as f:
contents = f.read()
if "palmrun crashed" in contents:
trial.mark_abandoned()
elif "error:" in contents:
trial.mark_failed()
if "all OUTPUT-files saved" in contents:
trial.mark_completed()
The BaseWrapper.fetch_trial_data() function#
Retrieves the trial data (i.e., model outputs) and prepares it for the metric(s) used in the objective function. The return value needs to be a dictionary with the keys matching the keys of the metric function used in the objective function.
def fetch_trial_data(self, trial: Trial, *args, **kwargs):
modelfile = (
get_trial_dir(self.experiment_dir, trial.index) / self.ex_settings["output_fname"]
)
y_pred, y_true = get_model_obs(
modelfile,
self.ex_settings["obsfile"],
self.ex_settings,
self.model_settings,
trial.arm.parameters,
)
return dict(y_pred=y_pred, y_true=y_true)
def fetch_trial_data(self, trial: Trial, metric_properties: dict, metric_name: str, *args, **kwargs):
"""
Retrieves the trial data and prepares it for the metric(s) used in the objective
function.
For example, for a case where you are minimizing the error between a model and observations, using RMSE as a
metric, this function would load the model output and the corresponding observation data that will be passed to
the RMSE metric.
The return value of this function is a dictionary, with keys that match the keys
of the metric used in the objective function.
# TODO work on this description
Parameters
----------
trial : Trial
metric_properties: dict
metric_name: str
Returns
-------
dict
A dictionary with the keys matching the keys of the metric function
used in the objective
"""
trial_config = trial.run_metadata["trial_config_path"]
job_output_dir = trial_config["model_options"]["job_output_dir"]
data_filepath = job_output_dir / "output.json"
with open(data_filepath, 'r') as f:
data = json.load(f)
output = np.array(data["output"])
return dict(a=output)
Full Examples#
class Fetch3Wrapper(BaseWrapper):
_processes = []
def __init__(self, ex_settings, model_settings, experiment_dir):
self.ex_settings = ex_settings
self.model_settings = model_settings
self.experiment_dir = experiment_dir
def run_model(self, trial: Trial):
trial_dir = make_trial_dir(self.experiment_dir, trial.index)
config_dir = write_configs(trial_dir, trial.arm.parameters, self.model_settings)
model_dir = self.ex_settings["model_dir"]
# with cd_and_cd_back(model_dir):
os.chdir(model_dir)
cmd = (
f"python main.py --config_path {config_dir} --data_path"
f" {self.ex_settings['data_path']} --output_path {trial_dir}"
)
args = cmd.split()
popen = subprocess.Popen(args, stdout=subprocess.PIPE, universal_newlines=True)
self._processes.append(popen)
def set_trial_status(self, trial: Trial) -> None:
""" "Get status of the job by a given ID. For simplicity of the example,
return an Ax `TrialStatus`.
"""
log_file = get_trial_dir(self.experiment_dir, trial.index) / "fetch3.log"
if log_file.exists():
with open(log_file, "r") as f:
contents = f.read()
if "Error completing Run! Reason:" in contents:
trial.mark_failed()
elif "run complete" in contents:
trial.mark_completed()
def fetch_trial_data(self, trial: Trial, *args, **kwargs):
modelfile = (
get_trial_dir(self.experiment_dir, trial.index) / self.ex_settings["output_fname"]
)
y_pred, y_true = get_model_obs(
modelfile,
self.ex_settings["obsfile"],
self.ex_settings,
self.model_settings,
trial.arm.parameters,
)
return dict(y_pred=y_pred, y_true=y_true)
link to source: jemissik/fetch3_nhl
class Wrapper(BaseWrapper):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.paths_by_trial = {}
def write_configs(self, trial: Trial) -> None:
"""
This function is usually used to write out the configurations files used
in an individual optimization trial run, or to dynamically write a run
script to start an optimization trial run.
Parameters
----------
trial : BaseTrial
"""
trial_config = copy.deepcopy(self.config)
job_name = zfilled_trial_index(trial.index)
job_output_dir = self.experiment_dir / job_name
job_output_dir.mkdir(parents=True)
run_time = trial_config["model_options"]["output_end_time"] * trial_config["model_options"][
"palmrun_walltime_scalar"]
data_analyses_time = (
(trial_config["model_options"]["output_end_time"] - trial_config["model_options"][
"output_start_time"])
* trial_config["model_options"]["data_analyse_walltime_scalar"])
trial_config_path = job_output_dir / "trial_config.yaml"
job_script_path = (job_output_dir / "slurm_job.sh").resolve()
trial_config["parameters"] = trial.arm.parameters
trial_config["model_options"]["config_path"] = str(trial_config_path)
trial_config["model_options"]["job_script_path"] = str(job_script_path)
trial_config["model_options"]["job_name"] = job_name
trial_config["model_options"]["log_file"] = str(job_output_dir / f"{job_name}_%j.log")
trial_config["model_options"]["job_output_dir"] = str(job_output_dir)
trial_config["model_options"]["run_time"] = int(run_time)
trial_config["model_options"]["data_analyses_time"] = data_analyses_time
trial_config["model_options"]["batch_time"] = int(run_time + data_analyses_time)
self.paths_by_trial[trial.index] = dict(trial_config_path=trial_config_path,
job_script_path=job_script_path)
with open(JOB_SCRIPT_PATH) as template:
job_script = template.read()
jinja_env = jinja2.Environment(
loader=jinja2.BaseLoader(),
)
template = jinja_env.from_string(job_script)
job_script = template.render(**trial_config["model_options"])
with open(job_script_path, "w") as f:
f.write(job_script)
with open(trial_config_path, 'w') as f:
yaml.dump(trial_config, f)
def run_model(self, trial: Trial):
"""
Runs a model by deploying a given trial.
Parameters
----------
trial : BaseTrial
"""
trial_config = load_yaml(self.paths_by_trial[trial.index]["trial_config_path"], normalize=False)
job_script_path = trial_config["model_options"]["job_script_path"]
cmd = f"sbatch {job_script_path}"
args = cmd.split()
subprocess.Popen(
args, stdout=subprocess.PIPE, universal_newlines=True
)
def set_trial_status(self, trial: Trial) -> None:
"""
The trial gets polled from time to time to see if it is completed, failed, still running,
etc. This marks the trial as one of those options based on some criteria of the model.
If the model is still running, don't do anything with the trial.
Parameters
----------
trial : BaseTrial
Examples
--------
trial.mark_completed()
trial.mark_failed()
trial.mark_abandoned()
trial.mark_early_stopped()
See Also
--------
# TODO add sphinx link to ax trial status
"""
trial_config = load_yaml(self.paths_by_trial[trial.index]["trial_config_path"], normalize=False)
log_file = Path(trial_config["model_options"]["log_file"])
if log_file.exists():
with open(log_file, "r") as f:
contents = f.read()
if "palmrun crashed" in contents:
trial.mark_abandoned()
elif "error:" in contents:
trial.mark_failed()
if "all OUTPUT-files saved" in contents:
trial.mark_completed()
def fetch_trial_data(self, trial: Trial, metric_properties: dict, metric_name: str, *args, **kwargs):
"""
Retrieves the trial data and prepares it for the metric(s) used in the objective
function.
For example, for a case where you are minimizing the error between a model and observations, using RMSE as a
metric, this function would load the model output and the corresponding observation data that will be passed to
the RMSE metric.
The return value of this function is a dictionary, with keys that match the keys
of the metric used in the objective function.
# TODO work on this description
Parameters
----------
trial : Trial
metric_properties: dict
metric_name: str
Returns
-------
dict
A dictionary with the keys matching the keys of the metric function
used in the objective
"""
trial_config = trial.run_metadata["trial_config_path"]
job_output_dir = trial_config["model_options"]["job_output_dir"]
data_filepath = job_output_dir / "output.json"
with open(data_filepath, 'r') as f:
data = json.load(f)
output = np.array(data["output"])
return dict(a=output)
link to source: madeline-scyphers/palm_wrapper
- Base Wrapper
BaseWrapper- Wrapper Utility Tools
cd_and_cd_back()cd_and_cd_back_dec()initialize_wrapper()split_shell_command()load_json()load_yaml()load_jsonlike()normalize_config()wpr_params_to_boa()boa_params_to_wpr()get_dt_now_as_str()make_experiment_dir()zfilled_trial_index()get_trial_dir()make_trial_dir()save_trial_data()