From c8f384cb41a1d2fecc8d01b9e20f72589563b42e Mon Sep 17 00:00:00 2001 From: wout4 Date: Fri, 10 May 2024 15:55:28 +0200 Subject: [PATCH] AllDifferentLists --- cpmpy/expressions/__init__.py | 2 +- cpmpy/expressions/globalconstraints.py | 16 ++++++++++++++++ tests/test_globalconstraints.py | 20 ++++++++++++++++++++ 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/cpmpy/expressions/__init__.py b/cpmpy/expressions/__init__.py index bd71085cf..1d74a1380 100644 --- a/cpmpy/expressions/__init__.py +++ b/cpmpy/expressions/__init__.py @@ -21,7 +21,7 @@ # others need to be imported by the developer explicitely from .variables import boolvar, intvar, cpm_array from .variables import BoolVar, IntVar, cparray # Old, to be deprecated -from .globalconstraints import AllDifferent, AllDifferentExcept0, AllEqual, Circuit, Inverse, Table, Xor, Cumulative, \ +from .globalconstraints import AllDifferent, AllDifferentExcept0, AllDifferentLists, AllEqual, Circuit, Inverse, Table, Xor, Cumulative, \ IfThenElse, GlobalCardinalityCount, DirectConstraint, InDomain, Increasing, Decreasing, IncreasingStrict, DecreasingStrict from .globalconstraints import alldifferent, allequal, circuit # Old, to be deprecated from .globalfunctions import Maximum, Minimum, Abs, Element, Count, NValue diff --git a/cpmpy/expressions/globalconstraints.py b/cpmpy/expressions/globalconstraints.py index 9c6fe3933..deb292bf7 100644 --- a/cpmpy/expressions/globalconstraints.py +++ b/cpmpy/expressions/globalconstraints.py @@ -98,6 +98,7 @@ def my_circuit_decomp(self): AllDifferent AllDifferentExcept0 + AllDifferentLists AllEqual Circuit Inverse @@ -197,6 +198,21 @@ def value(self): return len(set(vals)) == len(vals) +class AllDifferentLists(GlobalConstraint): + def __init__(self, lists): + super().__init__("alldifferent_lists", [flatlist(lst) for lst in lists]) + + def decompose(self): + """Returns the decomposition + """ + from .python_builtins import any + return [any([var1 != var2 for (var1, var2) in zip(vars1, vars2)]) for vars1, vars2 in all_pairs(self.args)], [] + + def value(self): + x = set([tuple([argval(a) for a in arr]) for arr in self.args]) + y = [tuple([argval(a) for a in arr]) for arr in self.args] + return len(set([tuple([argval(a) for a in arr]) for arr in self.args])) == len(self.args) + def allequal(args): warnings.warn("Deprecated, use AllEqual(v1,v2,...,vn) instead, will be removed in stable version", DeprecationWarning) return AllEqual(*args) # unfold list as individual arguments diff --git a/tests/test_globalconstraints.py b/tests/test_globalconstraints.py index 1f74c80c1..e8ab2ea2f 100644 --- a/tests/test_globalconstraints.py +++ b/tests/test_globalconstraints.py @@ -55,6 +55,26 @@ def test_alldifferent2(self): var._value = val assert (c.value() == oracle), f"Wrong value function for {vals,oracle}" + def test_alldifferent_lists(self): + # test known input/outputs + tuples = [ + ([(1,2,3),(1,3,3),(1,2,4)], True), + ([(1,2,3),(1,3,3),(1,2,3)], False), + ([(0,0,3),(1,3,3),(1,2,4)], True), + ([(1,2,3),(1,3,3),(3,3,3)], True) + ] + iv = cp.intvar(0,4, shape=(3,3)) + c = cp.AllDifferentLists(iv) + for (vals, oracle) in tuples: + ret = cp.Model(c, iv == vals).solve() + assert (ret == oracle), f"Mismatch solve for {vals,oracle}" + # don't try this at home, forcibly overwrite variable values (so even when ret=false) + for (var,val) in zip(iv,vals): + for (vr, vl) in zip(var, val): + vr._value = vl + assert (c.value() == oracle), f"Wrong value function for {vals,oracle}" + + def test_not_alldifferent(self): # from fuzztester of Ruben Kindt, #143 pos = cp.intvar(lb=0, ub=5, shape=3, name="positions")