From 914153a0e60962fcba785af8f3af964547f6ab2e Mon Sep 17 00:00:00 2001 From: Ignace Bleukx Date: Fri, 8 Sep 2023 08:39:12 -0400 Subject: [PATCH] Allow description for a constraint (#367) * allow a constraint to have a description * add test * swap MRO for NDVarArray so __str__ of ndarray is called instead of Expression * add case for negboolview * add extra arguments for set_description * remove custom for _NegBoolView --- cpmpy/expressions/core.py | 16 ++++++++++++++- cpmpy/expressions/variables.py | 2 +- tests/test_expressions.py | 36 ++++++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/cpmpy/expressions/core.py b/cpmpy/expressions/core.py index 6f7e32140..709d00811 100644 --- a/cpmpy/expressions/core.py +++ b/cpmpy/expressions/core.py @@ -107,6 +107,20 @@ def __init__(self, name, arg_list): assert (is_any_list(arg_list)), "_list_ of arguments required, even if of length one e.g. [arg]" self.args = arg_list + def set_description(self, txt, override_print=True, full_print=False): + self.desc = txt + self._override_print = override_print + self._full_print = full_print + + def __str__(self): + if not hasattr(self, "desc") or self._override_print is False: + return self.__repr__() + out = self.desc + if self._full_print: + out += " -- "+self.__repr__() + return out + + def __repr__(self): strargs = [] for arg in self.args: @@ -119,7 +133,7 @@ def __repr__(self): return "{}({})".format(self.name, ",".join(strargs)) def __hash__(self): - return hash(self.__str__()) + return hash(self.__repr__()) def is_bool(self): """ is it a Boolean (return type) Operator? diff --git a/cpmpy/expressions/variables.py b/cpmpy/expressions/variables.py index 5225cf22d..b8a0f88d3 100644 --- a/cpmpy/expressions/variables.py +++ b/cpmpy/expressions/variables.py @@ -366,7 +366,7 @@ def __invert__(self): # subclass numericexpression for operators (first), ndarray for all the rest -class NDVarArray(Expression, np.ndarray): +class NDVarArray(np.ndarray, Expression): """ N-dimensional numpy array of variables. diff --git a/tests/test_expressions.py b/tests/test_expressions.py index f476f39cd..0b51133dd 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -462,5 +462,41 @@ def test_not_operator(self): self.assertTrue(cp.Model([~~p == ~q]).solve()) self.assertTrue(cp.Model([Operator('not',[p]) == q]).solve()) self.assertTrue(cp.Model([Operator('not',[p])]).solve()) + + def test_description(self): + + a,b = cp.boolvar(name="a"), cp.boolvar(name="b") + cons = a ^ b + cons.set_description("either a or b should be true, but not both") + + self.assertEqual(repr(cons), "a xor b") + self.assertEqual(str(cons), "either a or b should be true, but not both") + + # ensure nothing goes wrong due to calling __str__ on a constraint with a custom description + for solver,cls in cp.SolverLookup.base_solvers(): + if not cls.supported(): + continue + print("Testing", solver) + self.assertTrue(cp.Model(cons).solve(solver=solver)) + + ## test extra attributes of set_description + cons = a ^ b + cons.set_description("either a or b should be true, but not both", + override_print=False) + + self.assertEqual(repr(cons), "a xor b") + self.assertEqual(str(cons), "a xor b") + + cons = a ^ b + cons.set_description("either a or b should be true, but not both", + full_print=True) + + self.assertEqual(repr(cons), "a xor b") + self.assertEqual(str(cons), "either a or b should be true, but not both -- a xor b") + + + + + if __name__ == '__main__': unittest.main()