Skip to content

Commit

Permalink
add total constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Dec 20, 2024
1 parent bc26a0f commit 9771993
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
12 changes: 4 additions & 8 deletions src/blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,15 +863,11 @@ def _set_hypers(self, hypers):
self.validity_constraint.load_state_dict(hypers["validity_constraint"])

def constraint(self, x):
p = torch.ones(x.shape[:-1])
log_p = torch.zeros(x.shape[:-1])
for obj in self.objectives(active=True):
# if the constraint is non-trivial
if obj.constraint is not None:
p *= obj.constraint_probability(x)
# if the validity constaint is non-trivial
if obj.validity_conjugate_model is not None:
p *= obj.validity_constraint(x)
return p # + 1e-6 * normalize(x, self.sample_domain).square().sum(axis=-1)
log_p += obj.log_total_constraint(x)

return log_p.exp() # + 1e-6 * normalize(x, self.sample_domain).square().sum(axis=-1)

@property
def hypers(self) -> dict:
Expand Down
13 changes: 13 additions & 0 deletions src/blop/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,19 @@ def constrain(self, y):
else:
return np.array([value in self.constraint for value in np.atleast_1d(y)])

def log_total_constraint(self, x):

log_p = 0
# if you have a constraint
if self.constraint is not None:
log_p += self.constraint_probability(x).log()

# if the validity constaint is non-trivial
if self.validity_conjugate_model is not None:
log_p += self.validity_constraint(x).log()

return log_p

@property
def _trust_domain(self):
if self.trust_domain is None:
Expand Down

0 comments on commit 9771993

Please sign in to comment.