Skip to content

Commit

Permalink
Nvalues except (#474)
Browse files Browse the repository at this point in the history
  • Loading branch information
IgnaceBleukx authored May 10, 2024
1 parent 3a94691 commit b3b30c1
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 2 deletions.
2 changes: 1 addition & 1 deletion cpmpy/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@
from .globalconstraints import AllDifferent, AllDifferentExcept0, AllEqual, Circuit, Inverse, Table, Xor, Cumulative, \
IfThenElse, GlobalCardinalityCount, DirectConstraint, InDomain, Increasing, Decreasing, IncreasingStrict, DecreasingStrict
from .globalconstraints import alldifferent, allequal, circuit # Old, to be deprecated
from .globalfunctions import Maximum, Minimum, Abs, Element, Count, NValue
from .globalfunctions import Maximum, Minimum, Abs, Element, Count, NValue, NValueExcept
from .core import BoolVal
from .python_builtins import all, any, max, min, sum
58 changes: 57 additions & 1 deletion cpmpy/expressions/globalfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,4 +362,60 @@ def get_bounds(self):
"""
Returns the bounds of the (numerical) global constraint
"""
return 1, len(self.args)
return 1, len(self.args)


class NValueExcept(GlobalFunction):

"""
The NValueExceptN constraint counts the number of distinct values,
not including value N, if any argument is assigned to it.
"""

def __init__(self, arr, n):
if not is_any_list(arr):
raise ValueError("NValueExcept takes an array as input")
if not is_num(n):
raise ValueError(f"NValueExcept takes an integer as second argument, but got {n} of type {type(n)}")
super().__init__("nvalue_except",[arr, n])

def decompose_comparison(self, cmp_op, cpm_rhs):
"""
NValue(arr) can only be decomposed if it's part of a comparison
Based on "simple decomposition" from:
Bessiere, Christian, et al. "Decomposition of the NValue constraint."
International Conference on Principles and Practice of Constraint Programming.
Berlin, Heidelberg: Springer Berlin Heidelberg, 2010.
"""
from .python_builtins import sum, any

arr, n = self.args
arr = cpm_array(arr)
lbs, ubs = get_bounds(arr)
lb, ub = min(lbs), max(ubs)

constraints = []

# introduce boolvar for each possible value
bvars = boolvar(shape=(ub + 1 - lb))
idx_of_n = n - lb
if 0 <= idx_of_n < len(bvars):
count_of_vals = sum(bvars[:idx_of_n]) + sum(bvars[idx_of_n+1:])
else:
count_of_vals = sum(bvars)

# bvar is true if the value is taken by any variable
for bv, val in zip(bvars, range(lb, ub + 1)):
constraints += [any(arr == val) == bv]

return [eval_comparison(cmp_op, count_of_vals, cpm_rhs)], constraints

def value(self):
return len(set(argval(a) for a in self.args[0]) - {self.args[1]})

def get_bounds(self):
"""
Returns the bounds of the (numerical) global constraint
"""
return 0, len(self.args)
28 changes: 28 additions & 0 deletions tests/test_globalconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,34 @@ def check_true():
self.assertTrue(cons.value())
cp.Model(cons).solveAll(display=check_true)

def test_nvalue_except(self):

iv = cp.intvar(-8, 8, shape=3)
cnt = cp.intvar(0, 10)


self.assertFalse(cp.Model(cp.all(iv == 1), cp.NValueExcept(iv, 6) > 1).solve())
self.assertTrue(cp.Model(cp.NValueExcept(iv, 10) > 1).solve())
self.assertTrue(cp.Model(cp.all(iv == 1), cp.NValueExcept(iv, 1) == 0).solve())
self.assertTrue(cp.Model(cp.all(iv == 1), cp.NValueExcept(iv, 6) > cnt).solve())
self.assertGreater(len(set(iv.value())), cnt.value())

val = 6
self.assertTrue(cp.Model(cp.NValueExcept(iv, val) != cnt).solve())
self.assertTrue(cp.Model(cp.NValueExcept(iv, val) >= cnt).solve())
self.assertTrue(cp.Model(cp.NValueExcept(iv, val) <= cnt).solve())
self.assertTrue(cp.Model(cp.NValueExcept(iv, val) < cnt).solve())
self.assertTrue(cp.Model(cp.NValueExcept(iv, val) > cnt).solve())

# test nested
bv = cp.boolvar()
cons = bv == (cp.NValueExcept(iv, val) <= 2)

def check_true():
self.assertTrue(cons.value())

cp.Model(cons).solveAll(display=check_true)

@pytest.mark.skipif(not CPM_minizinc.supported(),
reason="Minizinc not installed")
def test_nvalue_minizinc(self):
Expand Down

0 comments on commit b3b30c1

Please sign in to comment.