Skip to content

Commit

Permalink
fix linearization of modulo
Browse files Browse the repository at this point in the history
add tests
  • Loading branch information
Wout4 committed Oct 8, 2024
1 parent 6513ecc commit e6eb2a1
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 28 deletions.
39 changes: 11 additions & 28 deletions cpmpy/transformations/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
"""
import copy
import numpy as np
from cpmpy.transformations.reification import only_implies, only_bv_reifies

from cpmpy.transformations.normalize import toplevel_list
from .decompose_global import decompose_in_tree

Expand Down Expand Up @@ -108,6 +110,7 @@ def _linearize_constraint_helper(lst_of_expr, supported={"sum","wsum"}, reified=
elif isinstance(cond, _BoolVarImpl):
lin_sub, new_vars = _linearize_constraint_helper([sub_expr], supported=supported, reified=True)
newlist += [cond.implies(lin) for lin in lin_sub]
# ensure no new solutions are created
a, b = _linearize_constraint_helper([(~cond).implies(nv == nv.lb) for nv in new_vars], reified=reified)
newlist += a
newvars += b
Expand All @@ -130,7 +133,7 @@ def _linearize_constraint_helper(lst_of_expr, supported={"sum","wsum"}, reified=

elif lhs.name == "mod" and "mod" not in supported:
if "mul" not in supported:
raise NotImplementedError("Cannot linearize modulo withtout multiplication")
raise NotImplementedError("Cannot linearize modulo without multiplication")

if cpm_expr.name != "==":
new_rhs, newcons = get_or_make_var(lhs)
Expand All @@ -148,35 +151,21 @@ def _linearize_constraint_helper(lst_of_expr, supported={"sum","wsum"}, reified=
raise NotImplementedError("Modulo with a divisor domain containing 0 is not supported. Please safen the expression first.")
k = intvar(*get_bounds((x - rhs) // y))
mult_res, newcons = get_or_make_var(k * y)
newlist += linearize_constraint([rhs < abs(y)]+newcons, supported, reified=reified)
# (abs of) modulo rhs is smaller than (abs of) the divisor y, but also needs to be of same sign as x.
newlist += linearize_constraint(only_implies(only_bv_reifies(flatten_constraint([Abs(rhs) < Abs(y), (x > 0).implies(rhs >= 0), (x < 0).implies(rhs <= 0)]))) + newcons, supported, reified=reified)

cpm_expr = (mult_res + rhs) == x
elif lhs.name == 'div':
a, b = lhs.args
# if division is total, b is either strictly negative or strictly positive!
lb, ub = get_bounds(b)
if not ((lb < 0 and ub < 0) or (lb > 0 and ub > 0)):
raise TypeError(
f"Can't divide by a domain containing 0, safen the expression first")
r = intvar(0, max(abs(lb) - 1, abs(ub) - 1)) # remainder is always positive for floordivision.
cpm_expr = [eval_comparison(cpm_expr.name, a, b * rhs + r)]
cond = [r < Abs(b)] # decomposition of Abs + flatten_constraint will twice flip the order around when b is a constant (a bit wastefull)
decomp = toplevel_list(flatten_constraint(decompose_in_tree(cond))) # decompose abs
cpm_exprs = toplevel_list(decomp + cpm_expr)
exprs = linearize_constraint(flatten_constraint(cpm_exprs), supported=supported)
newlist.extend(exprs)
continue
#newrhs = lhs.args[0]
#lhs = lhs.args[1] * rhs #operator is actually always '==' here due to only_numexpr_equality
#cpm_expr = eval_comparison(cpm_expr.name, lhs, newrhs)
elif lhs.name == 'idiv':

elif lhs.name == 'div' and 'div' not in supported:
if "mul" not in supported:
raise NotImplementedError("Cannot linearize modulo without multiplication")
a, b = lhs.args
# if division is total, b is either strictly negative or strictly positive!
lb, ub = get_bounds(b)
if not ((lb < 0 and ub < 0) or (lb > 0 and ub > 0)):
raise TypeError(
f"Can't divide by a domain containing 0, safen the expression first")
r = intvar(-(max(abs(lb) - 1, abs(ub) - 1)), max(abs(lb) - 1, abs(ub) - 1)) # remainder can be both positive and negative
r = intvar(-(max(abs(lb) - 1, abs(ub) - 1)), max(abs(lb) - 1, abs(ub) - 1)) # remainder can be both positive and negative (round towards 0, so negative r if a and b are both negative)
cpm_expr = [eval_comparison(cpm_expr.name, a, b * rhs + r)]
cond = [Abs(r) < Abs(b), Abs(b * rhs) < Abs(a)]
decomp = toplevel_list(decompose_in_tree(cond)) # decompose abs
Expand All @@ -185,12 +174,6 @@ def _linearize_constraint_helper(lst_of_expr, supported={"sum","wsum"}, reified=
newlist.extend(exprs)
continue


elif lhs.name == 'mod': # x mod y == x - (x//y) * y
# gets handles in the solver interface
# We should never get here, since both Gurobi and Exact have "faked support" for "Mod"
newlist.append(cpm_expr)
continue
else:
raise TransformationNotImplementedError(f"lhs of constraint {cpm_expr} cannot be linearized, should be any of {supported | set(['sub'])} but is {lhs}. Please report on github")

Expand Down
45 changes: 45 additions & 0 deletions tests/test_trans_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,51 @@ def test_neq(self):
# self.assertEqual(str(linearize_constraint(cons)), "[(a) -> (sum([1, -1, -6] * [x, y, BV4]) <= -1), (a) -> (sum([1, -1, -6] * [x, y, BV4]) >= -5)]")


def test_linearize_modulo(self):
x, y, z = cp.intvar(0, 5, shape=3, name=['x', 'y', 'z'])
a, b, c = cp.intvar(-5, 0, shape=3, name=['a', 'b', 'c'])
g, h, i = cp.intvar(-5, 5, shape=3, name=['g', 'h', 'i'])
s_pos = cp.intvar(1, 5, name='s_pos')
s_neg = cp.intvar(-5, -1, name='s_neg')

constraint = [g % s_pos == i]
lin = linearize_constraint(constraint, supported={'sum', 'wsum', 'mul'})

all_sols = set()
lin_all_sols = set()
cons_models = cp.Model(constraint).solveAll(display=lambda: all_sols.add(tuple([x.value() for x in [g, s_pos, i]])))
lin_models = cp.Model(lin).solveAll(display=lambda: lin_all_sols.add(tuple([x.value() for x in [g, s_pos, i]])))
self.assertEqual(cons_models,lin_models)

# ortools only accepts strictly positive modulo argument.

def test_linearize_division(self):
x, y, z = cp.intvar(0, 5, shape=3, name=['x', 'y', 'z'])
a, b, c = cp.intvar(-5, 0, shape=3, name=['a', 'b', 'c'])
g, h, i = cp.intvar(-5, 5, shape=3, name=['g', 'h', 'i'])
s_pos = cp.intvar(1, 5, name='s_pos')
s_neg = cp.intvar(-5, -1, name='s_neg')

constraint = [g / s_pos == i]
lin = linearize_constraint(constraint, supported={'sum', 'wsum', 'mul'})

all_sols = set()
lin_all_sols = set()
cons_models = cp.Model(constraint).solveAll(display=lambda: all_sols.add(tuple([x.value() for x in [g, s_pos, i]])))
lin_models = cp.Model(lin).solveAll(display=lambda: lin_all_sols.add(tuple([x.value() for x in [g, s_pos, i]])))
self.assertEqual(cons_models,lin_models)

# Duplicate test with s_neg instead of s_pos
constraint_neg = [g / s_neg == i]
lin_neg = linearize_constraint(constraint_neg, supported={'sum', 'wsum', 'mul'})

all_sols_neg = set()
lin_all_sols_neg = set()
cons_models_neg = cp.Model(constraint_neg).solveAll(display=lambda: all_sols_neg.add(tuple([x.value() for x in [g, s_neg, i]])))
lin_models_neg = cp.Model(lin_neg).solveAll(display=lambda: lin_all_sols_neg.add(tuple([x.value() for x in [g, s_neg, i]])))
self.assertEqual(cons_models_neg,lin_models_neg)


class TestConstRhs(unittest.TestCase):

def test_numvar(self):
Expand Down

0 comments on commit e6eb2a1

Please sign in to comment.