From 906489695a91ef55de65db63d8c77b5b9b4a9e74 Mon Sep 17 00:00:00 2001 From: Wout Date: Tue, 28 May 2024 12:30:23 +0200 Subject: [PATCH] AllDiffExceptN, AllEqualExceptN (#473) AlldifferentExceptN and AllEqualExceptN Global Constraints Co-authored-by: Dimos Tsouros Co-authored-by: Ignace Bleukx --- cpmpy/expressions/__init__.py | 4 +- cpmpy/expressions/globalconstraints.py | 50 ++++++++++-- tests/test_constraints.py | 18 +++-- tests/test_globalconstraints.py | 104 +++++++++++++++++++++++-- 4 files changed, 153 insertions(+), 23 deletions(-) diff --git a/cpmpy/expressions/__init__.py b/cpmpy/expressions/__init__.py index 664b97204..435c5f04c 100644 --- a/cpmpy/expressions/__init__.py +++ b/cpmpy/expressions/__init__.py @@ -5,13 +5,13 @@ List of submodules ================== .. autosummary:: - python_builtins :nosignatures: variables core globalconstraints globalfunctions + python_builtins utils @@ -21,7 +21,7 @@ # others need to be imported by the developer explicitely from .variables import boolvar, intvar, cpm_array from .variables import BoolVar, IntVar, cparray # Old, to be deprecated -from .globalconstraints import AllDifferent, AllDifferentExcept0, AllDifferentLists, AllEqual, Circuit, Inverse, Table, Xor, Cumulative, \ +from .globalconstraints import AllDifferent, AllDifferentExcept0, AllDifferentExceptN, AllDifferentLists, AllEqual, AllEqualExceptN, Circuit, Inverse, Table, Xor, Cumulative, \ IfThenElse, GlobalCardinalityCount, DirectConstraint, InDomain, Increasing, Decreasing, IncreasingStrict, DecreasingStrict, \ LexLess, LexLessEq, LexChainLess, LexChainLessEq, Precedence, NoOverlap from .globalconstraints import alldifferent, allequal, circuit # Old, to be deprecated diff --git a/cpmpy/expressions/globalconstraints.py b/cpmpy/expressions/globalconstraints.py index 2620df725..06976331f 100644 --- a/cpmpy/expressions/globalconstraints.py +++ b/cpmpy/expressions/globalconstraints.py @@ -98,8 +98,10 @@ def my_circuit_decomp(self): AllDifferent AllDifferentExcept0 + AllDifferentExceptN AllDifferentLists AllEqual + AllEqualExceptN Circuit Inverse Table @@ -181,22 +183,33 @@ def decompose(self): def value(self): return len(set(argvals(self.args))) == len(self.args) - -class AllDifferentExcept0(GlobalConstraint): +class AllDifferentExceptN(GlobalConstraint): """ - All nonzero arguments have a distinct value + All arguments except those equal to a value in n have a distinct value. """ - def __init__(self, *args): - super().__init__("alldifferent_except0", flatlist(args)) + def __init__(self, arr, n): + flatarr = flatlist(arr) + if not is_any_list(n): + n = [n] + super().__init__("alldifferent_except_n", [flatarr, n]) def decompose(self): - # equivalent to (var1 == 0) | (var2 == 0) | (var1 != var2) - return [(var1 == var2).implies(var1 == 0) for var1, var2 in all_pairs(self.args)], [] + from .python_builtins import any as cpm_any + # equivalent to (var1 == n) | (var2 == n) | (var1 != var2) + return [(var1 == var2).implies(cpm_any(var1 == a for a in self.args[1])) for var1, var2 in all_pairs(self.args[0])], [] def value(self): - vals = [argval(a) for a in self.args if argval(a) != 0] + vals = [argval(a) for a in self.args[0] if argval(a) not in argvals(self.args[1])] return len(set(vals)) == len(vals) +class AllDifferentExcept0(AllDifferentExceptN): + """ + All nonzero arguments have a distinct value + """ + def __init__(self, *arr): + flatarr = flatlist(arr) + super().__init__(arr, 0) + class AllDifferentLists(GlobalConstraint): """ @@ -224,6 +237,7 @@ def value(self): lst_vals = [tuple(argvals(a)) for a in self.args] return len(set(lst_vals)) == len(self.args) + def allequal(args): warnings.warn("Deprecated, use AllEqual(v1,v2,...,vn) instead, will be removed in stable version", DeprecationWarning) return AllEqual(*args) # unfold list as individual arguments @@ -244,6 +258,26 @@ def decompose(self): def value(self): return len(set(argvals(self.args))) == 1 +class AllEqualExceptN(GlobalConstraint): + """ + All arguments except those equal to a value in n have the same value. + """ + + def __init__(self, arr, n): + flatarr = flatlist(arr) + if not is_any_list(n): + n = [n] + super().__init__("allequal_except_n", [flatarr, n]) + + def decompose(self): + from .python_builtins import any as cpm_any + return [(cpm_any(var1 == a for a in self.args[1]) | (var1 == var2) | cpm_any(var2 == a for a in self.args[1])) for var1, var2 in all_pairs(self.args[0])], [] + + def value(self): + vals = [argval(a) for a in self.args[0] if argval(a) not in argvals(self.args[1])] + return len(set(vals)) == 1 or len(set(vals)) == 0 + + def circuit(args): warnings.warn("Deprecated, use Circuit(v1,v2,...,vn) instead, will be removed in stable version", DeprecationWarning) return Circuit(*args) # unfold list as individual arguments diff --git a/tests/test_constraints.py b/tests/test_constraints.py index dd59b6ffd..cad4c7fde 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -16,6 +16,7 @@ # Exclude some global constraints for solvers NUM_GLOBAL = { "AllEqual", "AllDifferent", "AllDifferentLists", "AllDifferentExcept0", + "AllDifferentExceptN", "AllEqualExceptN", "GlobalCardinalityCount", "InDomain", "Inverse", "Table", "Circuit", "Increasing", "IncreasingStrict", "Decreasing", "DecreasingStrict", "Precedence", "Cumulative", "NoOverlap", @@ -180,6 +181,8 @@ def global_constraints(solver): classes = [(name, cls) for name, cls in classes if name not in EXCLUDE_GLOBAL.get(solver, {})] for name, cls in classes: + if solver in EXCLUDE_GLOBAL and name in EXCLUDE_GLOBAL[solver]: + continue if name == "Xor": expr = cls(BOOL_ARGS) @@ -198,6 +201,14 @@ def global_constraints(solver): demand = [4, 5, 7] cap = 10 expr = Cumulative(s, dur, e, demand, cap) + elif name == "GlobalCardinalityCount": + vals = [1, 2, 3] + cnts = intvar(0,10,shape=3) + expr = cls(NUM_ARGS, vals, cnts) + elif name == "AllDifferentExceptN": + expr = cls(NUM_ARGS, NUM_VAR) + elif name == "AllEqualExceptN": + expr = cls(NUM_ARGS, NUM_VAR) elif name == "Precedence": x = intvar(0,5, shape=3, name="x") expr = cls(x, [3,1,0]) @@ -214,16 +225,13 @@ def global_constraints(solver): X = intvar(0, 3, shape=3) Y = intvar(0, 3, shape=3) expr = LexLessEq(X, Y) - elif name == "LexLess": X = intvar(0, 3, shape=3) Y = intvar(0, 3, shape=3) expr = LexLess(X, Y) - elif name == "LexChainLess": X = intvar(0, 3, shape=(3,3)) - expr = LexChainLess(X) - + expr = LexChainLess(X) elif name == "LexChainLessEq": X = intvar(0, 3, shape=(3,3)) expr = LexChainLess(X) @@ -237,7 +245,7 @@ def global_constraints(solver): continue else: yield expr - + def reify_imply_exprs(solver): """ diff --git a/tests/test_globalconstraints.py b/tests/test_globalconstraints.py index 3c0d20f6a..b060717e9 100644 --- a/tests/test_globalconstraints.py +++ b/tests/test_globalconstraints.py @@ -114,6 +114,42 @@ def test_alldifferent_except0(self): bv = cp.boolvar() self.assertTrue(cp.Model([cp.AllDifferentExcept0(iv[0], bv)]).solve()) + def test_alldifferent_except_n(self): + # test known input/outputs + tuples = [ + ((1, 2, 3), True), + ((1, 2, 1), False), + ((0, 1, 2), True), + ((2, 0, 3), True), + ((2, 0, 2), True), + ((0, 0, 2), False), + ] + iv = cp.intvar(0, 4, shape=3) + c = cp.AllDifferentExceptN(iv, 2) + for (vals, oracle) in tuples: + ret = cp.Model(c, iv == vals).solve() + assert (ret == oracle), f"Mismatch solve for {vals, oracle}" + # don't try this at home, forcibly overwrite variable values (so even when ret=false) + for (var, val) in zip(iv, vals): + var._value = val + assert (c.value() == oracle), f"Wrong value function for {vals, oracle}" + + # and some more + iv = cp.intvar(-8, 8, shape=3) + self.assertTrue(cp.Model([cp.AllDifferentExceptN(iv,2)]).solve()) + self.assertTrue(cp.AllDifferentExceptN(iv,4).value()) + self.assertTrue(cp.Model([cp.AllDifferentExceptN(iv,7), iv == [7, 7, 1]]).solve()) + self.assertTrue(cp.AllDifferentExceptN(iv,7).value()) + + # test with mixed types + bv = cp.boolvar() + self.assertTrue(cp.Model([cp.AllDifferentExceptN([iv[0], bv],4)]).solve()) + + # test with list of n + iv = cp.intvar(0, 4, shape=7) + self.assertFalse(cp.Model([cp.AllDifferentExceptN([iv], [7,8])]).solve()) + self.assertTrue(cp.Model([cp.AllDifferentExceptN([iv], [4, 1])]).solve()) + def test_not_alldifferentexcept0(self): iv = cp.intvar(-8, 8, shape=3) self.assertTrue(cp.Model([~cp.AllDifferentExcept0(iv)]).solve()) @@ -169,25 +205,25 @@ def test_not_circuit(self): self.assertFalse(cp.Model([circuit, ~circuit]).solve()) - circuit_sols = set() - not_circuit_sols = set() + all_sols = set() + not_all_sols = set() - circuit_models = cp.Model(circuit).solveAll(display=lambda : circuit_sols.add(tuple(x.value()))) - not_circuit_models = cp.Model(~circuit).solveAll(display=lambda : not_circuit_sols.add(tuple(x.value()))) + circuit_models = cp.Model(circuit).solveAll(display=lambda : all_sols.add(tuple(x.value()))) + not_circuit_models = cp.Model(~circuit).solveAll(display=lambda : not_all_sols.add(tuple(x.value()))) total = cp.Model(x == x).solveAll() - for sol in circuit_sols: + for sol in all_sols: for var,val in zip(x, sol): var._value = val self.assertTrue(circuit.value()) - for sol in not_circuit_sols: + for sol in not_all_sols: for var,val in zip(x, sol): var._value = val self.assertFalse(circuit.value()) - self.assertEqual(total, len(circuit_sols) + len(not_circuit_sols)) + self.assertEqual(total, len(all_sols) + len(not_all_sols)) def test_inverse(self): @@ -914,7 +950,59 @@ def test_allEqual(self): a = cp.boolvar() self.assertTrue(cp.Model([cp.AllEqual(x,y,-1)]).solve()) self.assertTrue(cp.Model([cp.AllEqual(a,b,False, a | b)]).solve()) - #self.assertTrue(cp.Model([cp.AllEqual(x,y,b)]).solve()) + self.assertFalse(cp.Model([cp.AllEqual(x,y,b)]).solve()) + + def test_allEqualExceptn(self): + x = cp.intvar(-8, 8) + y = cp.intvar(-7, -1) + b = cp.boolvar() + a = cp.boolvar() + self.assertTrue(cp.Model([cp.AllEqualExceptN([x,y,-1],211)]).solve()) + self.assertTrue(cp.Model([cp.AllEqualExceptN([x,y,-1,4],4)]).solve()) + self.assertTrue(cp.Model([cp.AllEqualExceptN([x,y,-1,4],-1)]).solve()) + self.assertTrue(cp.Model([cp.AllEqualExceptN([a,b,False, a | b], 4)]).solve()) + self.assertTrue(cp.Model([cp.AllEqualExceptN([a,b,False, a | b], 0)]).solve()) + self.assertTrue(cp.Model([cp.AllEqualExceptN([a,b,False, a | b, y], -1)]).solve()) + + # test with list of n + iv = cp.intvar(0, 4, shape=7) + self.assertFalse(cp.Model([cp.AllEqualExceptN([iv], [7,8]), iv[0] != iv[1]]).solve()) + self.assertTrue(cp.Model([cp.AllEqualExceptN([iv], [4, 1]), iv[0] != iv[1]]).solve()) + + def test_not_allEqualExceptn(self): + x = cp.intvar(lb=0, ub=3, shape=3) + n = 2 + constr = cp.AllEqualExceptN(x,n) + + model = cp.Model([~constr, x == [1, 2, 1]]) + self.assertFalse(model.solve()) + + model = cp.Model([~constr]) + self.assertTrue(model.solve()) + self.assertFalse(constr.value()) + + self.assertFalse(cp.Model([constr, ~constr]).solve()) + + all_sols = set() + not_all_sols = set() + + circuit_models = cp.Model(constr).solveAll(display=lambda: all_sols.add(tuple(x.value()))) + not_circuit_models = cp.Model(~constr).solveAll(display=lambda: not_all_sols.add(tuple(x.value()))) + + total = cp.Model(x == x).solveAll() + + for sol in all_sols: + for var, val in zip(x, sol): + var._value = val + self.assertTrue(constr.value()) + + for sol in not_all_sols: + for var, val in zip(x, sol): + var._value = val + self.assertFalse(constr.value()) + + self.assertEqual(total, len(all_sols) + len(not_all_sols)) + def test_increasing(self): x = cp.intvar(-8, 8)