Skip to content

Commit

Permalink
add logging
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Dec 19, 2024
1 parent 1e747ea commit bc26a0f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 10 deletions.
10 changes: 10 additions & 0 deletions src/blop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import logging

from . import utils # noqa F401
from ._version import __version__, __version_tuple__ # noqa: F401
from .agent import Agent # noqa F401
from .dofs import DOF # noqa F401
from .objectives import Objective # noqa F401

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s.%(msecs)03d %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

logger = logging.getLogger("maria")
8 changes: 5 additions & 3 deletions src/blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from .objectives import Objective, ObjectiveList
from .plans import default_acquisition_plan

logger = logging.getLogger("maria")

warnings.filterwarnings("ignore", category=botorch.exceptions.warnings.InputDataWarning)

mpl.rc("image", cmap="coolwarm")
Expand Down Expand Up @@ -382,7 +384,7 @@ def tell(
t0 = ttime.monotonic()
train_model(obj.model)
if self.verbose:
print(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms")
logger.debug(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms")

else:
train_model(obj.model, hypers=cached_hypers)
Expand Down Expand Up @@ -432,7 +434,7 @@ def learn(

for i in range(iterations):
if self.verbose:
print(f"running iteration {i + 1} / {iterations}")
logger.info(f"running iteration {i + 1} / {iterations}")
for single_acqf in np.atleast_1d(acqf):
res = self.ask(n=n, acqf=single_acqf, upsample=upsample, route=route, **acqf_kwargs)
new_table = yield from self.acquire(res["points"])
Expand Down Expand Up @@ -761,7 +763,7 @@ def _train_all_models(self, **kwargs):
train_model(obj.validity_conjugate_model)

if self.verbose:
print(f"trained models in {ttime.monotonic() - t0:.01f} seconds")
logger.info(f"trained models in {ttime.monotonic() - t0:.01f} seconds")

self.n_last_trained = len(self._table)

Expand Down
21 changes: 15 additions & 6 deletions src/blop/bayesian/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,22 @@
from . import kernels


def train_model(model, hypers=None, **kwargs):
def train_model(model, hypers=None, max_fails=4, **kwargs):
"""Fit all of the agent's models. All kwargs are passed to `botorch.fit.fit_gpytorch_mll`."""
if hypers is not None:
model.load_state_dict(hypers)
else:
botorch.fit.fit_gpytorch_mll(gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model), **kwargs)
model.trained = True
fails = 0
while True:
try:
if hypers is not None:
model.load_state_dict(hypers)
else:
botorch.fit.fit_gpytorch_mll(gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model), **kwargs)
model.trained = True
return
except Exception as e:
if fails < max_fails:
fails += 1
else:
raise e


def construct_single_task_model(X, y, skew_dims=None, min_noise=1e-6, max_noise=1e0):
Expand Down
2 changes: 1 addition & 1 deletion src/blop/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _plot_fitness_objs_one_dof(agent, size=16, lw=1e0):
test_model_inputs = agent.dofs(active=True).transform(test_inputs)

for obj_index, obj in enumerate(fitness_objs):
obj_values = agent.train_targets()(obj.name).numpy()
obj_values = agent.train_targets()[obj.name].numpy()

color = DEFAULT_COLOR_LIST[obj_index]

Expand Down

0 comments on commit bc26a0f

Please sign in to comment.