From 9771993edea812e0d261fb550c4d9c99bbd5ff34 Mon Sep 17 00:00:00 2001 From: Thomas Morris Date: Fri, 20 Dec 2024 11:28:33 -0500 Subject: [PATCH] add total constraint --- src/blop/agent.py | 12 ++++-------- src/blop/objectives.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/blop/agent.py b/src/blop/agent.py index 9178a8a..c3fcdce 100644 --- a/src/blop/agent.py +++ b/src/blop/agent.py @@ -863,15 +863,11 @@ def _set_hypers(self, hypers): self.validity_constraint.load_state_dict(hypers["validity_constraint"]) def constraint(self, x): - p = torch.ones(x.shape[:-1]) + log_p = torch.zeros(x.shape[:-1]) for obj in self.objectives(active=True): - # if the constraint is non-trivial - if obj.constraint is not None: - p *= obj.constraint_probability(x) - # if the validity constaint is non-trivial - if obj.validity_conjugate_model is not None: - p *= obj.validity_constraint(x) - return p # + 1e-6 * normalize(x, self.sample_domain).square().sum(axis=-1) + log_p += obj.log_total_constraint(x) + + return log_p.exp() # + 1e-6 * normalize(x, self.sample_domain).square().sum(axis=-1) @property def hypers(self) -> dict: diff --git a/src/blop/objectives.py b/src/blop/objectives.py index 5479767..0101d3c 100644 --- a/src/blop/objectives.py +++ b/src/blop/objectives.py @@ -155,6 +155,19 @@ def constrain(self, y): else: return np.array([value in self.constraint for value in np.atleast_1d(y)]) + def log_total_constraint(self, x): + + log_p = 0 + # if you have a constraint + if self.constraint is not None: + log_p += self.constraint_probability(x).log() + + # if the validity constaint is non-trivial + if self.validity_conjugate_model is not None: + log_p += self.validity_constraint(x).log() + + return log_p + @property def _trust_domain(self): if self.trust_domain is None: