diff --git a/src/blop/__init__.py b/src/blop/__init__.py index d2fc954..c771e78 100644 --- a/src/blop/__init__.py +++ b/src/blop/__init__.py @@ -1,5 +1,15 @@ +import logging + from . import utils # noqa F401 from ._version import __version__, __version_tuple__ # noqa: F401 from .agent import Agent # noqa F401 from .dofs import DOF # noqa F401 from .objectives import Objective # noqa F401 + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s.%(msecs)03d %(levelname)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + +logger = logging.getLogger("maria") diff --git a/src/blop/agent.py b/src/blop/agent.py index df04907..c3fcdce 100644 --- a/src/blop/agent.py +++ b/src/blop/agent.py @@ -35,6 +35,8 @@ from .objectives import Objective, ObjectiveList from .plans import default_acquisition_plan +logger = logging.getLogger("maria") + warnings.filterwarnings("ignore", category=botorch.exceptions.warnings.InputDataWarning) mpl.rc("image", cmap="coolwarm") @@ -382,7 +384,7 @@ def tell( t0 = ttime.monotonic() train_model(obj.model) if self.verbose: - print(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms") + logger.debug(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms") else: train_model(obj.model, hypers=cached_hypers) @@ -432,7 +434,7 @@ def learn( for i in range(iterations): if self.verbose: - print(f"running iteration {i + 1} / {iterations}") + logger.info(f"running iteration {i + 1} / {iterations}") for single_acqf in np.atleast_1d(acqf): res = self.ask(n=n, acqf=single_acqf, upsample=upsample, route=route, **acqf_kwargs) new_table = yield from self.acquire(res["points"]) @@ -761,7 +763,7 @@ def _train_all_models(self, **kwargs): train_model(obj.validity_conjugate_model) if self.verbose: - print(f"trained models in {ttime.monotonic() - t0:.01f} seconds") + logger.info(f"trained models in {ttime.monotonic() - t0:.01f} seconds") self.n_last_trained = len(self._table) @@ -861,15 +863,11 @@ def _set_hypers(self, hypers): self.validity_constraint.load_state_dict(hypers["validity_constraint"]) def constraint(self, x): - p = torch.ones(x.shape[:-1]) + log_p = torch.zeros(x.shape[:-1]) for obj in self.objectives(active=True): - # if the constraint is non-trivial - if obj.constraint is not None: - p *= obj.constraint_probability(x) - # if the validity constaint is non-trivial - if obj.validity_conjugate_model is not None: - p *= obj.validity_constraint(x) - return p # + 1e-6 * normalize(x, self.sample_domain).square().sum(axis=-1) + log_p += obj.log_total_constraint(x) + + return log_p.exp() # + 1e-6 * normalize(x, self.sample_domain).square().sum(axis=-1) @property def hypers(self) -> dict: diff --git a/src/blop/bayesian/models.py b/src/blop/bayesian/models.py index c125a1d..a76e29d 100644 --- a/src/blop/bayesian/models.py +++ b/src/blop/bayesian/models.py @@ -6,13 +6,22 @@ from . import kernels -def train_model(model, hypers=None, **kwargs): +def train_model(model, hypers=None, max_fails=4, **kwargs): """Fit all of the agent's models. All kwargs are passed to `botorch.fit.fit_gpytorch_mll`.""" - if hypers is not None: - model.load_state_dict(hypers) - else: - botorch.fit.fit_gpytorch_mll(gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model), **kwargs) - model.trained = True + fails = 0 + while True: + try: + if hypers is not None: + model.load_state_dict(hypers) + else: + botorch.fit.fit_gpytorch_mll(gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model), **kwargs) + model.trained = True + return + except Exception as e: + if fails < max_fails: + fails += 1 + else: + raise e def construct_single_task_model(X, y, skew_dims=None, min_noise=1e-6, max_noise=1e0): diff --git a/src/blop/objectives.py b/src/blop/objectives.py index 5479767..0101d3c 100644 --- a/src/blop/objectives.py +++ b/src/blop/objectives.py @@ -155,6 +155,19 @@ def constrain(self, y): else: return np.array([value in self.constraint for value in np.atleast_1d(y)]) + def log_total_constraint(self, x): + + log_p = 0 + # if you have a constraint + if self.constraint is not None: + log_p += self.constraint_probability(x).log() + + # if the validity constaint is non-trivial + if self.validity_conjugate_model is not None: + log_p += self.validity_constraint(x).log() + + return log_p + @property def _trust_domain(self): if self.trust_domain is None: diff --git a/src/blop/plotting.py b/src/blop/plotting.py index 39fb5bd..241d0dd 100644 --- a/src/blop/plotting.py +++ b/src/blop/plotting.py @@ -33,7 +33,7 @@ def _plot_fitness_objs_one_dof(agent, size=16, lw=1e0): test_model_inputs = agent.dofs(active=True).transform(test_inputs) for obj_index, obj in enumerate(fitness_objs): - obj_values = agent.train_targets()(obj.name).numpy() + obj_values = agent.train_targets()[obj.name].numpy() color = DEFAULT_COLOR_LIST[obj_index]