diff --git a/cpmpy/transformations/flatten_model.py b/cpmpy/transformations/flatten_model.py index b1c6dbc3a..40b8bf8e1 100644 --- a/cpmpy/transformations/flatten_model.py +++ b/cpmpy/transformations/flatten_model.py @@ -195,10 +195,10 @@ def flatten_constraint(expr,expr_dict={}): elif isinstance(expr.args[0], _BoolVarImpl): # LHS is var, ensure RHS is normalized 'Boolexpr' lhs,lcons = expr.args[0], () - rhs,rcons = normalized_boolexpr(expr.args[1]) + rhs,rcons = normalized_boolexpr(expr.args[1],expr_dict) else: # make LHS normalized 'Boolexpr', RHS must be a var - lhs,lcons = normalized_boolexpr(expr.args[0]) + lhs,lcons = normalized_boolexpr(expr.args[0],expr_dict) rhs,rcons = get_or_make_var(expr.args[1],expr_dict) newlist.append(Operator(expr.name, (lhs,rhs))) @@ -210,7 +210,7 @@ def flatten_constraint(expr,expr_dict={}): # if none of the above cases + continue matched: # a normalizable boolexpr - (con, flatcons) = normalized_boolexpr(expr) + (con, flatcons) = normalized_boolexpr(expr,expr_dict) newlist.append(con) newlist.extend(flatcons) @@ -258,15 +258,15 @@ def flatten_constraint(expr,expr_dict={}): # shortcut, full original one is normalizable BoolExpr # such as And(v1,v2,v3) == 0 # TODO: should be normalized away in earlier transform - (con, flatcons) = normalized_boolexpr(expr) + (con, flatcons) = normalized_boolexpr(expr,expr_dict) newlist.append(con) newlist.extend(flatcons) continue else: - (lhs, lcons) = normalized_boolexpr(lexpr) + (lhs, lcons) = normalized_boolexpr(lexpr,expr_dict) else: # other cases: LHS is numexpr - (lhs, lcons) = normalized_numexpr(lexpr) + (lhs, lcons) = normalized_numexpr(lexpr,expr_dict) newlist.append(Comparison(exprname, lhs, rvar)) newlist.extend(lcons) @@ -276,7 +276,7 @@ def flatten_constraint(expr,expr_dict={}): """ - Global constraint: global([Var]*) (CPMpy class 'GlobalConstraint') """ - (con, flatcons) = normalized_boolexpr(expr) + (con, flatcons) = normalized_boolexpr(expr,expr_dict) newlist.append(con) newlist.extend(flatcons) @@ -287,7 +287,7 @@ def flatten_constraint(expr,expr_dict={}): return newlist -def flatten_objective(expr, supported=frozenset(["sum","wsum"])): +def flatten_objective(expr, supported=frozenset(["sum","wsum"]),expr_dict={}): """ - Decision variable: Var - Linear: sum([Var]) (CPMpy class 'Operator', name 'sum') @@ -299,12 +299,12 @@ def flatten_objective(expr, supported=frozenset(["sum","wsum"])): raise Exception(f"Objective expects a single variable/expression, not a list of expressions") expr = simplify_boolean([expr])[0] - (flatexpr, flatcons) = normalized_numexpr(expr) # might rewrite expr into a (w)sum + (flatexpr, flatcons) = normalized_numexpr(expr,expr_dict) # might rewrite expr into a (w)sum if isinstance(flatexpr, Expression) and flatexpr.name in supported: return (flatexpr, flatcons) else: # any other numeric expression, - var, cons = get_or_make_var(flatexpr) + var, cons = get_or_make_var(flatexpr,expr_dict) return (var, cons+flatcons) @@ -337,7 +337,7 @@ def get_or_make_var(expr,expr_dict={}): if expr.is_bool(): # normalize expr into a boolexpr LHS, reify LHS == bvar - (flatexpr, flatcons) = normalized_boolexpr(expr) + (flatexpr, flatcons) = normalized_boolexpr(expr,expr_dict) if isinstance(flatexpr,_BoolVarImpl): #avoids unnecessary bv == bv or bv == ~bv assignments @@ -353,7 +353,7 @@ def get_or_make_var(expr,expr_dict={}): else: # normalize expr into a numexpr LHS, # then compute bounds and return (newintvar, LHS == newintvar) - (flatexpr, flatcons) = normalized_numexpr(expr) + (flatexpr, flatcons) = normalized_numexpr(expr,expr_dict) lb, ub = flatexpr.get_bounds() ivar = _IntVarImpl(lb, ub) @@ -363,20 +363,20 @@ def get_or_make_var(expr,expr_dict={}): expr_dict[str(flatexpr)] = ivar return (ivar, [flatexpr == ivar]+flatcons) -def get_or_make_var_or_list(expr): +def get_or_make_var_or_list(expr,expr_dict={}): """ Like get_or_make_var() but also accepts and recursively transforms lists Used to convert arguments of globals """ if __is_flat_var_or_list(expr): return (expr,[]) elif is_any_list(expr): - flatvars, flatcons = zip(*[get_or_make_var(arg) for arg in expr]) + flatvars, flatcons = zip(*[get_or_make_var(arg,expr_dict) for arg in expr]) return (flatvars, [c for con in flatcons for c in con]) else: - return get_or_make_var(expr) + return get_or_make_var(expr,expr_dict) -def normalized_boolexpr(expr): +def normalized_boolexpr(expr,expr_dict={}): """ input is any Boolean (is_bool()) expression output are all 'flat normal form' Boolean expressions that can be 'reified', meaning that @@ -404,18 +404,18 @@ def normalized_boolexpr(expr): # apply De Morgan's transform for "implies" if expr.name == '->': # TODO, optimisation if args0 is an 'and'? - (lhs,lcons) = get_or_make_var(expr.args[0]) + (lhs,lcons) = get_or_make_var(expr.args[0],expr_dict) # TODO, optimisation if args1 is an 'or'? - (rhs,rcons) = get_or_make_var(expr.args[1]) + (rhs,rcons) = get_or_make_var(expr.args[1],expr_dict) return ((~lhs | rhs), lcons+rcons) if expr.name == 'not': - flatvar, flatcons = get_or_make_var(expr.args[0]) + flatvar, flatcons = get_or_make_var(expr.args[0],expr_dict) return (~flatvar, flatcons) if all(__is_flat_var(arg) for arg in expr.args): return (expr, []) else: # one of the arguments is not flat, flatten all - flatvars, flatcons = zip(*[get_or_make_var(arg) for arg in expr.args]) + flatvars, flatcons = zip(*[get_or_make_var(arg,expr_dict) for arg in expr.args]) newexpr = Operator(expr.name, flatvars) return (newexpr, [c for con in flatcons for c in con]) @@ -434,7 +434,7 @@ def normalized_boolexpr(expr): lexpr, rexpr = rexpr, lexpr # ensure rhs is var - (rvar, rcons) = get_or_make_var(rexpr) + (rvar, rcons) = get_or_make_var(rexpr,expr_dict) # LHS: check if Boolexpr == smth: if (exprname == '==' or exprname == '!=') and lexpr.is_bool(): @@ -443,19 +443,19 @@ def normalized_boolexpr(expr): assert (not rexpr), f"should be false: {rexpr}" # 'true' is preprocessed away if exprname == '==': nnexpr = recurse_negation(lexpr) - return normalized_boolexpr(nnexpr) + return normalized_boolexpr(nnexpr,expr_dict) else: # !=, should only be possible in dubble negation - return normalized_boolexpr(lexpr) + return normalized_boolexpr(lexpr,expr_dict) # this is a reified constraint, so lhs must be var too to be in normal form - (lhs, lcons) = get_or_make_var(lexpr) + (lhs, lcons) = get_or_make_var(lexpr,expr_dict) if expr.name == '!=' and rvar.is_bool(): # != not needed, negate RHS variable rvar = ~rvar exprname = '==' else: # other cases: LHS is numexpr - (lhs, lcons) = normalized_numexpr(lexpr) + (lhs, lcons) = normalized_numexpr(lexpr,expr_dict) return (Comparison(exprname, lhs, rvar), lcons+rcons) @@ -468,7 +468,7 @@ def normalized_boolexpr(expr): return (expr, []) else: # recursively flatten all children - flatargs, flatcons = zip(*[get_or_make_var_or_list(arg) for arg in expr.args]) + flatargs, flatcons = zip(*[get_or_make_var_or_list(arg,expr_dict) for arg in expr.args]) # take copy, replace args newexpr = copy.copy(expr) # shallow or deep? currently shallow @@ -476,7 +476,7 @@ def normalized_boolexpr(expr): return (newexpr, [c for con in flatcons for c in con]) -def normalized_numexpr(expr): +def normalized_numexpr(expr,expr_dict={}): """ all 'flat normal form' numeric expressions... @@ -498,12 +498,12 @@ def normalized_numexpr(expr): elif expr.is_bool(): # unusual case, but its truth-value is a valid numexpr # so reify and return the boolvar - return get_or_make_var(expr) + return get_or_make_var(expr,expr_dict) elif isinstance(expr, Operator): # rewrite -a, const*a and a*const into a weighted sum, so it can be used as objective if expr.name == '-' or (expr.name == 'mul' and _wsum_should(expr)): - return normalized_numexpr(Operator("wsum", _wsum_make(expr))) + return normalized_numexpr(Operator("wsum", _wsum_make(expr)),expr_dict) if all(__is_flat_var(arg) for arg in expr.args): return (expr, []) @@ -515,7 +515,7 @@ def normalized_numexpr(expr): we = [_wsum_make(a) for a in expr.args] w = [wi for w,_ in we for wi in w] e = [ei for _,e in we for ei in e] - return normalized_numexpr(Operator("wsum", (w,e))) + return normalized_numexpr(Operator("wsum", (w,e)),expr_dict) # wsum needs special handling because expr.args is a tuple of which only 2nd one has exprs if expr.name == 'wsum': @@ -536,13 +536,13 @@ def normalized_numexpr(expr): i = i+1 # now flatten the resulting subexprs - flatvars, flatcons = map(list, zip(*[get_or_make_var(arg) for arg in sub_exprs])) # also bool, reified... + flatvars, flatcons = map(list, zip(*[get_or_make_var(arg,expr_dict) for arg in sub_exprs])) # also bool, reified... newexpr = Operator(expr.name, (weights, flatvars)) return (newexpr, [c for con in flatcons for c in con]) else: # generic operator # recursively flatten all children - flatvars, flatcons = zip(*[get_or_make_var(arg) for arg in expr.args]) + flatvars, flatcons = zip(*[get_or_make_var(arg,expr_dict) for arg in expr.args]) newexpr = Operator(expr.name, flatvars) return (newexpr, [c for con in flatcons for c in con]) @@ -554,7 +554,7 @@ def normalized_numexpr(expr): return (expr, []) else: # recursively flatten all children - flatvars, flatcons = zip(*[get_or_make_var_or_list(arg) for arg in expr.args]) + flatvars, flatcons = zip(*[get_or_make_var_or_list(arg,expr_dict) for arg in expr.args]) # take copy, replace args newexpr = copy.copy(expr) # shallow or deep? currently shallow