"""
########################
Saving and Loading
########################
Functions for saving and loading your experiment to
stop and restart.
"""
import json
from typing import Any, Callable, Dict, Optional, Type
from ax.service.scheduler import Scheduler, SchedulerOptions
from ax.storage.json_store.decoder import (
generation_strategy_from_json,
object_from_json,
)
from ax.storage.json_store.encoder import object_to_json
from ax.storage.json_store.registry import (
CORE_CLASS_DECODER_REGISTRY,
CORE_CLASS_ENCODER_REGISTRY,
CORE_DECODER_REGISTRY,
CORE_ENCODER_REGISTRY,
)
from boa.definitions import PathLike
from boa.logger import get_logger
from boa.metrics.modular_metric import ModularMetric
from boa.runner import WrappedJobRunner
from boa.wrappers.wrapper_utils import initialize_wrapper
logger = get_logger(__name__)
[docs]def scheduler_to_json_file(scheduler, filepath: PathLike = "scheduler_snapshot.json") -> None:
"""Save a JSON-serialized snapshot of this `Scheduler`'s settings and state
to a .json file by the given path.
"""
with open(filepath, "w+") as file: # pragma: no cover
file.write(json.dumps(scheduler_to_json_snapshot(scheduler)))
logger.info(f"Saved JSON-serialized state of optimization to `{filepath}`.")
[docs]def scheduler_from_json_file(filepath: PathLike = "scheduler.json", wrapper=None, **kwargs) -> Scheduler:
"""Restore an `Scheduler` and its state from a JSON-serialized snapshot,
residing in a .json file by the given path.
"""
with open(filepath, "r") as file: # pragma: no cover
serialized = json.loads(file.read())
scheduler = scheduler_from_json_snapshot(serialized=serialized, **kwargs)
wrapper_dict = serialized.pop("wrapper", {})
wrapper_dict = object_from_json(wrapper_dict)
if not wrapper and "path" in wrapper_dict:
wrapper = initialize_wrapper(wrapper=wrapper_dict["path"], wrapper_name=wrapper_dict["name"], **wrapper_dict)
wrapper.config = wrapper_dict.get("config", {})
wrapper.experiment_dir = wrapper_dict.get("experiment_dir")
wrapper.working_dir = wrapper_dict.get("working_dir")
wrapper.metric_names = wrapper_dict.get("metric_names")
for trial in scheduler.running_trials:
wrapper.set_trial_status(trial) # try and complete or fail and leftover trials
for trial in scheduler.running_trials: # any trial that was marked above is no longer here
trial.mark_failed() # fail anything leftover from above
if wrapper is not None:
if isinstance(scheduler.experiment.runner, WrappedJobRunner):
scheduler.experiment.runner.wrapper = wrapper
for metric in scheduler.experiment.metrics.values():
if isinstance(metric, ModularMetric):
metric.wrapper = wrapper
return scheduler
[docs]def scheduler_to_json_snapshot(
scheduler: Scheduler,
encoder_registry: Optional[Dict[Type, Callable[[Any], Dict[str, Any]]]] = None,
class_encoder_registry: Optional[Dict[Type, Callable[[Any], Dict[str, Any]]]] = None,
) -> Dict[str, Any]:
"""Serialize this `Scheduler` to JSON to be able to interrupt and restart
optimization and save it to file by the provided path.
Returns:
A JSON-safe dict representation of this `Scheduler`.
"""
if encoder_registry is None:
encoder_registry = CORE_ENCODER_REGISTRY
if class_encoder_registry is None:
class_encoder_registry = CORE_CLASS_ENCODER_REGISTRY
return {
"_type": scheduler.__class__.__name__,
"experiment": object_to_json(
scheduler.experiment,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
),
"generation_strategy": object_to_json(
scheduler.generation_strategy,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
),
"options": object_to_json(
scheduler.options,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
),
"wrapper": scheduler.experiment.runner.wrapper.to_dict(),
}
[docs]def scheduler_from_json_snapshot(
serialized: Dict[str, Any],
decoder_registry: Optional[Dict[str, Type]] = None,
class_decoder_registry: Optional[Dict[str, Callable[[Dict[str, Any]], Any]]] = None,
**kwargs,
) -> Scheduler:
"""Recreate an `Scheduler` from a JSON snapshot."""
if decoder_registry is None:
decoder_registry = CORE_DECODER_REGISTRY
if class_decoder_registry is None:
class_decoder_registry = CORE_CLASS_DECODER_REGISTRY
if "options" in serialized:
options = object_from_json(
serialized.pop("options"),
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
else:
options = SchedulerOptions()
experiment = object_from_json(
serialized.pop("experiment"),
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
serialized_generation_strategy = serialized.pop("generation_strategy")
generation_strategy = generation_strategy_from_json(
generation_strategy_json=serialized_generation_strategy, experiment=experiment
)
scheduler = Scheduler(generation_strategy=generation_strategy, experiment=experiment, options=options, **kwargs)
scheduler._experiment = experiment
return scheduler