Skip to content

Commit

Permalink
adapt linearize tests to have lists as input
Browse files Browse the repository at this point in the history
add a case to linearize for supported alldif
  • Loading branch information
Wout4 committed Sep 5, 2023
1 parent 72f4e51 commit 05e463c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 33 deletions.
2 changes: 2 additions & 0 deletions cpmpy/transformations/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def linearize_constraint(lst_of_expr, supported={"sum","wsum"}, reified=False):
# supported comparison
newlist.append(eval_comparison(cpm_expr.name, lhs, rhs))

elif cpm_expr.name == "alldifferent" and cpm_expr.name in supported:
newlist.append(cpm_expr)
elif cpm_expr.name == "alldifferent" and cpm_expr.name not in supported:
"""
More efficient implementations possible
Expand Down
69 changes: 36 additions & 33 deletions tests/test_trans_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ def test_linearize(self):
a, b, c = [boolvar(name=var) for var in "abc"]

# and
cons = linearize_constraint(a & b)[0]
cons = linearize_constraint([a & b])[0]
self.assertEqual("(a) + (b) >= 2", str(cons))

# or
cons = linearize_constraint(a | b)[0]
cons = linearize_constraint([a | b])[0]
self.assertEqual("(a) + (b) >= 1", str(cons))

# implies
cons = linearize_constraint(a.implies(b))[0]
cons = linearize_constraint([a.implies(b)])[0]
self.assertEqual("sum([1, -1] * [a, b]) <= 0", str(cons))

def test_bug_168(self):
Expand All @@ -48,31 +48,32 @@ def test_constraint(self):
a,b,c = [cp.boolvar(name=n) for n in "abc"]

# test and
self.assertEqual(str(linearize_constraint(a & b & c)), "[sum([a, b, c]) >= 3]")
self.assertEqual(str(linearize_constraint(a & b & (~c))), "[sum([a, b, ~c]) >= 3]")
self.assertEqual(str(linearize_constraint([a & b & c])), "[sum([a, b, c]) >= 3]")
self.assertEqual(str(linearize_constraint([a & b & (~c)])), "[sum([a, b, ~c]) >= 3]")
# test or
self.assertEqual(str(linearize_constraint(a | b | c)), "[sum([a, b, c]) >= 1]")
self.assertEqual(str(linearize_constraint(a | b | (~c))), "[sum([a, b, ~c]) >= 1]")
self.assertEqual(str(linearize_constraint([a | b | c])), "[sum([a, b, c]) >= 1]")
self.assertEqual(str(linearize_constraint([a | b | (~c)])), "[sum([a, b, ~c]) >= 1]")
# test implies
self.assertEqual(str(linearize_constraint(a.implies(b))), "[sum([1, -1] * [a, b]) <= 0]")
self.assertEqual(str(linearize_constraint(a.implies(~b))), "[sum([1, -1] * [~b, a]) >= 0]")
self.assertEqual(str(linearize_constraint(a.implies(x+y+z >= 0))), "[(a) -> (sum([x, y, z]) >= 0)]")
self.assertEqual(str(linearize_constraint(a.implies(x+y+z > 0))), "[(a) -> (sum([x, y, z]) >= 1)]")
self.assertEqual(str(linearize_constraint([a.implies(b)])), "[sum([1, -1] * [a, b]) <= 0]")
self.assertEqual(str(linearize_constraint([a.implies(~b)])), "[sum([1, -1] * [a, ~b]) <= 0]")
self.assertEqual(str(linearize_constraint([a.implies(x+y+z >= 0)])), "[(a) -> (sum([x, y, z]) >= 0)]")
self.assertEqual(str(linearize_constraint([a.implies(x+y+z > 0)])), "[(a) -> (sum([x, y, z]) >= 1)]")
# test sub
self.assertEqual(str(linearize_constraint(Operator("sub",[x,y]) >= z)), "[sum([1, -1, -1] * [x, y, z]) >= 0]")
self.assertEqual(str(linearize_constraint([Operator("sub",[x,y]) >= z])), "[sum([1, -1, -1] * [x, y, z]) >= 0]")
# test mul
self.assertEqual(str(linearize_constraint(3 * x > 2)), "[sum([3] * [x]) >= 3]")
self.assertEqual(str(linearize_constraint([3 * x > 2])), "[sum([3] * [x]) >= 3]")
# test <
self.assertEqual((str(linearize_constraint(x + y < z))), "[sum([1, 1, -1] * [x, y, z]) <= -1]")
self.assertEqual((str(linearize_constraint([x + y < z]))), "[sum([1, 1, -1] * [x, y, z]) <= -1]")
# test >
self.assertEqual((str(linearize_constraint(x + y > z))), "[sum([1, 1, -1] * [x, y, z]) >= 1]")
self.assertEqual((str(linearize_constraint([x + y > z]))), "[sum([1, 1, -1] * [x, y, z]) >= 1]")
# test !=
c1,c2 = linearize_constraint(x + y != z)
c1,c2 = linearize_constraint([x + y != z])
self.assertEqual(str(c1), "(BV3) -> (sum([1, 1, -1] * [x, y, z]) <= -1)")
self.assertEqual(str(c2), "(~BV3) -> (sum([1, 1, -1] * [x, y, z]) >= 1)")
c1, c2 = linearize_constraint(a.implies(x != y))
c1, c2, c3 = linearize_constraint([a.implies(x != y)])
self.assertEqual(str(c1), "(a) -> (sum([1, -1, -6] * [x, y, BV4]) <= -1)")
self.assertEqual(str(c2), "(a) -> (sum([1, -1, -6] * [x, y, BV4]) >= -5)")
self.assertEqual(str(c3), "sum([1, -1] * [~a, ~BV4]) <= 0")


def test_neq(self):
Expand All @@ -81,13 +82,13 @@ def test_neq(self):
x, y, z = [cp.intvar(0, 5, name=n) for n in "xyz"]
a, b, c = [cp.boolvar(name=n) for n in "abc"]

cons = 2*x + 3*y + 4*z != 10
cons = [2*x + 3*y + 4*z != 10]
self.assertEqual(str(linearize_constraint(cons)),"[(BV3) -> (sum([2, 3, 4] * [x, y, z]) <= 9), (~BV3) -> (sum([2, 3, 4] * [x, y, z]) >= 11)]")

cons = a.implies(x != y)
cons = [a.implies(x != y)]
lin_cons = linearize_constraint(cons)
cons_vals = []
cp.Model(lin_cons).solveAll(solver="ortools", display=lambda : cons_vals.append(cons.value()))
cp.Model(lin_cons).solveAll(solver="ortools", display=lambda : cons_vals.append(cons[0].value()))
print(len(cons_vals))
self.assertTrue(all(cons_vals))
# self.assertEqual(str(linearize_constraint(cons)), "[(a) -> (sum([1, -1, -6] * [x, y, BV4]) <= -1), (a) -> (sum([1, -1, -6] * [x, y, BV4]) >= -5)]")
Expand All @@ -98,34 +99,34 @@ class TestConstRhs(unittest.TestCase):
def test_numvar(self):
a, b = [cp.intvar(0, 10, name=n) for n in "ab"]

cons = linearize_constraint(a <= b)[0]
cons = linearize_constraint([a <= b])[0]
self.assertEqual("sum([1, -1] * [a, b]) <= 0", str(cons))

def test_sum(self):
a,b,c = [cp.intvar(0,10,name=n) for n in "abc"]
rhs = intvar(0,10,name="r")

cons = linearize_constraint(cp.sum([a,b,c]) <= rhs)[0]
cons = linearize_constraint([cp.sum([a,b,c]) <= rhs])[0]
self.assertEqual("sum([1, 1, 1, -1] * [a, b, c, r]) <= 0", str(cons))

def test_wsum(self):
a, b, c = [cp.intvar(0, 10,name=n) for n in "abc"]
rhs = intvar(0, 10, name="r")

cons = 1*a + 2*b + 3*c <= rhs
cons = linearize_constraint(cons)[0]
cons = linearize_constraint([cons])[0]
self.assertEqual("sum([1, 2, 3, -1] * [a, b, c, r]) <= 0", str(cons))

def test_impl(self):
a, b, c = [cp.intvar(0, 10, name=n) for n in "abc"]
rhs = intvar(0, 10, name="r")
cond = cp.boolvar(name="bv")

cons = cond.implies(1 * a + 2 * b + 3 * c <= rhs)
cons = [cond.implies(1 * a + 2 * b + 3 * c <= rhs)]
cons = linearize_constraint(cons)[0]
self.assertEqual("(bv) -> (sum([1, 2, 3, -1] * [a, b, c, r]) <= 0)", str(cons))

cons = (~cond).implies(1 * a + 2 * b + 3 * c <= rhs)
cons = [(~cond).implies(1 * a + 2 * b + 3 * c <= rhs)]
cons = linearize_constraint(cons)[0]
self.assertEqual("(~bv) -> (sum([1, 2, 3, -1] * [a, b, c, r]) <= 0)", str(cons))

Expand All @@ -134,11 +135,13 @@ def test_others(self):
a, b, c = [cp.intvar(0, 10, name=n) for n in "abc"]
rhs = intvar(0, 10, name="r")

cons = cp.max([a,b,c]) <= rhs
cons = [cp.max([a,b,c]) <= rhs]
print(linearize_constraint(cons, supported={"max"}))
cons = linearize_constraint(cons, supported={"max"})[0]
self.assertEqual("(max(a,b,c)) <= (r)", str(cons))

cons = cp.AllDifferent([a,b,c])
cons = [cp.AllDifferent([a,b,c])]
print(linearize_constraint(cons, supported={"alldifferent"}))
cons = linearize_constraint(cons, supported={"alldifferent"})[0]
self.assertEqual("alldifferent(a,b,c)", str(cons))

Expand All @@ -149,14 +152,14 @@ def test_sum(self):
a,b,c = [cp.intvar(0,10,name=n) for n in "abc"]
rhs = 5

cons = linearize_constraint(cp.sum([a,b,c,10]) <= rhs)[0]
cons = linearize_constraint([cp.sum([a,b,c,10]) <= rhs])[0]
self.assertEqual("sum([a, b, c]) <= -5", str(cons))

def test_wsum(self):
a, b, c = [cp.intvar(0, 10,name=n) for n in "abc"]
rhs = 5

cons = Operator("wsum",[[1,2,3,-1],[a,b,c,10]]) <= rhs
cons = [Operator("wsum",[[1,2,3,-1],[a,b,c,10]]) <= rhs]
cons = linearize_constraint(cons)[0]
self.assertEqual("sum([1, 2, 3] * [a, b, c]) <= 15", str(cons))

Expand All @@ -165,11 +168,11 @@ def test_impl(self):
rhs = 5
cond = cp.boolvar(name="bv")

cons = cond.implies(Operator("wsum",[[1,2,3,-1],[a,b,c,10]]) <= rhs)
cons = [cond.implies(Operator("wsum",[[1,2,3,-1],[a,b,c,10]]) <= rhs)]
cons = linearize_constraint(cons)[0]
self.assertEqual("(bv) -> (sum([1, 2, 3] * [a, b, c]) <= 15)", str(cons))

cons = (~cond).implies(Operator("wsum",[[1,2,3,-1],[a,b,c,10]]) <= rhs)
cons = [(~cond).implies(Operator("wsum",[[1,2,3,-1],[a,b,c,10]]) <= rhs)]
cons = linearize_constraint(cons)[0]
self.assertEqual("(~bv) -> (sum([1, 2, 3] * [a, b, c]) <= 15)", str(cons))

Expand All @@ -178,11 +181,11 @@ def test_others(self):
a, b, c = [cp.intvar(0, 10, name=n) for n in "abc"]
rhs = intvar(0, 10, name="r")

cons = cp.max([a,b,c,5]) <= rhs
cons = [cp.max([a,b,c,5]) <= rhs]
cons = linearize_constraint(cons, supported={"max"})[0]
self.assertEqual("(max(a,b,c,5)) <= (r)", str(cons))

cons = cp.AllDifferent([a, b, c])
cons = [cp.AllDifferent([a, b, c])]
cons = linearize_constraint(cons, supported={"alldifferent"})[0]
self.assertEqual("alldifferent(a,b,c)", str(cons))

Expand Down

0 comments on commit 05e463c

Please sign in to comment.