Skip to content

Commit

Permalink
make objective models read-only
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Jul 31, 2024
1 parent c1257a7 commit 76ddd5a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def ask(self, acqf="qei", n=1, route=True, sequential=True, upsample=1, **acqf_k

else:
# check that all the objectives have models
if not all(hasattr(obj, "model") for obj in active_objs):
if not all(hasattr(obj, "_model") for obj in active_objs):
raise RuntimeError(
f"Can't construct non-trivial acquisition function '{acqf}' as the agent is not initialized."
)
Expand Down Expand Up @@ -367,7 +367,7 @@ def tell(
for obj in objectives_to_model:
t0 = ttime.monotonic()

cached_hypers = obj.model.state_dict() if hasattr(obj, "model") else None
cached_hypers = obj.model.state_dict() if hasattr(obj, "_model") else None
n_before_tell = obj.n_valid
self._construct_model(obj)
n_after_tell = obj.n_valid
Expand Down Expand Up @@ -538,8 +538,8 @@ def reset(self):
self._table = pd.DataFrame()

for obj in self.objectives(active=True):
if hasattr(obj, "model"):
del obj.model
if hasattr(obj, "_model"):
del obj._model

self.n_last_trained = 0

Expand Down Expand Up @@ -573,7 +573,7 @@ def benchmark(
def model(self):
"""A model encompassing all the fitnesses and constraints."""
active_objs = self.objectives(active=True)
if all(hasattr(obj, "model") for obj in active_objs):
if all(hasattr(obj, "_model") for obj in active_objs):
return ModelListGP(*[obj.model for obj in active_objs]) if len(active_objs) > 1 else active_objs[0].model
raise ValueError("Not all active objectives have models.")

Expand Down Expand Up @@ -689,7 +689,7 @@ def _construct_model(self, obj, skew_dims=None):

trusted = inputs_are_trusted & targets_are_trusted

obj.model = construct_single_task_model(
obj._model = construct_single_task_model(
X=train_inputs[trusted],
y=train_targets[trusted],
min_noise=obj.min_noise,
Expand Down Expand Up @@ -731,7 +731,7 @@ def _train_all_models(self, **kwargs):
t0 = ttime.monotonic()
objectives_to_train = self.objectives if self.model_inactive_objectives else self.objectives(active=True)
for obj in objectives_to_train:
train_model(obj.model)
train_model(obj._model)
if obj.validity_conjugate_model is not None:
train_model(obj.validity_conjugate_model)

Expand Down
4 changes: 4 additions & 0 deletions src/blop/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ def fitness_prediction(self, X):
if isinstance(self.target, tuple):
return self.targeting_constraint(X).log().clamp(min=-16)

@property
def model(self):
return self._model.eval()


class ObjectiveList(Sequence):
def __init__(self, objectives: list = []):
Expand Down

0 comments on commit 76ddd5a

Please sign in to comment.