diff --git a/cpmpy/solvers/pysat.py b/cpmpy/solvers/pysat.py index 66003b7e1..341528341 100644 --- a/cpmpy/solvers/pysat.py +++ b/cpmpy/solvers/pysat.py @@ -35,11 +35,13 @@ from ..expressions.variables import _BoolVarImpl, NegBoolView, boolvar from ..expressions.globalconstraints import DirectConstraint from ..expressions.utils import is_int, flatlist +from ..transformations.comparison import only_numexpr_equality from ..transformations.decompose_global import decompose_in_tree from ..transformations.get_variables import get_variables from ..transformations.flatten_model import flatten_constraint +from ..transformations.linearize import linearize_constraint, only_positive_bv from ..transformations.normalize import toplevel_list -from ..transformations.reification import only_implies, only_bv_reifies +from ..transformations.reification import only_implies, only_bv_reifies, reify_rewrite class CPM_pysat(SolverInterface): @@ -72,6 +74,15 @@ def supported(): except ImportError as e: return False + @staticmethod + def pb_supported(): + try: + from pypblib import pblib + from pysat.pb import PBEnc + import pysat + return True + except ImportError as e: + return False @staticmethod def solvernames(): @@ -230,9 +241,16 @@ def transform(self, cpm_expr): """ cpm_cons = toplevel_list(cpm_expr) cpm_cons = decompose_in_tree(cpm_cons) - cpm_cons = flatten_constraint(cpm_cons) + cpm_cons = flatten_constraint(cpm_cons) # flat normal form + # the next two only needed if the model contains cardinality/pseudo-boolean constraints + #?cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum'])) # constraints that support reification + #?cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum"])) # supports >, <, != + cpm_cons = only_bv_reifies(cpm_cons) - cpm_cons = only_implies(cpm_cons) + cpm_cons = only_implies(cpm_cons) # anything that can create full reif should go above... + # the next only needed if the model contains cardinality/pseudo-boolean constraints + cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum","wsum", "and", "or", "bv"})) # the core of the MIP-linearization + return cpm_cons def __add__(self, cpm_expr_orig): @@ -263,16 +281,19 @@ def __add__(self, cpm_expr_orig): elif cpm_expr.name == '->': # BV -> BE only thanks to only_bv_reifies a0,a1 = cpm_expr.args - # BoolVar() -> BoolVar() if isinstance(a1, _BoolVarImpl): + # BoolVar() -> BoolVar() args = [~a0, a1] self.pysat_solver.add_clause(self.solver_vars(args)) elif isinstance(a1, Operator) and a1.name == 'or': + # BoolVar() -> or(...) args = [~a0]+a1.args self.pysat_solver.add_clause(self.solver_vars(args)) elif hasattr(a1, 'decompose'): # implied global constraint + # TODO @wout I think we decompose in transformation now? self += a0.implies(cpm_expr.decompose()) - elif isinstance(a1, Comparison) and a1.args[0].name == "sum": # implied sum comparison (a0->sum(bvs)<>val) + elif isinstance(a1, Comparison) and a1.args[0].name == "sum": + # implied sum comparison (a0->sum(bvs)<>val) # convert sum to clauses sum_clauses = self._pysat_cardinality(a1) # implication of conjunction is conjunction of individual implications @@ -280,12 +301,31 @@ def __add__(self, cpm_expr_orig): clauses = [nimplvar+c for c in sum_clauses] self.pysat_solver.append_formula(clauses) - elif isinstance(cpm_expr, Comparison): - # only handle cardinality encodings (for now) - if isinstance(cpm_expr.args[0], Operator) and cpm_expr.args[0].name == "sum": - # convert to clauses and post - clauses = self._pysat_cardinality(cpm_expr) + elif isinstance(a1, Comparison) and (a1.args[0].name == "wsum" or a1.args[0].name == "mul"): # implied pseudo-boolean comparison (a0->wsum(ws,bvs)<>val) + # implied sum comparison (a0->wsum([w,bvs])<>val or a0->(w*bv<>val)) + # convert wsum to clauses + wsum_clauses = self._pysat_pseudoboolean(a1) + # implication of conjunction is conjunction of individual implications + nimplvar = [self.solver_var(~a0)] + clauses = [nimplvar+c for c in wsum_clauses] self.pysat_solver.append_formula(clauses) + else: + raise NotSupportedError(f"Implication: {cpm_expr} not supported by CPM_pysat") + + elif isinstance(cpm_expr, Comparison): + # comparisons between Booleans will have been transformed out + # check if comparison of cardinality/pseudo-boolean constraint + if isinstance(cpm_expr.args[0], Operator): + if cpm_expr.args[0].name == "sum": + # convert to clauses and post + clauses = self._pysat_cardinality(cpm_expr) + self.pysat_solver.append_formula(clauses) + elif (cpm_expr.args[0].name == "wsum" or cpm_expr.args[0].name == "mul"): + # convert to clauses and post + clauses = self._pysat_pseudoboolean(cpm_expr) + self.pysat_solver.append_formula(clauses) + else: + raise NotImplementedError(f"Operator constraint {cpm_expr} not supported by CPM_pysat") else: raise NotImplementedError(f"Non-operator constraint {cpm_expr} not supported by CPM_pysat") @@ -357,9 +397,11 @@ def _pysat_cardinality(self, cpm_compsum): if not isinstance(cpm_compsum, Comparison): raise NotSupportedError(f"PySAT card: input constraint must be Comparison -- {cpm_compsum}") if not is_int(cpm_compsum.args[1]): - raise NotSupportedError(f"PySAT card: sum must have constant at rhs not {cpm_compsum.args[1]} -- {cpm_compsum}") + raise NotSupportedError( + f"PySAT card: sum must have constant at rhs not {cpm_compsum.args[1]} -- {cpm_compsum}") if not cpm_compsum.args[0].name == "sum": - raise NotSupportedError(f"PySAT card: input constraint must be sum, got {cpm_compsum.args[0].name} -- {cpm_compsum}") + raise NotSupportedError( + f"PySAT card: input constraint must be sum, got {cpm_compsum.args[0].name} -- {cpm_compsum}") if not (all(isinstance(v, _BoolVarImpl) for v in cpm_compsum.args[0].args)): raise NotSupportedError(f"PySAT card: sum must be over Boolvars only -- {cpm_compsum.args[0]}") @@ -368,31 +410,76 @@ def _pysat_cardinality(self, cpm_compsum): lits = self.solver_vars(cpm_compsum.args[0].args) bound = cpm_compsum.args[1] - if cpm_compsum.name == "<": - return CardEnc.atmost(lits=lits, bound=bound-1, vpool=self.pysat_vpool).clauses - elif cpm_compsum.name == "<=": + # if cpm_compsum.name == "<": + # return CardEnc.atmost(lits=lits, bound=bound-1, vpool=self.pysat_vpool).clauses + if cpm_compsum.name == "<=": return CardEnc.atmost(lits=lits, bound=bound, vpool=self.pysat_vpool).clauses elif cpm_compsum.name == ">=": return CardEnc.atleast(lits=lits, bound=bound, vpool=self.pysat_vpool).clauses - elif cpm_compsum.name == ">": - return CardEnc.atleast(lits=lits, bound=bound+1, vpool=self.pysat_vpool).clauses + # elif cpm_compsum.name == ">": + # return CardEnc.atleast(lits=lits, bound=bound+1, vpool=self.pysat_vpool).clauses elif cpm_compsum.name == "==": return CardEnc.equals(lits=lits, bound=bound, vpool=self.pysat_vpool).clauses - elif cpm_compsum.name == "!=": - # special cases with bounding 'hardcoded' for clarity - if bound <= 0: - return CardEnc.atleast(lits=lits, bound=bound+1, vpool=self.pysat_vpool).clauses - elif bound >= len(lits): - return CardEnc.atmost(lits=lits, bound=bound-1, vpool=self.pysat_vpool).clauses - else: - ## add implication literals for (strict) atleast/atmost, one must be true - is_atleast = self.solver_var(boolvar()) - is_atmost = self.solver_var(boolvar()) - clauses = [[is_atleast, is_atmost]] - clauses += [atl + [-is_atleast] for atl in - CardEnc.atleast(lits=lits, bound=bound+1, vpool=self.pysat_vpool).clauses] - clauses += [atm + [-is_atmost] for atm in - CardEnc.atmost(lits=lits, bound=bound-1, vpool=self.pysat_vpool).clauses] - return clauses + # elif cpm_compsum.name == "!=": + # # special cases with bounding 'hardcoded' for clarity + # if bound <= 0: + # return CardEnc.atleast(lits=lits, bound=bound+1, vpool=self.pysat_vpool).clauses + # elif bound >= len(lits): + # return CardEnc.atmost(lits=lits, bound=bound-1, vpool=self.pysat_vpool).clauses + # else: + # ## add implication literals for (strict) atleast/atmost, one must be true + # is_atleast = self.solver_var(boolvar()) + # is_atmost = self.solver_var(boolvar()) + # clauses = [[is_atleast, is_atmost]] + # clauses += [atl + [-is_atleast] for atl in + # CardEnc.atleast(lits=lits, bound=bound+1, vpool=self.pysat_vpool).clauses] + # clauses += [atm + [-is_atmost] for atm in + # CardEnc.atmost(lits=lits, bound=bound-1, vpool=self.pysat_vpool).clauses] + # return clauses raise NotImplementedError(f"Non-operator constraint {cpm_compsum} not supported by CPM_pysat") + + def _pysat_pseudoboolean(self, cpm_expr): + if not CPM_pysat.pb_supported(): + raise ImportError("Please install PyPBLib: pip install pypblib") + from pysat.pb import PBEnc + + left = cpm_expr.args[0] # left-hand side, sum/wsum/mul + bound = cpm_expr.args[1] # right-hand side, constant + assert (is_int(bound)), f"PySAT PB: pseudo-Boolean must have constant at rhs not {left.args[1]} -- {left}" + + if left.name == "mul": + if not is_int(left.args[0]): + raise NotSupportedError(f"CPM_pysat: multiplication of variable with non-integer not supported: {left} in {cpm_expr}") + # single weight,value pair, in list + weights = [left.args[0]] + lits = [self.solver_var(left.args[1])] + else: # wsum + weights = left.args[0] + lits = self.solver_vars(left.args[1]) + + # if cpm_expr.name == "<": # edge case? or (cpm_expr.name == "!=" and bound >= sum(max(0,weights))): + # return PBEnc.leq(lits=lits, weights=weights, bound=bound-1, vpool=self.pysat_vpool).clauses + if cpm_expr.name == "<=": + return PBEnc.leq(lits=lits, weights=weights, bound=bound,vpool=self.pysat_vpool).clauses + # elif cpm_expr.name == ">": # edge case? or (cpm_expr.name == "!=" and bound <= sum(min(0,weights))): + # return PBEnc.geq(lits=lits, weights=weights, bound=bound+1, vpool=self.pysat_vpool).clauses + elif cpm_expr.name == ">=": + return PBEnc.geq(lits=lits, weights=weights, bound=bound ,vpool=self.pysat_vpool).clauses + elif cpm_expr.name == "==": + return PBEnc.equals(lits=lits, weights=weights, bound=bound, vpool=self.pysat_vpool) + + # elif cpm_expr.name == "!=": + # # XXX This case already covered by linearize (which uses just 1 literal is_atleast=-is_atmost) + # # BUG with pblib solved in Pysat dev 0.1.7.dev12 + # ## add implication literals for (strict) atleast/atmost, one must be true + # is_atleast = self.solver_var(boolvar()) + # is_atmost = self.solver_var(boolvar()) + # clauses = [[is_atleast, is_atmost]] + # clauses += [atl + [-is_atleast] for atl in + # PBEnc.geq(lits=lits, weights=weights, bound=bound+1, vpool=self.pysat_vpool).clauses] + # clauses += [atm + [-is_atmost] for atm in + # PBEnc.leq(lits=lits, weights=weights, bound=bound-1, vpool=self.pysat_vpool).clauses] + # return clauses + + raise NotImplementedError(f"Comparison: {cpm_expr} not supported by CPM_pysat") diff --git a/cpmpy/transformations/linearize.py b/cpmpy/transformations/linearize.py index 63ad3bb68..14c7ad007 100644 --- a/cpmpy/transformations/linearize.py +++ b/cpmpy/transformations/linearize.py @@ -67,16 +67,19 @@ def linearize_constraint(lst_of_expr, supported={"sum","wsum"}, reified=False): # boolvar if isinstance(cpm_expr, _BoolVarImpl): - newlist.append(sum([cpm_expr]) >= 1) + if "bv" in supported: + newlist.append(cpm_expr) + else: + newlist.append(sum([cpm_expr]) >= 1) # Boolean operators elif isinstance(cpm_expr, Operator) and cpm_expr.is_bool(): # conjunction - if cpm_expr.name == "and": + if cpm_expr.name == "and" and cpm_expr.name not in supported: newlist.append(sum(cpm_expr.args) >= len(cpm_expr.args)) # disjunction - elif cpm_expr.name == "or": + elif cpm_expr.name == "or" and cpm_expr.name not in supported: newlist.append(sum(cpm_expr.args) >= 1) # xor @@ -101,6 +104,9 @@ def linearize_constraint(lst_of_expr, supported={"sum","wsum"}, reified=False): new_vars = set(get_variables(lin_sub)) - set(get_variables(sub_expr)) newlist += linearize_constraint([(~cond).implies(nv == nv.lb) for nv in new_vars], reified=reified) + else: # supported operator + newlist.append(cpm_expr) + # comparisons elif isinstance(cpm_expr, Comparison): diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 895d675bf..ad5f92a6c 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -8,6 +8,7 @@ # make sure that `SolverLookup.get(solver)` works # also add exclusions to the 3 EXCLUDE_* below as needed SOLVERNAMES = [name for name, solver in SolverLookup.base_solvers() if solver.supported()] +SOLVERNAMES = ["pysat"] # Exclude some global constraints for solvers # Can be used when .value() method is not implemented/contains bugs diff --git a/tests/test_pysat_cardinality.py b/tests/test_pysat_cardinality.py index 1b193d2a5..43cc1c9bd 100644 --- a/tests/test_pysat_cardinality.py +++ b/tests/test_pysat_cardinality.py @@ -94,6 +94,40 @@ def test_pysat_atleast_equals(self): self.assertGreaterEqual(sum(self.bvs.value()), 2) + def test_pysat_linear_other(self): + expressions = [ + self.bvs[0] + self.bvs[1] + self.bvs[2] > 0, + # now with var/expr on RHS + self.bvs[0] + self.bvs[1] > self.bvs[2], + self.bvs[0] > self.bvs[1] + self.bvs[2], + self.bvs[0] > (self.bvs[1] | self.bvs[2]), + ] + + ## check all types of linear constraints are handled + for expression in expressions: + cp.Model(expression).solve("pysat") + + def test_pysat_oob(self): + + def test_encode_pb_oob(self): + self.assertTrue(len(self.bv) == 3) + # test out of bounds (meaningless) thresholds + expressions = [ + sum(self.bv) <= 5, # true + sum(self.bv) <= 3, # true + sum(self.bv) <= -2, # false + sum(self.bv) <= 0, # undecided + + sum(self.bv) >= -2, # true + sum(self.bv) >= 0, # true + sum(self.bv) >= 5, # false + sum(self.bv) >= 3, # undecided + ] + + ## check all types of linear constraints are handled + for expression in expressions: + cp.Model(expression).solve("pysat") + def test_pysat_different(self): differrent = cp.Model( diff --git a/tests/test_pysat_wsum.py b/tests/test_pysat_wsum.py new file mode 100644 index 000000000..391f62a06 --- /dev/null +++ b/tests/test_pysat_wsum.py @@ -0,0 +1,74 @@ +import unittest +import cpmpy as cp +from cpmpy import * +from cpmpy.solvers.pysat import CPM_pysat +from cpmpy.transformations.to_cnf import to_cnf + +class TestEncodePseudoBooleanConstraint(unittest.TestCase): + def setUp(self): + self.bv = boolvar(shape=3) + + def test_pysat_simple_atmost(self): + + atmost = cp.Model( + ## < + - 2 * self.bv[0] < 3, + ## <= + - 3 * self.bv[1] <= 3, + ## > + 2 * self.bv[2] > 1, + ## >= + 4 * self.bv[2] >= 3, + ) + ps = CPM_pysat(atmost) + ps.solve() + + def test_pysat_unsat(self): + ls = cp.Model( + 2 * self.bv[0] + 3 * self.bv[1] <= 3, + self.bv[0] == 1, + self.bv[1] == 1 + ) + + ps = CPM_pysat(ls) + solved = ps.solve() + self.assertFalse(solved) + + def test_encode_pb_expressions(self): + expressions = [ + - self.bv[2] == -1, + - 2 * self.bv[2] == -2, + self.bv[0] - self.bv[2] > 0, + -self.bv[0] + self.bv[2] > 0, + 2 * self.bv[0] + 3 * self.bv[2] > 0, + 2 * self.bv[0] - 3 * self.bv[2] + 2 * self.bv[1] > 0, + self.bv[0] - 3 * self.bv[2] > 0, + self.bv[0] - 3 * (self.bv[2] + 2 * self.bv[1])> 0, + # now with var on RHS + self.bv[0] - 3 * self.bv[1] > self.bv[2], + ] + + ## check all types of linear constraints are handled + for expression in expressions: + Model(expression).solve("pysat") + + def test_encode_pb_oob(self): + # test out of bounds (meaningless) thresholds + expressions = [ + sum(self.bv*[2,2,2]) <= 10, # true + sum(self.bv*[2,2,2]) <= 6, # true + sum(self.bv*[2,2,2]) >= 10, # false + sum(self.bv*[2,2,2]) >= 6, # undecided + sum(self.bv*[2,-2,2]) <= 10, # true + sum(self.bv*[2,-2,2]) <= 4, # true + sum(self.bv*[2,-2,2]) >= 10, # false + sum(self.bv*[2,-2,2]) >= 4, # undecided + ] + + ## check all types of linear constraints are handled + for expression in expressions: + Model(expression).solve("pysat") + +if __name__ == '__main__': + unittest.main() + diff --git a/tests/test_solvers.py b/tests/test_solvers.py index e8f06849b..307fd2476 100644 --- a/tests/test_solvers.py +++ b/tests/test_solvers.py @@ -343,6 +343,18 @@ def test_pysat(self): self.assertFalse(ps2.solve(assumptions=[mayo]+[v for v in inds])) self.assertEqual(ps2.get_core(), [mayo,inds[6],inds[9]]) + @pytest.mark.skipif(not CPM_pysat.supported(), + reason="PySAT not installed") + def test_pysat_card(self): + b = cp.boolvar() + x = cp.boolvar(shape=5) + + cons = [sum(x) > 3, sum(x) <= 2, sum(x) == 4, (sum(x) <= 1) & (sum(x) != 2), + b.implies(sum(x) > 3), b == (sum(x) != 2), (sum(x) >= 3).implies(b)] + for c in cons: + cp.Model(c).solve("pysat") + self.assertTrue(c.value()) + @pytest.mark.skipif(not CPM_minizinc.supported(), reason="MiniZinc not installed")