Skip to content

Commit

Permalink
AllDifferentLists
Browse files Browse the repository at this point in the history
  • Loading branch information
Wout4 committed May 10, 2024
1 parent 3a94691 commit c8f384c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 1 deletion.
2 changes: 1 addition & 1 deletion cpmpy/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# others need to be imported by the developer explicitely
from .variables import boolvar, intvar, cpm_array
from .variables import BoolVar, IntVar, cparray # Old, to be deprecated
from .globalconstraints import AllDifferent, AllDifferentExcept0, AllEqual, Circuit, Inverse, Table, Xor, Cumulative, \
from .globalconstraints import AllDifferent, AllDifferentExcept0, AllDifferentLists, AllEqual, Circuit, Inverse, Table, Xor, Cumulative, \
IfThenElse, GlobalCardinalityCount, DirectConstraint, InDomain, Increasing, Decreasing, IncreasingStrict, DecreasingStrict
from .globalconstraints import alldifferent, allequal, circuit # Old, to be deprecated
from .globalfunctions import Maximum, Minimum, Abs, Element, Count, NValue
Expand Down
16 changes: 16 additions & 0 deletions cpmpy/expressions/globalconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def my_circuit_decomp(self):
AllDifferent
AllDifferentExcept0
AllDifferentLists
AllEqual
Circuit
Inverse
Expand Down Expand Up @@ -197,6 +198,21 @@ def value(self):
return len(set(vals)) == len(vals)


class AllDifferentLists(GlobalConstraint):
def __init__(self, lists):
super().__init__("alldifferent_lists", [flatlist(lst) for lst in lists])

def decompose(self):
"""Returns the decomposition
"""
from .python_builtins import any
return [any([var1 != var2 for (var1, var2) in zip(vars1, vars2)]) for vars1, vars2 in all_pairs(self.args)], []

def value(self):
x = set([tuple([argval(a) for a in arr]) for arr in self.args])
y = [tuple([argval(a) for a in arr]) for arr in self.args]
return len(set([tuple([argval(a) for a in arr]) for arr in self.args])) == len(self.args)

def allequal(args):
warnings.warn("Deprecated, use AllEqual(v1,v2,...,vn) instead, will be removed in stable version", DeprecationWarning)
return AllEqual(*args) # unfold list as individual arguments
Expand Down
20 changes: 20 additions & 0 deletions tests/test_globalconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,26 @@ def test_alldifferent2(self):
var._value = val
assert (c.value() == oracle), f"Wrong value function for {vals,oracle}"

def test_alldifferent_lists(self):
# test known input/outputs
tuples = [
([(1,2,3),(1,3,3),(1,2,4)], True),
([(1,2,3),(1,3,3),(1,2,3)], False),
([(0,0,3),(1,3,3),(1,2,4)], True),
([(1,2,3),(1,3,3),(3,3,3)], True)
]
iv = cp.intvar(0,4, shape=(3,3))
c = cp.AllDifferentLists(iv)
for (vals, oracle) in tuples:
ret = cp.Model(c, iv == vals).solve()
assert (ret == oracle), f"Mismatch solve for {vals,oracle}"
# don't try this at home, forcibly overwrite variable values (so even when ret=false)
for (var,val) in zip(iv,vals):
for (vr, vl) in zip(var, val):
vr._value = vl
assert (c.value() == oracle), f"Wrong value function for {vals,oracle}"


def test_not_alldifferent(self):
# from fuzztester of Ruben Kindt, #143
pos = cp.intvar(lb=0, ub=5, shape=3, name="positions")
Expand Down

0 comments on commit c8f384c

Please sign in to comment.