Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flatten: avoid unnecessary sum decompositions #222

Merged
merged 1 commit into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions cpmpy/expressions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,11 +437,9 @@ def __init__(self, name, arg_list):
if name == 'sum' and \
all(not is_num(a) for a in arg_list) and \
any(_wsum_should(a) for a in arg_list):
w,e = [], []
for a in arg_list:
w1,e1 = _wsum_make(a)
w += w1
e += e1
we = [_wsum_make(a) for a in arg_list]
w = [wi for w,_ in we for wi in w]
e = [ei for _,e in we for ei in e]
name = 'wsum'
arg_list = [w,e]

Expand All @@ -462,6 +460,16 @@ def __init__(self, name, arg_list):
i += l
i += 1

# another cleanup, translate -(v*c) to v*-c
if name == '-' and arg_list[0].name == 'mul' and len(arg_list[0].args)==2:
mul_args = arg_list[0].args
if is_num(mul_args[0]):
name = 'mul'
arg_list = (-mul_args[0], mul_args[1])
elif is_num(mul_args[1]):
name = 'mul'
arg_list = (mul_args[0], -mul_args[1])

super().__init__(name, arg_list)

def is_bool(self):
Expand Down Expand Up @@ -556,6 +564,8 @@ def _wsum_make(arg):
"""
if arg.name == 'wsum':
return arg.args
elif arg.name == 'sum':
return [1]*len(arg.args), arg.args
elif arg.name == 'mul':
if is_num(arg.args[0]):
return [arg.args[0]], [arg.args[1]]
Expand Down
79 changes: 33 additions & 46 deletions cpmpy/transformations/flatten_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import math
import numpy as np
from ..expressions.core import *
from ..expressions.core import _wsum_should, _wsum_make
from ..expressions.variables import _NumVarImpl, _IntVarImpl, _BoolVarImpl, NegBoolView
from ..expressions.utils import is_num, is_any_list

Expand Down Expand Up @@ -237,11 +238,13 @@ def flatten_objective(expr, supported=frozenset(["sum","wsum"])):
# one source of errors is sum(v) where v is a matrix, use v.sum() instead
raise Exception(f"Objective expects a single variable/expression, not a list of expressions")

if isinstance(expr, Expression) and expr.name in supported:
return normalized_numexpr(expr)
(flatexpr, flatcons) = normalized_numexpr(expr) # might rewrite expr into a (w)sum
if isinstance(flatexpr, Expression) and flatexpr.name in supported:
return (flatexpr, flatcons)
else:
# any other numeric expression
return get_or_make_var(expr)
# any other numeric expression,
var, cons = get_or_make_var(flatexpr)
return (var, cons+flatcons)


def __is_flat_var(arg):
Expand Down Expand Up @@ -281,7 +284,7 @@ def get_or_make_var(expr):
# then compute bounds and return (newintvar, LHS == newintvar)
(flatexpr, flatcons) = normalized_numexpr(expr)

if isinstance(flatexpr, Operator) and expr.name == "wsum":
if isinstance(flatexpr, Operator) and flatexpr.name == "wsum":
# more complex args, and weights can be negative, so more complex lbs/ubs
weights, flatvars = flatexpr.args
bounds = np.array([[w * fvar.lb for w, fvar in zip(weights, flatvars)],
Expand Down Expand Up @@ -419,43 +422,6 @@ def normalized_boolexpr(expr):
else:
# LHS can be numexpr, RHS has to be variable

# TODO: optimisations that swap directions instead when it can avoid to create vars
"""
if expr.name == '==' or expr.name == '!=':
# RHS has to be variable, LHS can be more
if __is_flat_var(lexpr) and not __is_flat_var(rexpr):
# LHS is var and RHS not, swap for new expression
lexpr, rexpr = rexpr, lexpr

if __is_numexpr(lexpr) and __is_numexpr(rexpr):
# numeric case
(lrich, lcons) = flatten_objective(lexpr)
(rvar, rcons) = flatten_numexpr(rexpr)
else:
# Boolean case
# make LHS reify_ready, RHS var
(lrich, lcons) = reify_ready_boolexpr(lexpr)
(rvar, rcons) = flatten_boolexpr(rexpr)
flatcons += [Comparison(expr.name, lrich, rvar)]+lcons+rcons

else: # inequalities '<=', '<', '>=', '>'
newname = expr.name
# LHS can be linexpr, RHS a var
if __is_flat_var(lexpr) and not __is_flat_var(rexpr):
# LHS is var and RHS not, swap for new expression (incl. operator name)
lexpr, rexpr = rexpr, lexpr
if expr.name == "<=": newname = ">="
elif expr.name == "<": newname = ">"
elif expr.name == ">=": newname = "<="
elif expr.name == ">": newname = "<"

# make LHS like objective, RHS var
(lrich, lcons) = flatten_objective(lexpr)
(rvar, rcons) = flatten_numexpr(rexpr)
flatcons += [Comparison(newname, lrich, rvar)]+lcons+rcons
"""

# RHS must be var (or const)
lexpr, rexpr = expr.args
exprname = expr.name

Expand Down Expand Up @@ -532,15 +498,36 @@ def normalized_numexpr(expr):
return get_or_make_var(expr)

elif isinstance(expr, Operator):
# special case, -var, turn into -1*args[0]
if expr.name == '-': # unary
return normalized_numexpr(-1*expr.args[0])
# rewrite -a, const*a and a*const into a weighted sum, so it can be used as objective
if expr.name == '-' or (expr.name == 'mul' and _wsum_should(expr)):
return normalized_numexpr(Operator("wsum", _wsum_make(expr)))

if all(__is_flat_var(arg) for arg in expr.args):
return (expr, [])

elif expr.name == 'wsum': # unary
# pre-process sum, to fold in nested subtractions and const*Exprs, e.g. x - y + 2*(z+r)
if expr.name == "sum" and \
any(a.name == "-" or _wsum_should(expr) for a in expr.args):
we = [_wsum_make(a) for a in expr.args]
w = [wi for w,_ in we for wi in w]
e = [ei for _,e in we for ei in e]
return normalized_numexpr(Operator("wsum", (w,e)))

# wsum needs special handling because expr.args is a tuple of which only 2nd one has exprs
if expr.name == 'wsum':
weights, sub_exprs = expr.args
# while here, avoid creation of auxiliary variables for compatible operators -/sum/wsum
i = 0
while(i < len(sub_exprs)): # can dynamically change
if sub_exprs[i].name in ['-', 'sum', 'wsum']:
w,e = _wsum_make(sub_exprs[i])
# insert in place, and next iteration over same 'i' again
weights[i:i+1] = [weights[i]*wj for wj in w]
sub_exprs[i:i+1] = e
else:
i = i+1

# now flatten the resulting subexprs
flatvars, flatcons = map(list, zip(*[get_or_make_var(arg) for arg in sub_exprs])) # also bool, reified...
newexpr = Operator(expr.name, (weights, flatvars))
return (newexpr, [c for con in flatcons for c in con])
Expand Down
20 changes: 15 additions & 5 deletions tests/test_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,10 @@ def test_get_or_make_var__bool(self):

def test_get_or_make_var__num(self):
(a,b,c,d,e) = self.ivars[:5]
(x,y,z) = self.bvars[:3]

self.assertEqual( str(get_or_make_var( a+b )), "(IV5, [((IV0) + (IV1)) == (IV5)])" )
self.assertEqual( str(get_or_make_var( a+b+c )), "(IV6, [(sum([IV0, IV1, IV2])) == (IV6)])" )
self.assertEqual( str(get_or_make_var( 2*a )), "(IV7, [(2 * (IV0)) == (IV7)])" )
self.assertEqual( str(get_or_make_var( 2*a )), "(IV7, [(sum([2] * [IV0])) == (IV7)])" )
self.assertEqual( str(get_or_make_var( a*b )), "(IV8, [((IV0) * (IV1)) == (IV8)])" )
self.assertEqual( str(get_or_make_var( a//b )), "(IV9, [((IV0) // (IV1)) == (IV9)])" )
self.assertEqual( str(get_or_make_var( 1//b )), "(IV10, [(1 // (IV1)) == (IV10)])" )
Expand All @@ -130,19 +129,30 @@ def test_get_or_make_var__num(self):
self.assertEqual( str(get_or_make_var( 1*a + 2*b + 3*c )), "(IV12, [(sum([1, 2, 3] * [IV0, IV1, IV2])) == (IV12)])")
self.assertEqual( str(get_or_make_var( cp.cpm_array([1,2,3])[a] )), "(IV13, [([1 2 3][IV0]) == (IV13)])" )
self.assertEqual( str(get_or_make_var( cp.cpm_array([b+c,2,3])[a] )), "(IV15, [((IV14, 2, 3)[IV0]) == (IV15), ((IV1) + (IV2)) == (IV14)])" )
self.assertEqual( str(get_or_make_var( a*2 )), "(IV16, [(sum([2] * [IV0])) == (IV16)])" )

def test_objective(self):
(a,b,c,d,e) = self.ivars[:5]
(x,y,z) = self.bvars[:3]

self.assertEqual( str(flatten_objective( a )), f"({str(a)}, [])" )
self.assertEqual( str(flatten_objective( -a )), '(sum([-1] * [IV0]), [])' )
self.assertEqual( str(flatten_objective( -2*a )), '(sum([-2] * [IV0]), [])' )
self.assertEqual( str(flatten_objective( a+b )), f"(({str(a)}) + ({str(b)}), [])" )
self.assertEqual( str(flatten_objective( a-b )), '(sum([1, -1] * [IV0, IV1]), [])' )
self.assertEqual( str(flatten_objective( -a+b )), '(sum([-1, 1] * [IV0, IV1]), [])' )
self.assertEqual( str(flatten_objective( a+b-c )), "(sum([1, 1, -1] * [IV0, IV1, IV2]), [])" )
self.assertEqual( str(flatten_objective( 2*a+3*b )), "(sum([2, 3] * [IV0, IV1]), [])" )
self.assertEqual( str(flatten_objective( 2*a+3*(b + c) )), "(sum([2, 3] * [IV0, IV5]), [((IV1) + (IV2)) == (IV5)])" )
self.assertEqual( str(flatten_objective( 2*a+b*3 )), "(sum([2, 3] * [IV0, IV1]), [])" )
self.assertEqual( str(flatten_objective( 2*a-b*3 )), "(sum([2, -3] * [IV0, IV1]), [])" )
self.assertEqual( str(flatten_objective( 2*a-3*b+4*c )), "(sum([2, -3, 4] * [IV0, IV1, IV2]), [])" )
self.assertEqual( str(flatten_objective( 2*a+3*(b + c) )), "(sum([2, 3, 3] * [IV0, IV1, IV2]), [])" )
self.assertEqual( str(flatten_objective( 2*a-3*(b + 2*c) )), "(sum([2, -3, -6] * [IV0, IV1, IV2]), [])" )
self.assertEqual( str(flatten_objective( 2*a-3*(b - c*2) )), '(sum([2, -3, 6] * [IV0, IV1, IV2]), [])' )
cp.intvar(0,2) # increase counter
self.assertEqual( str(flatten_objective( a//b+c )), f"((IV6) + ({str(c)}), [(({str(a)}) // ({str(b)})) == (IV6)])" )
self.assertEqual( str(flatten_objective( cp.cpm_array([1,2,3])[a] )), "(IV7, [([1 2 3][IV0]) == (IV7)])" )
self.assertEqual( str(flatten_objective( cp.cpm_array([1,2,3])[a]+b )), "((IV8) + (IV1), [([1 2 3][IV0]) == (IV8)])" )
self.assertEqual( str(flatten_objective( a+b-c )), "(sum([IV0, IV1, IV9]), [(-1 * (IV2)) == (IV9)])" )


def test_constraint(self):
(a,b,c,d,e) = self.ivars[:5]
Expand Down