"""
########################
Saving and Loading
########################
Functions for saving and loading your experiment to
stop and restart.
"""
import json
import pathlib
from copy import deepcopy
from dataclasses import asdict
from typing import Any, Callable, Dict, Optional, Type
from ax.exceptions.storage import JSONDecodeError as AXJSONDecodeError
from ax.exceptions.storage import JSONEncodeError as AXJSONEncodeError
from ax.service.scheduler import SchedulerOptions
from ax.storage.json_store.decoder import (
generation_strategy_from_json,
object_from_json,
)
from ax.storage.json_store.encoder import object_to_json
from ax.storage.json_store.registry import (
CORE_CLASS_DECODER_REGISTRY,
CORE_CLASS_ENCODER_REGISTRY,
CORE_DECODER_REGISTRY,
CORE_ENCODER_REGISTRY,
)
from boa.__version__ import __version__
from boa.definitions import PathLike
from boa.logger import get_logger
from boa.metrics.modular_metric import ModularMetric
from boa.runner import WrappedJobRunner
from boa.scheduler import Scheduler
from boa.utils import _load_attr_from_module, _load_module_from_path
from boa.wrappers.base_wrapper import BaseWrapper
logger = get_logger()
[docs]def scheduler_to_json_file(scheduler, filepath: PathLike = "scheduler_snapshot.json") -> None:
"""Save a JSON-serialized snapshot of this `Scheduler`'s settings and state
to a .json file by the given path.
"""
with open(filepath, "w+") as file: # pragma: no cover
file.write(json.dumps(scheduler_to_json_snapshot(scheduler)))
logger.info(f"Saved JSON-serialized state of optimization to `{filepath}`." f"\nBoa version: {__version__}")
[docs]def scheduler_from_json_file(filepath: PathLike = "scheduler.json", wrapper=None, **kwargs) -> Scheduler:
"""Restore an `Scheduler` and its state from a JSON-serialized snapshot,
residing in a .json file by the given path.
"""
with open(filepath, "r") as file: # pragma: no cover
serialized = json.loads(file.read())
scheduler = scheduler_from_json_snapshot(serialized=serialized, **kwargs)
wrapper = scheduler.experiment.runner.wrapper
if wrapper is not None:
for trial in scheduler.running_trials:
wrapper.set_trial_status(trial) # try and complete or fail any leftover trials
for trial in scheduler.running_trials: # any trial that was marked above is no longer here
trial.mark_failed() # fail anything leftover from above
return scheduler
[docs]def scheduler_to_json_snapshot(
scheduler: Scheduler,
encoder_registry: Optional[Dict[Type, Callable[[Any], Dict[str, Any]]]] = None,
class_encoder_registry: Optional[Dict[Type, Callable[[Any], Dict[str, Any]]]] = None,
) -> Dict[str, Any]:
"""Serialize this `Scheduler` to JSON to be able to interrupt and restart
optimization and save it to file by the provided path.
Returns:
A JSON-safe dict representation of this `Scheduler`.
"""
if encoder_registry is None:
encoder_registry = CORE_ENCODER_REGISTRY
if class_encoder_registry is None:
class_encoder_registry = CORE_CLASS_ENCODER_REGISTRY
options = asdict(scheduler.options)
options.pop("global_stopping_strategy", None)
options = SchedulerOptions(**options)
try:
wrapper_serialization = (
object_to_json(
scheduler.experiment.runner.wrapper,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
),
)
except AXJSONEncodeError as e:
logger.error(e)
wrapper_serialization = scheduler.experiment.runner.wrapper.to_dict()
serialization = {
"_type": scheduler.__class__.__name__,
"experiment": object_to_json(
scheduler.experiment,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
),
"generation_strategy": object_to_json(
scheduler.generation_strategy,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
),
"options": object_to_json(
options,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
),
"boa_version": __version__,
}
if wrapper_serialization:
serialization["wrapper"] = wrapper_serialization
return serialization
[docs]def scheduler_from_json_snapshot(
serialized: Dict[str, Any],
decoder_registry: Optional[Dict[str, Type]] = None,
class_decoder_registry: Optional[Dict[str, Callable[[Dict[str, Any]], Any]]] = None,
wrapper_path=None,
**kwargs,
) -> Scheduler:
"""Recreate an `Scheduler` from a JSON snapshot."""
if decoder_registry is None:
decoder_registry = CORE_DECODER_REGISTRY
if class_decoder_registry is None:
class_decoder_registry = CORE_CLASS_DECODER_REGISTRY
if "options" in serialized:
options = object_from_json(
serialized.pop("options"),
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
else:
options = SchedulerOptions()
experiment = object_from_json(
serialized.pop("experiment"),
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
wrapper = None
if "wrapper" in serialized:
wrapper_dict = serialized.pop("wrapper", {})
try:
wrapper = object_from_json(
deepcopy(wrapper_dict),
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
except Exception:
deserialized = recursive_deserialize(
wrapper_dict,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
if ("path" in wrapper_dict and pathlib.Path(deserialized["path"]).exists()) or (
wrapper_path is not None and pathlib.Path(wrapper_path).exists()
):
path = pathlib.Path(wrapper_path) if wrapper_path is not None else pathlib.Path(deserialized["path"])
module = _load_module_from_path(path, "user_wrapper")
WrapperCls: Type[BaseWrapper] = _load_attr_from_module(module, wrapper_dict["name"])
wrapper = WrapperCls.from_dict(**wrapper_dict)
else:
logger.exception("Failed to deserialize wrapper.\n\n\n\n\n")
wrapper = BaseWrapper()
serialized_generation_strategy = serialized.pop("generation_strategy")
generation_strategy = generation_strategy_from_json(
generation_strategy_json=serialized_generation_strategy, experiment=experiment
)
scheduler = Scheduler(generation_strategy=generation_strategy, experiment=experiment, options=options, **kwargs)
scheduler._experiment = experiment
if wrapper:
if isinstance(scheduler.experiment.runner, WrappedJobRunner):
scheduler.experiment.runner.wrapper = wrapper
for metric in scheduler.experiment.metrics.values():
if isinstance(metric, ModularMetric):
metric.wrapper = wrapper
return scheduler
[docs]def recursive_deserialize(obj, **kwargs):
if isinstance(obj, dict):
try:
if "__type" not in obj: # at least this out dict isn't serialized in AX format, let's check inners
raise AXJSONDecodeError("obj is not a AX JSON deserializable object")
obj = object_from_json(obj, **kwargs)
except AXJSONDecodeError:
for key, value in obj.items():
obj[key] = recursive_deserialize(value, **kwargs)
return obj