Skip to content

Commit

Permalink
fix special case in canonical comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
IgnaceBleukx committed Jan 12, 2024
1 parent 7fdc2b1 commit b5bcc7a
Showing 1 changed file with 24 additions and 18 deletions.
42 changes: 24 additions & 18 deletions cpmpy/transformations/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,24 +299,30 @@ def canonical_comparison(lst_of_expr):
assert not is_num(lhs), "lhs cannot be an integer at this point!"

# bring all const to rhs
if lhs.name == "sum":
new_args = []
for i, arg in enumerate(lhs.args):
if is_num(arg):
rhs -= arg
else:
new_args.append(arg)
lhs = Operator("sum", new_args)

elif lhs.name == "wsum":
new_weights, new_args = [], []
for i, (w, arg) in enumerate(zip(*lhs.args)):
if is_num(arg):
rhs -= w * arg
else:
new_weights.append(w)
new_args.append(arg)
lhs = Operator("wsum", [new_weights, new_args])
if isinstance(lhs, Operator):
if lhs.name == "sum":
new_args = []
for i, arg in enumerate(lhs.args):
if is_num(arg):
rhs -= arg
else:
new_args.append(arg)
lhs = Operator("sum", new_args)

elif lhs.name == "wsum":
new_weights, new_args = [], []
for i, (w, arg) in enumerate(zip(*lhs.args)):
if is_num(arg):
rhs -= w * arg
else:
new_weights.append(w)
new_args.append(arg)
lhs = Operator("wsum", [new_weights, new_args])
else:
raise ValueError(f"lhs should be sum or wsum, but got {lhs}")
else:
assert isinstance(lhs, _NumVarImpl)
lhs = Operator("sum", [lhs])

newlist.append(eval_comparison(cpm_expr.name, lhs, rhs))
else: # rest of expressions
Expand Down

0 comments on commit b5bcc7a

Please sign in to comment.