Source code for boa.storage

import json
import logging
import os
import pickle
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type

from ax import Experiment
from ax.exceptions.storage import JSONEncodeError
from ax.service.scheduler import Scheduler, SchedulerOptions
from ax.storage.json_store.decoder import (
    generation_strategy_from_json,
    object_from_json,
)
from ax.storage.json_store.encoder import object_to_json
from ax.storage.json_store.load import load_experiment as ax_load_experiment
from ax.storage.json_store.registry import (
    CORE_CLASS_DECODER_REGISTRY,
    CORE_CLASS_ENCODER_REGISTRY,
    CORE_DECODER_REGISTRY,
    CORE_ENCODER_REGISTRY,
)
from ax.storage.json_store.save import save_experiment as ax_save_experiment

from boa.metrics.metrics import ModularMetric
from boa.runner import WrappedJobRunner

logger = logging.getLogger(__name__)


[docs]def pickle_obj(obj, filepath): with open(Path(str(filepath) + ".pickle"), "wb") as f: pickle.dump(obj, f)
[docs]def unpickle_obj(filepath): with open(filepath, "rb") as f: return pickle.load(f)
[docs]def save_experiment(experiment: Experiment, filepath: os.PathLike, pickle=True, *args, **kwargs): try: ax_save_experiment(experiment, filepath, *args, **kwargs) except JSONEncodeError: logger.warning("Failed to serialize experiment to JSON, attempting to pickle") with open(Path(str(filepath) + ".pickle"), "wb") as f: pickle.dump(experiment, f) logger.warning("Pickling succeeded. Experiment Pickled to %s" % Path(str(filepath) + ".pickle"))
[docs]def load_experiment(filepath: os.PathLike, *args, **kwargs) -> Experiment: """Load experiment from file. 1) Read file. 2) Convert dictionary to Ax experiment instance. """ try: with open(filepath, "rb") as f: return pickle.load(f) except pickle.UnpicklingError: return ax_load_experiment(filepath, *args, **kwargs)
[docs]def scheduler_to_json_file(scheduler, filepath: os.PathLike = "scheduler_snapshot.json") -> None: """Save a JSON-serialized snapshot of this `AxClient`'s settings and state to a .json file by the given path. """ with open(filepath, "w+") as file: # pragma: no cover file.write(json.dumps(scheduler_to_json_snapshot(scheduler))) logger.info(f"Saved JSON-serialized state of optimization to `{filepath}`.")
[docs]def scheduler_from_json_file(filepath: os.PathLike = "scheduler_snapshot.json", wrapper=None, **kwargs) -> Scheduler: """Restore an `AxClient` and its state from a JSON-serialized snapshot, residing in a .json file by the given path. """ with open(filepath, "r") as file: # pragma: no cover serialized = json.loads(file.read()) scheduler = scheduler_from_json_snapshot(serialized=serialized, **kwargs) if wrapper is not None: if isinstance(scheduler.experiment.runner, WrappedJobRunner): scheduler.experiment.runner.wrapper = wrapper for metric in scheduler.experiment.metrics.values(): if isinstance(metric, ModularMetric): metric.wrapper = wrapper return scheduler
[docs]def scheduler_to_json_snapshot( scheduler: Scheduler, encoder_registry: Optional[Dict[Type, Callable[[Any], Dict[str, Any]]]] = None, class_encoder_registry: Optional[Dict[Type, Callable[[Any], Dict[str, Any]]]] = None, ) -> Dict[str, Any]: """Serialize this `AxClient` to JSON to be able to interrupt and restart optimization and save it to file by the provided path. Returns: A JSON-safe dict representation of this `AxClient`. """ if encoder_registry is None: encoder_registry = CORE_ENCODER_REGISTRY if class_encoder_registry is None: class_encoder_registry = CORE_CLASS_ENCODER_REGISTRY return { "_type": scheduler.__class__.__name__, "experiment": object_to_json( scheduler.experiment, encoder_registry=encoder_registry, class_encoder_registry=class_encoder_registry, ), "generation_strategy": object_to_json( scheduler.generation_strategy, encoder_registry=encoder_registry, class_encoder_registry=class_encoder_registry, ), "options": object_to_json( scheduler.options, encoder_registry=encoder_registry, class_encoder_registry=class_encoder_registry, ), }
[docs]def scheduler_from_json_snapshot( serialized: Dict[str, Any], decoder_registry: Optional[Dict[str, Type]] = None, class_decoder_registry: Optional[Dict[str, Callable[[Dict[str, Any]], Any]]] = None, **kwargs, ) -> Scheduler: """Recreate an `AxClient` from a JSON snapshot.""" if decoder_registry is None: decoder_registry = CORE_DECODER_REGISTRY if class_decoder_registry is None: class_decoder_registry = CORE_CLASS_DECODER_REGISTRY if "options" in serialized: options = object_from_json( serialized.pop("options"), decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) else: options = SchedulerOptions() experiment = object_from_json( serialized.pop("experiment"), decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) serialized_generation_strategy = serialized.pop("generation_strategy") generation_strategy = generation_strategy_from_json( generation_strategy_json=serialized_generation_strategy, experiment=experiment ) ax_client = Scheduler(generation_strategy=generation_strategy, experiment=experiment, options=options, **kwargs) ax_client._experiment = experiment return ax_client