Source code for boa.instantiation_base

from __future__ import annotations

from ax import Objective as AxObjective
from ax.core.objective import ScalarizedObjective
from ax.service.utils.instantiation import InstantiationBase

from boa.config import BOAMetric, BOAObjective
from boa.metrics.metrics import get_metric_from_config
from boa.metrics.modular_metric import ModularMetric


[docs]class BoaInstantiationBase(InstantiationBase):
[docs] @classmethod def make_optimization_config( cls, objective: BOAObjective, status_quo_defined: bool = False, **kwargs, ): return cls.optimization_config_from_objectives( cls.make_objectives(objective, **kwargs), cls.make_objective_thresholds(objective.objective_thresholds, status_quo_defined), cls.make_outcome_constraints(objective.outcome_constraints, status_quo_defined), )
[docs] @classmethod def get_metric_from_metric_config(cls, metric_conf: BOAMetric, **kwargs) -> ModularMetric: metric = get_metric_from_config(metric_conf, **kwargs) return metric
[docs] @classmethod def get_metrics_from_obj_config(cls, objective: BOAObjective, info_only=False, **kwargs) -> list[ModularMetric]: """""" metrics = [] for metric_conf in objective.metrics: tracker = metric_conf.info_only metric = cls.get_metric_from_metric_config(metric_conf, **kwargs) if info_only is None: # get all metrics metrics.append(metric) elif info_only is True and tracker: # only get tracking metrics metrics.append(metric) elif info_only is False and not tracker: # only get not tracking metrics metrics.append(metric) return metrics
[docs] @classmethod def make_objectives(cls, objective: BOAObjective, **kwargs) -> list[AxObjective]: metrics = cls.get_metrics_from_obj_config(objective, **kwargs) weights = [metric.weight for metric in metrics] kw = {} if any(weights): kw["weights"] = weights if objective.minimize is not None: kw["minimize"] = objective.minimize output_objectives = [ScalarizedObjective(metrics=metrics, **kw)] else: output_objectives = [AxObjective(metric=metric, minimize=metric.lower_is_better) for metric in metrics] return output_objectives