Skip to content

Commit

Permalink
new transformation used in linearize
Browse files Browse the repository at this point in the history
  • Loading branch information
Dimosts committed Sep 13, 2023
1 parent 562b193 commit c7a7c5a
Showing 1 changed file with 9 additions and 37 deletions.
46 changes: 9 additions & 37 deletions cpmpy/transformations/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,48 +108,22 @@ def linearize_constraint(lst_of_expr, supported={"sum","wsum"}, reified=False):
if lhs.name == "sub":
# convert to wsum
lhs = sum([1 * lhs.args[0] + -1 * lhs.args[1]])
cpm_expr = eval_comparison(cpm_expr.name, lhs, rhs)

# linearize unsupported operators
elif isinstance(lhs, Operator) and lhs.name not in supported: # TODO: add mul, (abs?), (mod?), (pow?)

if lhs.name == "mul" and is_num(lhs.args[0]):
lhs = Operator("wsum",[[lhs.args[0]], [lhs.args[1]]])
cpm_expr = eval_comparison(cpm_expr.name, lhs, rhs)
else:
raise TransformationNotImplementedError(f"lhs of constraint {cpm_expr} cannot be linearized, should be any of {supported | set(['sub'])} but is {lhs}. Please report on github")

elif isinstance(lhs, GlobalConstraint) and lhs.name not in supported:
raise ValueError("Linearization of `lhs` not supported, run `cpmpy.transformations.decompose_global.decompose_global() first")

if is_num(lhs) or isinstance(lhs, _NumVarImpl) or (isinstance(lhs, Operator) and lhs.name in {"sum","wsum"}):
# bring all vars to lhs
if isinstance(rhs, _NumVarImpl):
if isinstance(lhs, Operator) and lhs.name == "sum":
lhs, rhs = sum([1 * a for a in lhs.args]+[-1 * rhs]), 0
elif isinstance(lhs, _NumVarImpl) or (isinstance(lhs, Operator) and lhs.name == "wsum"):
lhs, rhs = lhs + -1*rhs, 0
else:
raise ValueError(f"unexpected expression on lhs of expression, should be sum,wsum or intvar but got {lhs}")

assert not is_num(lhs), "lhs cannot be an integer at this point!"
# bring all const to rhs
if lhs.name == "sum":
new_args = []
for i, arg in enumerate(lhs.args):
if is_num(arg):
rhs -= arg
else:
new_args.append(arg)
lhs = Operator("sum", new_args)

elif lhs.name == "wsum":
new_weights, new_args = [],[]
for i, (w, arg) in enumerate(zip(*lhs.args)):
if is_num(arg):
rhs -= w * arg
else:
new_weights.append(w)
new_args.append(arg)
lhs = Operator("wsum",[new_weights, new_args])
[cpm_expr] = canonical_comparison([cpm_expr]) # just transforms the constraint, not introducing new ones
lhs, rhs = cpm_expr.args

# now fix the comparisons themselves
if cpm_expr.name == "<":
Expand Down Expand Up @@ -284,9 +258,9 @@ def canonical_comparison(lst_of_expr):
newlist = []
for cpm_expr in lst_of_expr:

lhs, rhs = cpm_expr.args

if isinstance(cpm_expr, Comparison):
lhs, rhs = cpm_expr.args

if is_num(lhs) or isinstance(lhs, _NumVarImpl) or (isinstance(lhs, Operator) and lhs.name in {"sum", "wsum"}):
# bring all vars to lhs
if isinstance(rhs, _NumVarImpl):
Expand Down Expand Up @@ -320,10 +294,8 @@ def canonical_comparison(lst_of_expr):
new_args.append(arg)
lhs = Operator("wsum", [new_weights, new_args])

if isinstance(lhs, Operator) and lhs.name == "mul" and len(lhs.args) == 2 and is_num(lhs.args[0]):
# convert to wsum
lhs = Operator("wsum", [[lhs.args[0]], [lhs.args[1]]])

newlist.append(eval_comparison(cpm_expr.name, lhs, rhs))
newlist.append(eval_comparison(cpm_expr.name, lhs, rhs))
else:
newlist.append(cpm_expr)

return newlist

0 comments on commit c7a7c5a

Please sign in to comment.