Skip to content

Commit

Permalink
Merge pull request #79 from thomaswmorris/logging
Browse files Browse the repository at this point in the history
Add logging
  • Loading branch information
jennmald authored Dec 20, 2024
2 parents 1e747ea + 9771993 commit 9fc28ea
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 18 deletions.
10 changes: 10 additions & 0 deletions src/blop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import logging

from . import utils # noqa F401
from ._version import __version__, __version_tuple__ # noqa: F401
from .agent import Agent # noqa F401
from .dofs import DOF # noqa F401
from .objectives import Objective # noqa F401

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s.%(msecs)03d %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

logger = logging.getLogger("maria")
20 changes: 9 additions & 11 deletions src/blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from .objectives import Objective, ObjectiveList
from .plans import default_acquisition_plan

logger = logging.getLogger("maria")

warnings.filterwarnings("ignore", category=botorch.exceptions.warnings.InputDataWarning)

mpl.rc("image", cmap="coolwarm")
Expand Down Expand Up @@ -382,7 +384,7 @@ def tell(
t0 = ttime.monotonic()
train_model(obj.model)
if self.verbose:
print(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms")
logger.debug(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms")

else:
train_model(obj.model, hypers=cached_hypers)
Expand Down Expand Up @@ -432,7 +434,7 @@ def learn(

for i in range(iterations):
if self.verbose:
print(f"running iteration {i + 1} / {iterations}")
logger.info(f"running iteration {i + 1} / {iterations}")
for single_acqf in np.atleast_1d(acqf):
res = self.ask(n=n, acqf=single_acqf, upsample=upsample, route=route, **acqf_kwargs)
new_table = yield from self.acquire(res["points"])
Expand Down Expand Up @@ -761,7 +763,7 @@ def _train_all_models(self, **kwargs):
train_model(obj.validity_conjugate_model)

if self.verbose:
print(f"trained models in {ttime.monotonic() - t0:.01f} seconds")
logger.info(f"trained models in {ttime.monotonic() - t0:.01f} seconds")

self.n_last_trained = len(self._table)

Expand Down Expand Up @@ -861,15 +863,11 @@ def _set_hypers(self, hypers):
self.validity_constraint.load_state_dict(hypers["validity_constraint"])

def constraint(self, x):
p = torch.ones(x.shape[:-1])
log_p = torch.zeros(x.shape[:-1])
for obj in self.objectives(active=True):
# if the constraint is non-trivial
if obj.constraint is not None:
p *= obj.constraint_probability(x)
# if the validity constaint is non-trivial
if obj.validity_conjugate_model is not None:
p *= obj.validity_constraint(x)
return p # + 1e-6 * normalize(x, self.sample_domain).square().sum(axis=-1)
log_p += obj.log_total_constraint(x)

return log_p.exp() # + 1e-6 * normalize(x, self.sample_domain).square().sum(axis=-1)

@property
def hypers(self) -> dict:
Expand Down
21 changes: 15 additions & 6 deletions src/blop/bayesian/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,22 @@
from . import kernels


def train_model(model, hypers=None, **kwargs):
def train_model(model, hypers=None, max_fails=4, **kwargs):
"""Fit all of the agent's models. All kwargs are passed to `botorch.fit.fit_gpytorch_mll`."""
if hypers is not None:
model.load_state_dict(hypers)
else:
botorch.fit.fit_gpytorch_mll(gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model), **kwargs)
model.trained = True
fails = 0
while True:
try:
if hypers is not None:
model.load_state_dict(hypers)
else:
botorch.fit.fit_gpytorch_mll(gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model), **kwargs)
model.trained = True
return
except Exception as e:
if fails < max_fails:
fails += 1
else:
raise e


def construct_single_task_model(X, y, skew_dims=None, min_noise=1e-6, max_noise=1e0):
Expand Down
13 changes: 13 additions & 0 deletions src/blop/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,19 @@ def constrain(self, y):
else:
return np.array([value in self.constraint for value in np.atleast_1d(y)])

def log_total_constraint(self, x):

log_p = 0
# if you have a constraint
if self.constraint is not None:
log_p += self.constraint_probability(x).log()

# if the validity constaint is non-trivial
if self.validity_conjugate_model is not None:
log_p += self.validity_constraint(x).log()

return log_p

@property
def _trust_domain(self):
if self.trust_domain is None:
Expand Down
2 changes: 1 addition & 1 deletion src/blop/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _plot_fitness_objs_one_dof(agent, size=16, lw=1e0):
test_model_inputs = agent.dofs(active=True).transform(test_inputs)

for obj_index, obj in enumerate(fitness_objs):
obj_values = agent.train_targets()(obj.name).numpy()
obj_values = agent.train_targets()[obj.name].numpy()

color = DEFAULT_COLOR_LIST[obj_index]

Expand Down

0 comments on commit 9fc28ea

Please sign in to comment.