From bc26a0feb139370ee1449f32839f1494cbe42c69 Mon Sep 17 00:00:00 2001 From: Thomas Morris Date: Thu, 19 Dec 2024 12:36:58 -0500 Subject: [PATCH] add logging --- src/blop/__init__.py | 10 ++++++++++ src/blop/agent.py | 8 +++++--- src/blop/bayesian/models.py | 21 +++++++++++++++------ src/blop/plotting.py | 2 +- 4 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/blop/__init__.py b/src/blop/__init__.py index d2fc954..c771e78 100644 --- a/src/blop/__init__.py +++ b/src/blop/__init__.py @@ -1,5 +1,15 @@ +import logging + from . import utils # noqa F401 from ._version import __version__, __version_tuple__ # noqa: F401 from .agent import Agent # noqa F401 from .dofs import DOF # noqa F401 from .objectives import Objective # noqa F401 + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s.%(msecs)03d %(levelname)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + +logger = logging.getLogger("maria") diff --git a/src/blop/agent.py b/src/blop/agent.py index df04907..9178a8a 100644 --- a/src/blop/agent.py +++ b/src/blop/agent.py @@ -35,6 +35,8 @@ from .objectives import Objective, ObjectiveList from .plans import default_acquisition_plan +logger = logging.getLogger("maria") + warnings.filterwarnings("ignore", category=botorch.exceptions.warnings.InputDataWarning) mpl.rc("image", cmap="coolwarm") @@ -382,7 +384,7 @@ def tell( t0 = ttime.monotonic() train_model(obj.model) if self.verbose: - print(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms") + logger.debug(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms") else: train_model(obj.model, hypers=cached_hypers) @@ -432,7 +434,7 @@ def learn( for i in range(iterations): if self.verbose: - print(f"running iteration {i + 1} / {iterations}") + logger.info(f"running iteration {i + 1} / {iterations}") for single_acqf in np.atleast_1d(acqf): res = self.ask(n=n, acqf=single_acqf, upsample=upsample, route=route, **acqf_kwargs) new_table = yield from self.acquire(res["points"]) @@ -761,7 +763,7 @@ def _train_all_models(self, **kwargs): train_model(obj.validity_conjugate_model) if self.verbose: - print(f"trained models in {ttime.monotonic() - t0:.01f} seconds") + logger.info(f"trained models in {ttime.monotonic() - t0:.01f} seconds") self.n_last_trained = len(self._table) diff --git a/src/blop/bayesian/models.py b/src/blop/bayesian/models.py index c125a1d..a76e29d 100644 --- a/src/blop/bayesian/models.py +++ b/src/blop/bayesian/models.py @@ -6,13 +6,22 @@ from . import kernels -def train_model(model, hypers=None, **kwargs): +def train_model(model, hypers=None, max_fails=4, **kwargs): """Fit all of the agent's models. All kwargs are passed to `botorch.fit.fit_gpytorch_mll`.""" - if hypers is not None: - model.load_state_dict(hypers) - else: - botorch.fit.fit_gpytorch_mll(gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model), **kwargs) - model.trained = True + fails = 0 + while True: + try: + if hypers is not None: + model.load_state_dict(hypers) + else: + botorch.fit.fit_gpytorch_mll(gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model), **kwargs) + model.trained = True + return + except Exception as e: + if fails < max_fails: + fails += 1 + else: + raise e def construct_single_task_model(X, y, skew_dims=None, min_noise=1e-6, max_noise=1e0): diff --git a/src/blop/plotting.py b/src/blop/plotting.py index 39fb5bd..241d0dd 100644 --- a/src/blop/plotting.py +++ b/src/blop/plotting.py @@ -33,7 +33,7 @@ def _plot_fitness_objs_one_dof(agent, size=16, lw=1e0): test_model_inputs = agent.dofs(active=True).transform(test_inputs) for obj_index, obj in enumerate(fitness_objs): - obj_values = agent.train_targets()(obj.name).numpy() + obj_values = agent.train_targets()[obj.name].numpy() color = DEFAULT_COLOR_LIST[obj_index]