Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
IgnaceBleukx committed Jan 16, 2024
1 parent 74e76e3 commit edf7a42
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def comp_constraints(solver):
for glob_expr in global_constraints(solver):
if not glob_expr.is_bool():
for rhs in [NUM_VAR, BOOL_VAR, 1, BoolVal(True)]:
if comp_name == "<" and get_bounds(glob_expr)[0] >= get_bounds(rhs)[1]:
continue
yield Comparison(comp_name, glob_expr, rhs)

if solver == "z3":
Expand Down Expand Up @@ -149,7 +151,7 @@ def global_constraints(solver):
- AllDifferent, AllEqual, Circuit, Minimum, Maximum, Element,
Xor, Cumulative
"""
global_cons = [AllDifferent, AllEqual, Minimum, Maximum]
global_cons = [AllDifferent, AllEqual, Minimum, Maximum, NValue]
for global_type in global_cons:
cons = global_type(NUM_ARGS)
if solver not in EXCLUDE_GLOBAL or cons.name not in EXCLUDE_GLOBAL[solver]:
Expand Down
23 changes: 23 additions & 0 deletions tests/test_globalconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,29 @@ def test_count_onearg(self):
except (NotImplementedError, NotSupportedError):
pass


def test_nvalue(self):

iv = cp.intvar(-8, 8, shape=3)
cnt = cp.intvar(0,10)

self.assertFalse(cp.Model(cp.all(iv == 1), cp.NValue(iv) > 1).solve())
self.assertTrue(cp.Model(cp.all(iv == 1), cp.NValue(iv) > cnt).solve())
self.assertGreater(len(set(iv.value())), cnt.value())

self.assertTrue(cp.Model(cp.NValue(iv) != cnt).solve())
self.assertTrue(cp.Model(cp.NValue(iv) >= cnt).solve())
self.assertTrue(cp.Model(cp.NValue(iv) <= cnt).solve())
self.assertTrue(cp.Model(cp.NValue(iv) < cnt).solve())
self.assertTrue(cp.Model(cp.NValue(iv) > cnt).solve())

# test nested
bv = cp.boolvar()
cons = bv == (cp.NValue(iv) <= 2)
def check_true():
self.assertTrue(cons.value())
cp.Model(cons).solveAll(display=check_true)

class TestBounds(unittest.TestCase):
def test_bounds_minimum(self):
x = cp.intvar(-8, 8)
Expand Down

0 comments on commit edf7a42

Please sign in to comment.