From 390a215e0ef207d8e12b2cdb50327618f2d1325b Mon Sep 17 00:00:00 2001 From: Tias Guns Date: Tue, 7 Feb 2023 23:39:54 +0100 Subject: [PATCH] flatten: avoid unnecessary sum decompositions --- cpmpy/expressions/core.py | 20 +++++-- cpmpy/transformations/flatten_model.py | 79 +++++++++++--------------- tests/test_flatten.py | 20 +++++-- 3 files changed, 63 insertions(+), 56 deletions(-) diff --git a/cpmpy/expressions/core.py b/cpmpy/expressions/core.py index addafcf41..8910df488 100644 --- a/cpmpy/expressions/core.py +++ b/cpmpy/expressions/core.py @@ -437,11 +437,9 @@ def __init__(self, name, arg_list): if name == 'sum' and \ all(not is_num(a) for a in arg_list) and \ any(_wsum_should(a) for a in arg_list): - w,e = [], [] - for a in arg_list: - w1,e1 = _wsum_make(a) - w += w1 - e += e1 + we = [_wsum_make(a) for a in arg_list] + w = [wi for w,_ in we for wi in w] + e = [ei for _,e in we for ei in e] name = 'wsum' arg_list = [w,e] @@ -462,6 +460,16 @@ def __init__(self, name, arg_list): i += l i += 1 + # another cleanup, translate -(v*c) to v*-c + if name == '-' and arg_list[0].name == 'mul' and len(arg_list[0].args)==2: + mul_args = arg_list[0].args + if is_num(mul_args[0]): + name = 'mul' + arg_list = (-mul_args[0], mul_args[1]) + elif is_num(mul_args[1]): + name = 'mul' + arg_list = (mul_args[0], -mul_args[1]) + super().__init__(name, arg_list) def is_bool(self): @@ -556,6 +564,8 @@ def _wsum_make(arg): """ if arg.name == 'wsum': return arg.args + elif arg.name == 'sum': + return [1]*len(arg.args), arg.args elif arg.name == 'mul': if is_num(arg.args[0]): return [arg.args[0]], [arg.args[1]] diff --git a/cpmpy/transformations/flatten_model.py b/cpmpy/transformations/flatten_model.py index 343ff06d6..9eb08cac4 100644 --- a/cpmpy/transformations/flatten_model.py +++ b/cpmpy/transformations/flatten_model.py @@ -81,6 +81,7 @@ import math import numpy as np from ..expressions.core import * +from ..expressions.core import _wsum_should, _wsum_make from ..expressions.variables import _NumVarImpl, _IntVarImpl, _BoolVarImpl, NegBoolView from ..expressions.utils import is_num, is_any_list @@ -237,11 +238,13 @@ def flatten_objective(expr, supported=frozenset(["sum","wsum"])): # one source of errors is sum(v) where v is a matrix, use v.sum() instead raise Exception(f"Objective expects a single variable/expression, not a list of expressions") - if isinstance(expr, Expression) and expr.name in supported: - return normalized_numexpr(expr) + (flatexpr, flatcons) = normalized_numexpr(expr) # might rewrite expr into a (w)sum + if isinstance(flatexpr, Expression) and flatexpr.name in supported: + return (flatexpr, flatcons) else: - # any other numeric expression - return get_or_make_var(expr) + # any other numeric expression, + var, cons = get_or_make_var(flatexpr) + return (var, cons+flatcons) def __is_flat_var(arg): @@ -281,7 +284,7 @@ def get_or_make_var(expr): # then compute bounds and return (newintvar, LHS == newintvar) (flatexpr, flatcons) = normalized_numexpr(expr) - if isinstance(flatexpr, Operator) and expr.name == "wsum": + if isinstance(flatexpr, Operator) and flatexpr.name == "wsum": # more complex args, and weights can be negative, so more complex lbs/ubs weights, flatvars = flatexpr.args bounds = np.array([[w * fvar.lb for w, fvar in zip(weights, flatvars)], @@ -419,43 +422,6 @@ def normalized_boolexpr(expr): else: # LHS can be numexpr, RHS has to be variable - # TODO: optimisations that swap directions instead when it can avoid to create vars - """ - if expr.name == '==' or expr.name == '!=': - # RHS has to be variable, LHS can be more - if __is_flat_var(lexpr) and not __is_flat_var(rexpr): - # LHS is var and RHS not, swap for new expression - lexpr, rexpr = rexpr, lexpr - - if __is_numexpr(lexpr) and __is_numexpr(rexpr): - # numeric case - (lrich, lcons) = flatten_objective(lexpr) - (rvar, rcons) = flatten_numexpr(rexpr) - else: - # Boolean case - # make LHS reify_ready, RHS var - (lrich, lcons) = reify_ready_boolexpr(lexpr) - (rvar, rcons) = flatten_boolexpr(rexpr) - flatcons += [Comparison(expr.name, lrich, rvar)]+lcons+rcons - - else: # inequalities '<=', '<', '>=', '>' - newname = expr.name - # LHS can be linexpr, RHS a var - if __is_flat_var(lexpr) and not __is_flat_var(rexpr): - # LHS is var and RHS not, swap for new expression (incl. operator name) - lexpr, rexpr = rexpr, lexpr - if expr.name == "<=": newname = ">=" - elif expr.name == "<": newname = ">" - elif expr.name == ">=": newname = "<=" - elif expr.name == ">": newname = "<" - - # make LHS like objective, RHS var - (lrich, lcons) = flatten_objective(lexpr) - (rvar, rcons) = flatten_numexpr(rexpr) - flatcons += [Comparison(newname, lrich, rvar)]+lcons+rcons - """ - - # RHS must be var (or const) lexpr, rexpr = expr.args exprname = expr.name @@ -532,15 +498,36 @@ def normalized_numexpr(expr): return get_or_make_var(expr) elif isinstance(expr, Operator): - # special case, -var, turn into -1*args[0] - if expr.name == '-': # unary - return normalized_numexpr(-1*expr.args[0]) + # rewrite -a, const*a and a*const into a weighted sum, so it can be used as objective + if expr.name == '-' or (expr.name == 'mul' and _wsum_should(expr)): + return normalized_numexpr(Operator("wsum", _wsum_make(expr))) if all(__is_flat_var(arg) for arg in expr.args): return (expr, []) - elif expr.name == 'wsum': # unary + # pre-process sum, to fold in nested subtractions and const*Exprs, e.g. x - y + 2*(z+r) + if expr.name == "sum" and \ + any(a.name == "-" or _wsum_should(expr) for a in expr.args): + we = [_wsum_make(a) for a in expr.args] + w = [wi for w,_ in we for wi in w] + e = [ei for _,e in we for ei in e] + return normalized_numexpr(Operator("wsum", (w,e))) + + # wsum needs special handling because expr.args is a tuple of which only 2nd one has exprs + if expr.name == 'wsum': weights, sub_exprs = expr.args + # while here, avoid creation of auxiliary variables for compatible operators -/sum/wsum + i = 0 + while(i < len(sub_exprs)): # can dynamically change + if sub_exprs[i].name in ['-', 'sum', 'wsum']: + w,e = _wsum_make(sub_exprs[i]) + # insert in place, and next iteration over same 'i' again + weights[i:i+1] = [weights[i]*wj for wj in w] + sub_exprs[i:i+1] = e + else: + i = i+1 + + # now flatten the resulting subexprs flatvars, flatcons = map(list, zip(*[get_or_make_var(arg) for arg in sub_exprs])) # also bool, reified... newexpr = Operator(expr.name, (weights, flatvars)) return (newexpr, [c for con in flatcons for c in con]) diff --git a/tests/test_flatten.py b/tests/test_flatten.py index 677a18512..c587ad6da 100644 --- a/tests/test_flatten.py +++ b/tests/test_flatten.py @@ -117,11 +117,10 @@ def test_get_or_make_var__bool(self): def test_get_or_make_var__num(self): (a,b,c,d,e) = self.ivars[:5] - (x,y,z) = self.bvars[:3] self.assertEqual( str(get_or_make_var( a+b )), "(IV5, [((IV0) + (IV1)) == (IV5)])" ) self.assertEqual( str(get_or_make_var( a+b+c )), "(IV6, [(sum([IV0, IV1, IV2])) == (IV6)])" ) - self.assertEqual( str(get_or_make_var( 2*a )), "(IV7, [(2 * (IV0)) == (IV7)])" ) + self.assertEqual( str(get_or_make_var( 2*a )), "(IV7, [(sum([2] * [IV0])) == (IV7)])" ) self.assertEqual( str(get_or_make_var( a*b )), "(IV8, [((IV0) * (IV1)) == (IV8)])" ) self.assertEqual( str(get_or_make_var( a//b )), "(IV9, [((IV0) // (IV1)) == (IV9)])" ) self.assertEqual( str(get_or_make_var( 1//b )), "(IV10, [(1 // (IV1)) == (IV10)])" ) @@ -130,19 +129,30 @@ def test_get_or_make_var__num(self): self.assertEqual( str(get_or_make_var( 1*a + 2*b + 3*c )), "(IV12, [(sum([1, 2, 3] * [IV0, IV1, IV2])) == (IV12)])") self.assertEqual( str(get_or_make_var( cp.cpm_array([1,2,3])[a] )), "(IV13, [([1 2 3][IV0]) == (IV13)])" ) self.assertEqual( str(get_or_make_var( cp.cpm_array([b+c,2,3])[a] )), "(IV15, [((IV14, 2, 3)[IV0]) == (IV15), ((IV1) + (IV2)) == (IV14)])" ) + self.assertEqual( str(get_or_make_var( a*2 )), "(IV16, [(sum([2] * [IV0])) == (IV16)])" ) def test_objective(self): (a,b,c,d,e) = self.ivars[:5] - (x,y,z) = self.bvars[:3] self.assertEqual( str(flatten_objective( a )), f"({str(a)}, [])" ) + self.assertEqual( str(flatten_objective( -a )), '(sum([-1] * [IV0]), [])' ) + self.assertEqual( str(flatten_objective( -2*a )), '(sum([-2] * [IV0]), [])' ) self.assertEqual( str(flatten_objective( a+b )), f"(({str(a)}) + ({str(b)}), [])" ) + self.assertEqual( str(flatten_objective( a-b )), '(sum([1, -1] * [IV0, IV1]), [])' ) + self.assertEqual( str(flatten_objective( -a+b )), '(sum([-1, 1] * [IV0, IV1]), [])' ) + self.assertEqual( str(flatten_objective( a+b-c )), "(sum([1, 1, -1] * [IV0, IV1, IV2]), [])" ) self.assertEqual( str(flatten_objective( 2*a+3*b )), "(sum([2, 3] * [IV0, IV1]), [])" ) - self.assertEqual( str(flatten_objective( 2*a+3*(b + c) )), "(sum([2, 3] * [IV0, IV5]), [((IV1) + (IV2)) == (IV5)])" ) + self.assertEqual( str(flatten_objective( 2*a+b*3 )), "(sum([2, 3] * [IV0, IV1]), [])" ) + self.assertEqual( str(flatten_objective( 2*a-b*3 )), "(sum([2, -3] * [IV0, IV1]), [])" ) + self.assertEqual( str(flatten_objective( 2*a-3*b+4*c )), "(sum([2, -3, 4] * [IV0, IV1, IV2]), [])" ) + self.assertEqual( str(flatten_objective( 2*a+3*(b + c) )), "(sum([2, 3, 3] * [IV0, IV1, IV2]), [])" ) + self.assertEqual( str(flatten_objective( 2*a-3*(b + 2*c) )), "(sum([2, -3, -6] * [IV0, IV1, IV2]), [])" ) + self.assertEqual( str(flatten_objective( 2*a-3*(b - c*2) )), '(sum([2, -3, 6] * [IV0, IV1, IV2]), [])' ) + cp.intvar(0,2) # increase counter self.assertEqual( str(flatten_objective( a//b+c )), f"((IV6) + ({str(c)}), [(({str(a)}) // ({str(b)})) == (IV6)])" ) self.assertEqual( str(flatten_objective( cp.cpm_array([1,2,3])[a] )), "(IV7, [([1 2 3][IV0]) == (IV7)])" ) self.assertEqual( str(flatten_objective( cp.cpm_array([1,2,3])[a]+b )), "((IV8) + (IV1), [([1 2 3][IV0]) == (IV8)])" ) - self.assertEqual( str(flatten_objective( a+b-c )), "(sum([IV0, IV1, IV9]), [(-1 * (IV2)) == (IV9)])" ) + def test_constraint(self): (a,b,c,d,e) = self.ivars[:5]