Skip to content

Commit

Permalink
pass expr_dict object
Browse files Browse the repository at this point in the history
  • Loading branch information
Wout4 committed Aug 25, 2023
1 parent 7126f1b commit 5d39d3f
Showing 1 changed file with 33 additions and 33 deletions.
66 changes: 33 additions & 33 deletions cpmpy/transformations/flatten_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,10 @@ def flatten_constraint(expr,expr_dict={}):
elif isinstance(expr.args[0], _BoolVarImpl):
# LHS is var, ensure RHS is normalized 'Boolexpr'
lhs,lcons = expr.args[0], ()
rhs,rcons = normalized_boolexpr(expr.args[1])
rhs,rcons = normalized_boolexpr(expr.args[1],expr_dict)
else:
# make LHS normalized 'Boolexpr', RHS must be a var
lhs,lcons = normalized_boolexpr(expr.args[0])
lhs,lcons = normalized_boolexpr(expr.args[0],expr_dict)
rhs,rcons = get_or_make_var(expr.args[1],expr_dict)

newlist.append(Operator(expr.name, (lhs,rhs)))
Expand All @@ -210,7 +210,7 @@ def flatten_constraint(expr,expr_dict={}):

# if none of the above cases + continue matched:
# a normalizable boolexpr
(con, flatcons) = normalized_boolexpr(expr)
(con, flatcons) = normalized_boolexpr(expr,expr_dict)
newlist.append(con)
newlist.extend(flatcons)

Expand Down Expand Up @@ -258,15 +258,15 @@ def flatten_constraint(expr,expr_dict={}):
# shortcut, full original one is normalizable BoolExpr
# such as And(v1,v2,v3) == 0
# TODO: should be normalized away in earlier transform
(con, flatcons) = normalized_boolexpr(expr)
(con, flatcons) = normalized_boolexpr(expr,expr_dict)
newlist.append(con)
newlist.extend(flatcons)
continue
else:
(lhs, lcons) = normalized_boolexpr(lexpr)
(lhs, lcons) = normalized_boolexpr(lexpr,expr_dict)
else:
# other cases: LHS is numexpr
(lhs, lcons) = normalized_numexpr(lexpr)
(lhs, lcons) = normalized_numexpr(lexpr,expr_dict)

newlist.append(Comparison(exprname, lhs, rvar))
newlist.extend(lcons)
Expand All @@ -276,7 +276,7 @@ def flatten_constraint(expr,expr_dict={}):
"""
- Global constraint: global([Var]*) (CPMpy class 'GlobalConstraint')
"""
(con, flatcons) = normalized_boolexpr(expr)
(con, flatcons) = normalized_boolexpr(expr,expr_dict)
newlist.append(con)
newlist.extend(flatcons)

Expand All @@ -287,7 +287,7 @@ def flatten_constraint(expr,expr_dict={}):
return newlist


def flatten_objective(expr, supported=frozenset(["sum","wsum"])):
def flatten_objective(expr, supported=frozenset(["sum","wsum"]),expr_dict={}):
"""
- Decision variable: Var
- Linear: sum([Var]) (CPMpy class 'Operator', name 'sum')
Expand All @@ -299,12 +299,12 @@ def flatten_objective(expr, supported=frozenset(["sum","wsum"])):
raise Exception(f"Objective expects a single variable/expression, not a list of expressions")

expr = simplify_boolean([expr])[0]
(flatexpr, flatcons) = normalized_numexpr(expr) # might rewrite expr into a (w)sum
(flatexpr, flatcons) = normalized_numexpr(expr,expr_dict) # might rewrite expr into a (w)sum
if isinstance(flatexpr, Expression) and flatexpr.name in supported:
return (flatexpr, flatcons)
else:
# any other numeric expression,
var, cons = get_or_make_var(flatexpr)
var, cons = get_or_make_var(flatexpr,expr_dict)
return (var, cons+flatcons)


Expand Down Expand Up @@ -337,7 +337,7 @@ def get_or_make_var(expr,expr_dict={}):

if expr.is_bool():
# normalize expr into a boolexpr LHS, reify LHS == bvar
(flatexpr, flatcons) = normalized_boolexpr(expr)
(flatexpr, flatcons) = normalized_boolexpr(expr,expr_dict)

if isinstance(flatexpr,_BoolVarImpl):
#avoids unnecessary bv == bv or bv == ~bv assignments
Expand All @@ -353,7 +353,7 @@ def get_or_make_var(expr,expr_dict={}):
else:
# normalize expr into a numexpr LHS,
# then compute bounds and return (newintvar, LHS == newintvar)
(flatexpr, flatcons) = normalized_numexpr(expr)
(flatexpr, flatcons) = normalized_numexpr(expr,expr_dict)

lb, ub = flatexpr.get_bounds()
ivar = _IntVarImpl(lb, ub)
Expand All @@ -363,20 +363,20 @@ def get_or_make_var(expr,expr_dict={}):
expr_dict[str(flatexpr)] = ivar
return (ivar, [flatexpr == ivar]+flatcons)

def get_or_make_var_or_list(expr):
def get_or_make_var_or_list(expr,expr_dict={}):
""" Like get_or_make_var() but also accepts and recursively transforms lists
Used to convert arguments of globals
"""
if __is_flat_var_or_list(expr):
return (expr,[])
elif is_any_list(expr):
flatvars, flatcons = zip(*[get_or_make_var(arg) for arg in expr])
flatvars, flatcons = zip(*[get_or_make_var(arg,expr_dict) for arg in expr])
return (flatvars, [c for con in flatcons for c in con])
else:
return get_or_make_var(expr)
return get_or_make_var(expr,expr_dict)


def normalized_boolexpr(expr):
def normalized_boolexpr(expr,expr_dict={}):
"""
input is any Boolean (is_bool()) expression
output are all 'flat normal form' Boolean expressions that can be 'reified', meaning that
Expand Down Expand Up @@ -404,18 +404,18 @@ def normalized_boolexpr(expr):
# apply De Morgan's transform for "implies"
if expr.name == '->':
# TODO, optimisation if args0 is an 'and'?
(lhs,lcons) = get_or_make_var(expr.args[0])
(lhs,lcons) = get_or_make_var(expr.args[0],expr_dict)
# TODO, optimisation if args1 is an 'or'?
(rhs,rcons) = get_or_make_var(expr.args[1])
(rhs,rcons) = get_or_make_var(expr.args[1],expr_dict)
return ((~lhs | rhs), lcons+rcons)
if expr.name == 'not':
flatvar, flatcons = get_or_make_var(expr.args[0])
flatvar, flatcons = get_or_make_var(expr.args[0],expr_dict)
return (~flatvar, flatcons)
if all(__is_flat_var(arg) for arg in expr.args):
return (expr, [])
else:
# one of the arguments is not flat, flatten all
flatvars, flatcons = zip(*[get_or_make_var(arg) for arg in expr.args])
flatvars, flatcons = zip(*[get_or_make_var(arg,expr_dict) for arg in expr.args])
newexpr = Operator(expr.name, flatvars)
return (newexpr, [c for con in flatcons for c in con])

Expand All @@ -434,7 +434,7 @@ def normalized_boolexpr(expr):
lexpr, rexpr = rexpr, lexpr

# ensure rhs is var
(rvar, rcons) = get_or_make_var(rexpr)
(rvar, rcons) = get_or_make_var(rexpr,expr_dict)

# LHS: check if Boolexpr == smth:
if (exprname == '==' or exprname == '!=') and lexpr.is_bool():
Expand All @@ -443,19 +443,19 @@ def normalized_boolexpr(expr):
assert (not rexpr), f"should be false: {rexpr}" # 'true' is preprocessed away
if exprname == '==':
nnexpr = recurse_negation(lexpr)
return normalized_boolexpr(nnexpr)
return normalized_boolexpr(nnexpr,expr_dict)
else: # !=, should only be possible in dubble negation
return normalized_boolexpr(lexpr)
return normalized_boolexpr(lexpr,expr_dict)

# this is a reified constraint, so lhs must be var too to be in normal form
(lhs, lcons) = get_or_make_var(lexpr)
(lhs, lcons) = get_or_make_var(lexpr,expr_dict)
if expr.name == '!=' and rvar.is_bool():
# != not needed, negate RHS variable
rvar = ~rvar
exprname = '=='
else:
# other cases: LHS is numexpr
(lhs, lcons) = normalized_numexpr(lexpr)
(lhs, lcons) = normalized_numexpr(lexpr,expr_dict)

return (Comparison(exprname, lhs, rvar), lcons+rcons)

Expand All @@ -468,15 +468,15 @@ def normalized_boolexpr(expr):
return (expr, [])
else:
# recursively flatten all children
flatargs, flatcons = zip(*[get_or_make_var_or_list(arg) for arg in expr.args])
flatargs, flatcons = zip(*[get_or_make_var_or_list(arg,expr_dict) for arg in expr.args])

# take copy, replace args
newexpr = copy.copy(expr) # shallow or deep? currently shallow
newexpr.args = flatargs
return (newexpr, [c for con in flatcons for c in con])


def normalized_numexpr(expr):
def normalized_numexpr(expr,expr_dict={}):
"""
all 'flat normal form' numeric expressions...
Expand All @@ -498,12 +498,12 @@ def normalized_numexpr(expr):
elif expr.is_bool():
# unusual case, but its truth-value is a valid numexpr
# so reify and return the boolvar
return get_or_make_var(expr)
return get_or_make_var(expr,expr_dict)

elif isinstance(expr, Operator):
# rewrite -a, const*a and a*const into a weighted sum, so it can be used as objective
if expr.name == '-' or (expr.name == 'mul' and _wsum_should(expr)):
return normalized_numexpr(Operator("wsum", _wsum_make(expr)))
return normalized_numexpr(Operator("wsum", _wsum_make(expr)),expr_dict)

if all(__is_flat_var(arg) for arg in expr.args):
return (expr, [])
Expand All @@ -515,7 +515,7 @@ def normalized_numexpr(expr):
we = [_wsum_make(a) for a in expr.args]
w = [wi for w,_ in we for wi in w]
e = [ei for _,e in we for ei in e]
return normalized_numexpr(Operator("wsum", (w,e)))
return normalized_numexpr(Operator("wsum", (w,e)),expr_dict)

# wsum needs special handling because expr.args is a tuple of which only 2nd one has exprs
if expr.name == 'wsum':
Expand All @@ -536,13 +536,13 @@ def normalized_numexpr(expr):
i = i+1

# now flatten the resulting subexprs
flatvars, flatcons = map(list, zip(*[get_or_make_var(arg) for arg in sub_exprs])) # also bool, reified...
flatvars, flatcons = map(list, zip(*[get_or_make_var(arg,expr_dict) for arg in sub_exprs])) # also bool, reified...
newexpr = Operator(expr.name, (weights, flatvars))
return (newexpr, [c for con in flatcons for c in con])

else: # generic operator
# recursively flatten all children
flatvars, flatcons = zip(*[get_or_make_var(arg) for arg in expr.args])
flatvars, flatcons = zip(*[get_or_make_var(arg,expr_dict) for arg in expr.args])

newexpr = Operator(expr.name, flatvars)
return (newexpr, [c for con in flatcons for c in con])
Expand All @@ -554,7 +554,7 @@ def normalized_numexpr(expr):
return (expr, [])
else:
# recursively flatten all children
flatvars, flatcons = zip(*[get_or_make_var_or_list(arg) for arg in expr.args])
flatvars, flatcons = zip(*[get_or_make_var_or_list(arg,expr_dict) for arg in expr.args])

# take copy, replace args
newexpr = copy.copy(expr) # shallow or deep? currently shallow
Expand Down

0 comments on commit 5d39d3f

Please sign in to comment.