diff --git a/cpmpy/expressions/__init__.py b/cpmpy/expressions/__init__.py index bd71085cf..11a295527 100644 --- a/cpmpy/expressions/__init__.py +++ b/cpmpy/expressions/__init__.py @@ -24,6 +24,6 @@ from .globalconstraints import AllDifferent, AllDifferentExcept0, AllEqual, Circuit, Inverse, Table, Xor, Cumulative, \ IfThenElse, GlobalCardinalityCount, DirectConstraint, InDomain, Increasing, Decreasing, IncreasingStrict, DecreasingStrict from .globalconstraints import alldifferent, allequal, circuit # Old, to be deprecated -from .globalfunctions import Maximum, Minimum, Abs, Element, Count, NValue +from .globalfunctions import Maximum, Minimum, Abs, Element, Count, NValue, NValueExcept from .core import BoolVal from .python_builtins import all, any, max, min, sum diff --git a/cpmpy/expressions/globalfunctions.py b/cpmpy/expressions/globalfunctions.py index 8c1eeecca..1562ee5fa 100644 --- a/cpmpy/expressions/globalfunctions.py +++ b/cpmpy/expressions/globalfunctions.py @@ -362,4 +362,60 @@ def get_bounds(self): """ Returns the bounds of the (numerical) global constraint """ - return 1, len(self.args) \ No newline at end of file + return 1, len(self.args) + + +class NValueExcept(GlobalFunction): + + """ + The NValueExceptN constraint counts the number of distinct values, + not including value N, if any argument is assigned to it. + """ + + def __init__(self, arr, n): + if not is_any_list(arr): + raise ValueError("NValueExcept takes an array as input") + if not is_num(n): + raise ValueError(f"NValueExcept takes an integer as second argument, but got {n} of type {type(n)}") + super().__init__("nvalue_except",[arr, n]) + + def decompose_comparison(self, cmp_op, cpm_rhs): + """ + NValue(arr) can only be decomposed if it's part of a comparison + + Based on "simple decomposition" from: + Bessiere, Christian, et al. "Decomposition of the NValue constraint." + International Conference on Principles and Practice of Constraint Programming. + Berlin, Heidelberg: Springer Berlin Heidelberg, 2010. + """ + from .python_builtins import sum, any + + arr, n = self.args + arr = cpm_array(arr) + lbs, ubs = get_bounds(arr) + lb, ub = min(lbs), max(ubs) + + constraints = [] + + # introduce boolvar for each possible value + bvars = boolvar(shape=(ub + 1 - lb)) + idx_of_n = n - lb + if 0 <= idx_of_n < len(bvars): + count_of_vals = sum(bvars[:idx_of_n]) + sum(bvars[idx_of_n+1:]) + else: + count_of_vals = sum(bvars) + + # bvar is true if the value is taken by any variable + for bv, val in zip(bvars, range(lb, ub + 1)): + constraints += [any(arr == val) == bv] + + return [eval_comparison(cmp_op, count_of_vals, cpm_rhs)], constraints + + def value(self): + return len(set(argval(a) for a in self.args[0]) - {self.args[1]}) + + def get_bounds(self): + """ + Returns the bounds of the (numerical) global constraint + """ + return 0, len(self.args) \ No newline at end of file diff --git a/tests/test_globalconstraints.py b/tests/test_globalconstraints.py index 1f74c80c1..311828afe 100644 --- a/tests/test_globalconstraints.py +++ b/tests/test_globalconstraints.py @@ -655,6 +655,34 @@ def check_true(): self.assertTrue(cons.value()) cp.Model(cons).solveAll(display=check_true) + def test_nvalue_except(self): + + iv = cp.intvar(-8, 8, shape=3) + cnt = cp.intvar(0, 10) + + + self.assertFalse(cp.Model(cp.all(iv == 1), cp.NValueExcept(iv, 6) > 1).solve()) + self.assertTrue(cp.Model(cp.NValueExcept(iv, 10) > 1).solve()) + self.assertTrue(cp.Model(cp.all(iv == 1), cp.NValueExcept(iv, 1) == 0).solve()) + self.assertTrue(cp.Model(cp.all(iv == 1), cp.NValueExcept(iv, 6) > cnt).solve()) + self.assertGreater(len(set(iv.value())), cnt.value()) + + val = 6 + self.assertTrue(cp.Model(cp.NValueExcept(iv, val) != cnt).solve()) + self.assertTrue(cp.Model(cp.NValueExcept(iv, val) >= cnt).solve()) + self.assertTrue(cp.Model(cp.NValueExcept(iv, val) <= cnt).solve()) + self.assertTrue(cp.Model(cp.NValueExcept(iv, val) < cnt).solve()) + self.assertTrue(cp.Model(cp.NValueExcept(iv, val) > cnt).solve()) + + # test nested + bv = cp.boolvar() + cons = bv == (cp.NValueExcept(iv, val) <= 2) + + def check_true(): + self.assertTrue(cons.value()) + + cp.Model(cons).solveAll(display=check_true) + @pytest.mark.skipif(not CPM_minizinc.supported(), reason="Minizinc not installed") def test_nvalue_minizinc(self):