Skip to content

Commit

Permalink
AllDiffExceptN, AllEqualExceptN (#473)
Browse files Browse the repository at this point in the history
AlldifferentExceptN and AllEqualExceptN Global Constraints

Co-authored-by: Dimos Tsouros
Co-authored-by: Ignace Bleukx
  • Loading branch information
Wout4 authored May 28, 2024
1 parent eabc97b commit 9064896
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 23 deletions.
4 changes: 2 additions & 2 deletions cpmpy/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
List of submodules
==================
.. autosummary::
python_builtins
:nosignatures:
variables
core
globalconstraints
globalfunctions
python_builtins
utils
Expand All @@ -21,7 +21,7 @@
# others need to be imported by the developer explicitely
from .variables import boolvar, intvar, cpm_array
from .variables import BoolVar, IntVar, cparray # Old, to be deprecated
from .globalconstraints import AllDifferent, AllDifferentExcept0, AllDifferentLists, AllEqual, Circuit, Inverse, Table, Xor, Cumulative, \
from .globalconstraints import AllDifferent, AllDifferentExcept0, AllDifferentExceptN, AllDifferentLists, AllEqual, AllEqualExceptN, Circuit, Inverse, Table, Xor, Cumulative, \
IfThenElse, GlobalCardinalityCount, DirectConstraint, InDomain, Increasing, Decreasing, IncreasingStrict, DecreasingStrict, \
LexLess, LexLessEq, LexChainLess, LexChainLessEq, Precedence, NoOverlap
from .globalconstraints import alldifferent, allequal, circuit # Old, to be deprecated
Expand Down
50 changes: 42 additions & 8 deletions cpmpy/expressions/globalconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ def my_circuit_decomp(self):
AllDifferent
AllDifferentExcept0
AllDifferentExceptN
AllDifferentLists
AllEqual
AllEqualExceptN
Circuit
Inverse
Table
Expand Down Expand Up @@ -181,22 +183,33 @@ def decompose(self):
def value(self):
return len(set(argvals(self.args))) == len(self.args)


class AllDifferentExcept0(GlobalConstraint):
class AllDifferentExceptN(GlobalConstraint):
"""
All nonzero arguments have a distinct value
All arguments except those equal to a value in n have a distinct value.
"""
def __init__(self, *args):
super().__init__("alldifferent_except0", flatlist(args))
def __init__(self, arr, n):
flatarr = flatlist(arr)
if not is_any_list(n):
n = [n]
super().__init__("alldifferent_except_n", [flatarr, n])

def decompose(self):
# equivalent to (var1 == 0) | (var2 == 0) | (var1 != var2)
return [(var1 == var2).implies(var1 == 0) for var1, var2 in all_pairs(self.args)], []
from .python_builtins import any as cpm_any
# equivalent to (var1 == n) | (var2 == n) | (var1 != var2)
return [(var1 == var2).implies(cpm_any(var1 == a for a in self.args[1])) for var1, var2 in all_pairs(self.args[0])], []

def value(self):
vals = [argval(a) for a in self.args if argval(a) != 0]
vals = [argval(a) for a in self.args[0] if argval(a) not in argvals(self.args[1])]
return len(set(vals)) == len(vals)

class AllDifferentExcept0(AllDifferentExceptN):
"""
All nonzero arguments have a distinct value
"""
def __init__(self, *arr):
flatarr = flatlist(arr)
super().__init__(arr, 0)


class AllDifferentLists(GlobalConstraint):
"""
Expand Down Expand Up @@ -224,6 +237,7 @@ def value(self):
lst_vals = [tuple(argvals(a)) for a in self.args]
return len(set(lst_vals)) == len(self.args)


def allequal(args):
warnings.warn("Deprecated, use AllEqual(v1,v2,...,vn) instead, will be removed in stable version", DeprecationWarning)
return AllEqual(*args) # unfold list as individual arguments
Expand All @@ -244,6 +258,26 @@ def decompose(self):
def value(self):
return len(set(argvals(self.args))) == 1

class AllEqualExceptN(GlobalConstraint):
"""
All arguments except those equal to a value in n have the same value.
"""

def __init__(self, arr, n):
flatarr = flatlist(arr)
if not is_any_list(n):
n = [n]
super().__init__("allequal_except_n", [flatarr, n])

def decompose(self):
from .python_builtins import any as cpm_any
return [(cpm_any(var1 == a for a in self.args[1]) | (var1 == var2) | cpm_any(var2 == a for a in self.args[1])) for var1, var2 in all_pairs(self.args[0])], []

def value(self):
vals = [argval(a) for a in self.args[0] if argval(a) not in argvals(self.args[1])]
return len(set(vals)) == 1 or len(set(vals)) == 0


def circuit(args):
warnings.warn("Deprecated, use Circuit(v1,v2,...,vn) instead, will be removed in stable version", DeprecationWarning)
return Circuit(*args) # unfold list as individual arguments
Expand Down
18 changes: 13 additions & 5 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# Exclude some global constraints for solvers
NUM_GLOBAL = {
"AllEqual", "AllDifferent", "AllDifferentLists", "AllDifferentExcept0",
"AllDifferentExceptN", "AllEqualExceptN",
"GlobalCardinalityCount", "InDomain", "Inverse", "Table", "Circuit",
"Increasing", "IncreasingStrict", "Decreasing", "DecreasingStrict",
"Precedence", "Cumulative", "NoOverlap",
Expand Down Expand Up @@ -180,6 +181,8 @@ def global_constraints(solver):
classes = [(name, cls) for name, cls in classes if name not in EXCLUDE_GLOBAL.get(solver, {})]

for name, cls in classes:
if solver in EXCLUDE_GLOBAL and name in EXCLUDE_GLOBAL[solver]:
continue

if name == "Xor":
expr = cls(BOOL_ARGS)
Expand All @@ -198,6 +201,14 @@ def global_constraints(solver):
demand = [4, 5, 7]
cap = 10
expr = Cumulative(s, dur, e, demand, cap)
elif name == "GlobalCardinalityCount":
vals = [1, 2, 3]
cnts = intvar(0,10,shape=3)
expr = cls(NUM_ARGS, vals, cnts)
elif name == "AllDifferentExceptN":
expr = cls(NUM_ARGS, NUM_VAR)
elif name == "AllEqualExceptN":
expr = cls(NUM_ARGS, NUM_VAR)
elif name == "Precedence":
x = intvar(0,5, shape=3, name="x")
expr = cls(x, [3,1,0])
Expand All @@ -214,16 +225,13 @@ def global_constraints(solver):
X = intvar(0, 3, shape=3)
Y = intvar(0, 3, shape=3)
expr = LexLessEq(X, Y)

elif name == "LexLess":
X = intvar(0, 3, shape=3)
Y = intvar(0, 3, shape=3)
expr = LexLess(X, Y)

elif name == "LexChainLess":
X = intvar(0, 3, shape=(3,3))
expr = LexChainLess(X)

expr = LexChainLess(X)
elif name == "LexChainLessEq":
X = intvar(0, 3, shape=(3,3))
expr = LexChainLess(X)
Expand All @@ -237,7 +245,7 @@ def global_constraints(solver):
continue
else:
yield expr


def reify_imply_exprs(solver):
"""
Expand Down
104 changes: 96 additions & 8 deletions tests/test_globalconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,42 @@ def test_alldifferent_except0(self):
bv = cp.boolvar()
self.assertTrue(cp.Model([cp.AllDifferentExcept0(iv[0], bv)]).solve())

def test_alldifferent_except_n(self):
# test known input/outputs
tuples = [
((1, 2, 3), True),
((1, 2, 1), False),
((0, 1, 2), True),
((2, 0, 3), True),
((2, 0, 2), True),
((0, 0, 2), False),
]
iv = cp.intvar(0, 4, shape=3)
c = cp.AllDifferentExceptN(iv, 2)
for (vals, oracle) in tuples:
ret = cp.Model(c, iv == vals).solve()
assert (ret == oracle), f"Mismatch solve for {vals, oracle}"
# don't try this at home, forcibly overwrite variable values (so even when ret=false)
for (var, val) in zip(iv, vals):
var._value = val
assert (c.value() == oracle), f"Wrong value function for {vals, oracle}"

# and some more
iv = cp.intvar(-8, 8, shape=3)
self.assertTrue(cp.Model([cp.AllDifferentExceptN(iv,2)]).solve())
self.assertTrue(cp.AllDifferentExceptN(iv,4).value())
self.assertTrue(cp.Model([cp.AllDifferentExceptN(iv,7), iv == [7, 7, 1]]).solve())
self.assertTrue(cp.AllDifferentExceptN(iv,7).value())

# test with mixed types
bv = cp.boolvar()
self.assertTrue(cp.Model([cp.AllDifferentExceptN([iv[0], bv],4)]).solve())

# test with list of n
iv = cp.intvar(0, 4, shape=7)
self.assertFalse(cp.Model([cp.AllDifferentExceptN([iv], [7,8])]).solve())
self.assertTrue(cp.Model([cp.AllDifferentExceptN([iv], [4, 1])]).solve())

def test_not_alldifferentexcept0(self):
iv = cp.intvar(-8, 8, shape=3)
self.assertTrue(cp.Model([~cp.AllDifferentExcept0(iv)]).solve())
Expand Down Expand Up @@ -169,25 +205,25 @@ def test_not_circuit(self):

self.assertFalse(cp.Model([circuit, ~circuit]).solve())

circuit_sols = set()
not_circuit_sols = set()
all_sols = set()
not_all_sols = set()

circuit_models = cp.Model(circuit).solveAll(display=lambda : circuit_sols.add(tuple(x.value())))
not_circuit_models = cp.Model(~circuit).solveAll(display=lambda : not_circuit_sols.add(tuple(x.value())))
circuit_models = cp.Model(circuit).solveAll(display=lambda : all_sols.add(tuple(x.value())))
not_circuit_models = cp.Model(~circuit).solveAll(display=lambda : not_all_sols.add(tuple(x.value())))

total = cp.Model(x == x).solveAll()

for sol in circuit_sols:
for sol in all_sols:
for var,val in zip(x, sol):
var._value = val
self.assertTrue(circuit.value())

for sol in not_circuit_sols:
for sol in not_all_sols:
for var,val in zip(x, sol):
var._value = val
self.assertFalse(circuit.value())

self.assertEqual(total, len(circuit_sols) + len(not_circuit_sols))
self.assertEqual(total, len(all_sols) + len(not_all_sols))


def test_inverse(self):
Expand Down Expand Up @@ -914,7 +950,59 @@ def test_allEqual(self):
a = cp.boolvar()
self.assertTrue(cp.Model([cp.AllEqual(x,y,-1)]).solve())
self.assertTrue(cp.Model([cp.AllEqual(a,b,False, a | b)]).solve())
#self.assertTrue(cp.Model([cp.AllEqual(x,y,b)]).solve())
self.assertFalse(cp.Model([cp.AllEqual(x,y,b)]).solve())

def test_allEqualExceptn(self):
x = cp.intvar(-8, 8)
y = cp.intvar(-7, -1)
b = cp.boolvar()
a = cp.boolvar()
self.assertTrue(cp.Model([cp.AllEqualExceptN([x,y,-1],211)]).solve())
self.assertTrue(cp.Model([cp.AllEqualExceptN([x,y,-1,4],4)]).solve())
self.assertTrue(cp.Model([cp.AllEqualExceptN([x,y,-1,4],-1)]).solve())
self.assertTrue(cp.Model([cp.AllEqualExceptN([a,b,False, a | b], 4)]).solve())
self.assertTrue(cp.Model([cp.AllEqualExceptN([a,b,False, a | b], 0)]).solve())
self.assertTrue(cp.Model([cp.AllEqualExceptN([a,b,False, a | b, y], -1)]).solve())

# test with list of n
iv = cp.intvar(0, 4, shape=7)
self.assertFalse(cp.Model([cp.AllEqualExceptN([iv], [7,8]), iv[0] != iv[1]]).solve())
self.assertTrue(cp.Model([cp.AllEqualExceptN([iv], [4, 1]), iv[0] != iv[1]]).solve())

def test_not_allEqualExceptn(self):
x = cp.intvar(lb=0, ub=3, shape=3)
n = 2
constr = cp.AllEqualExceptN(x,n)

model = cp.Model([~constr, x == [1, 2, 1]])
self.assertFalse(model.solve())

model = cp.Model([~constr])
self.assertTrue(model.solve())
self.assertFalse(constr.value())

self.assertFalse(cp.Model([constr, ~constr]).solve())

all_sols = set()
not_all_sols = set()

circuit_models = cp.Model(constr).solveAll(display=lambda: all_sols.add(tuple(x.value())))
not_circuit_models = cp.Model(~constr).solveAll(display=lambda: not_all_sols.add(tuple(x.value())))

total = cp.Model(x == x).solveAll()

for sol in all_sols:
for var, val in zip(x, sol):
var._value = val
self.assertTrue(constr.value())

for sol in not_all_sols:
for var, val in zip(x, sol):
var._value = val
self.assertFalse(constr.value())

self.assertEqual(total, len(all_sols) + len(not_all_sols))


def test_increasing(self):
x = cp.intvar(-8, 8)
Expand Down

0 comments on commit 9064896

Please sign in to comment.