Skip to content

Commit

Permalink
everything that uses get_or_make_var
Browse files Browse the repository at this point in the history
  • Loading branch information
Wout4 committed Aug 25, 2023
1 parent 5d39d3f commit fc6ecd5
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 48 deletions.
14 changes: 7 additions & 7 deletions cpmpy/solvers/exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def objective(self, expr, minimize):
self.objective_minimize = minimize

# make objective function non-nested
(flat_obj, flat_cons) = flatten_objective(expr)
(flat_obj, flat_cons) = flatten_objective(expr,expr_dict=self.expr_dict)
self += flat_cons # add potentially created constraints
self.user_vars.update(get_variables(flat_obj)) # add objvars to vars

Expand Down Expand Up @@ -404,12 +404,12 @@ def transform(self, cpm_expr):
# expressions have to be linearized to fit in MIP model. See /transformations/linearize
cpm_cons = toplevel_list(cpm_expr)
cpm_cons = decompose_in_tree(cpm_cons, supported=frozenset({'alldifferent'})) # Alldiff has a specialzed MIP decomp
cpm_cons = flatten_constraint(cpm_cons) # flat normal form
cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum'])) # constraints that support reification
cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum"])) # supports >, <, !=
cpm_cons = only_bv_implies(cpm_cons) # anything that can create full reif should go above...
cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum","wsum"})) # the core of the MIP-linearization
cpm_cons = only_positive_bv(cpm_cons) # after linearisation, rewrite ~bv into 1-bv
cpm_cons = flatten_constraint(cpm_cons,expr_dict=self.expr_dict) # flat normal form
cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum']),expr_dict=self.expr_dict) # constraints that support reification
cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum"]),expr_dict=self.expr_dict) # supports >, <, !=
cpm_cons = only_bv_implies(cpm_cons,expr_dict=self.expr_dict) # anything that can create full reif should go above...
cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum","wsum"}),expr_dict=self.expr_dict) # the core of the MIP-linearization
cpm_cons = only_positive_bv(cpm_cons,expr_dict=self.expr_dict) # after linearisation, rewrite ~bv into 1-bv
return cpm_cons

# NOTE: the transformations that are still done specifically for Exact are two-fold:
Expand Down
14 changes: 7 additions & 7 deletions cpmpy/solvers/gurobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def objective(self, expr, minimize=True):
from gurobipy import GRB

# make objective function non-nested
(flat_obj, flat_cons) = (flatten_objective(expr))
(flat_obj, flat_cons) = (flatten_objective(expr,expr_dict=self.expr_dict))
self += flat_cons
get_variables(flat_obj, collect=self.user_vars) # add potentially created constraints

Expand Down Expand Up @@ -270,12 +270,12 @@ def transform(self, cpm_expr):
cpm_cons = toplevel_list(cpm_expr)
supported = {"min", "max", "abs", "alldifferent"} # alldiff has a specialized MIP decomp in linearize
cpm_cons = decompose_in_tree(cpm_cons, supported)
cpm_cons = flatten_constraint(cpm_cons) # flat normal form
cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum'])) # constraints that support reification
cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum", "sub"])) # supports >, <, !=
cpm_cons = only_bv_implies(cpm_cons) # anything that can create full reif should go above...
cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum", "wsum","sub","min","max","mul","abs","pow","div"})) # the core of the MIP-linearization
cpm_cons = only_positive_bv(cpm_cons) # after linearization, rewrite ~bv into 1-bv
cpm_cons = flatten_constraint(cpm_cons,expr_dict=self.expr_dict) # flat normal form
cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum']),expr_dict=self.expr_dict) # constraints that support reification
cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum", "sub"]),expr_dict=self.expr_dict) # supports >, <, !=
cpm_cons = only_bv_implies(cpm_cons,expr_dict=self.expr_dict) # anything that can create full reif should go above...
cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum", "wsum","sub","min","max","mul","abs","pow","div"}),expr_dict=self.expr_dict) # the core of the MIP-linearization
cpm_cons = only_positive_bv(cpm_cons,expr_dict=self.expr_dict) # after linearization, rewrite ~bv into 1-bv
return cpm_cons

def __add__(self, cpm_expr_orig):
Expand Down
10 changes: 5 additions & 5 deletions cpmpy/solvers/ortools.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def objective(self, expr, minimize):
are premanently posted to the solver)
"""
# make objective function non-nested
(flat_obj, flat_cons) = flatten_objective(expr)
(flat_obj, flat_cons) = flatten_objective(expr,expr_dict=self.expr_dict)
self += flat_cons # add potentially created constraints
get_variables(flat_obj, collect=self.user_vars) # add objvars to vars

Expand Down Expand Up @@ -331,10 +331,10 @@ def transform(self, cpm_expr):
cpm_cons = toplevel_list(cpm_expr)
supported = {"min", "max", "abs", "element", "alldifferent", "xor", "table", "cumulative", "circuit", "inverse"}
cpm_cons = decompose_in_tree(cpm_cons, supported)
cpm_cons = flatten_constraint(cpm_cons) # flat normal form
cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum'])) # constraints that support reification
cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum", "sub"])) # supports >, <, !=
cpm_cons = only_bv_implies(cpm_cons) # everything that can create
cpm_cons = flatten_constraint(cpm_cons,expr_dict=self.expr_dict) # flat normal form
cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum']),expr_dict=self.expr_dict) # constraints that support reification
cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum", "sub"]),expr_dict=self.expr_dict) # supports >, <, !=
cpm_cons = only_bv_implies(cpm_cons,expr_dict=self.expr_dict) # everything that can create
# reified expr must go before this
return cpm_cons

Expand Down
4 changes: 2 additions & 2 deletions cpmpy/solvers/pysat.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def transform(self, cpm_expr):
"""
cpm_cons = toplevel_list(cpm_expr)
cpm_cons = decompose_in_tree(cpm_cons)
cpm_cons = flatten_constraint(cpm_cons)
cpm_cons = only_bv_implies(cpm_cons)
cpm_cons = flatten_constraint(cpm_cons,expr_dict=self.expr_dict)
cpm_cons = only_bv_implies(cpm_cons,expr_dict=self.expr_dict)
return cpm_cons

def __add__(self, cpm_expr_orig):
Expand Down
1 change: 1 addition & 0 deletions cpmpy/solvers/solver_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, name="dummy", cpm_model=None, subsolver=None):
# initialise variable handling
self.user_vars = set() # variables in the original (non-transformed) model
self._varmap = dict() # maps cpmpy variables to native solver variables
self.expr_dict = dict() #maps expressions to cpmpy variables for cse purposes

# rest uses own API
if cpm_model is not None:
Expand Down
4 changes: 2 additions & 2 deletions cpmpy/transformations/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
- only_numexpr_equality(): transforms `NumExpr <op> IV` to `(NumExpr == A) & (A <op> IV)` if not supported
"""

def only_numexpr_equality(constraints, supported=frozenset()):
def only_numexpr_equality(constraints, supported=frozenset(),expr_dict={}):
"""
transforms `NumExpr <op> IV` to `(NumExpr == A) & (A <op> IV)` if not supported
Expand All @@ -34,7 +34,7 @@ def only_numexpr_equality(constraints, supported=frozenset()):
lhs = con.args[0]
if not isinstance(lhs, _NumVarImpl) and not lhs.name in supported:
# LHS is unsupported for LHS <op> IV, rewrite to `(LHS == A) & (A <op> IV)`
(lhsvar, lhscons) = get_or_make_var(lhs)
(lhsvar, lhscons) = get_or_make_var(lhs,expr_dict=expr_dict)
# replace comparison by A <op> IV
newcons[i] = Comparison(con.name, lhsvar, con.args[1])
# add lhscon(s), which will be [(LHS == A)]
Expand Down
30 changes: 15 additions & 15 deletions cpmpy/transformations/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

from ..expressions.variables import _BoolVarImpl, boolvar, NegBoolView, _NumVarImpl

def linearize_constraint(cpm_expr, supported={"sum","wsum"}, reified=False):
def linearize_constraint(cpm_expr, supported={"sum","wsum"}, reified=False, expr_dict={}):
"""
Transforms all constraints to a linear form.
This function assumes all constraints are in 'flat normal form' with only boolean variables on the lhs of an implication.
Expand All @@ -61,7 +61,7 @@ def linearize_constraint(cpm_expr, supported={"sum","wsum"}, reified=False):
"""

if is_any_list(cpm_expr):
lin_cons = [linearize_constraint(expr, supported=supported, reified=reified) for expr in cpm_expr]
lin_cons = [linearize_constraint(expr, supported=supported, reified=reified,expr_dict=expr_dict) for expr in cpm_expr]
return [c for l in lin_cons for c in l]

# boolvar
Expand Down Expand Up @@ -90,11 +90,11 @@ def linearize_constraint(cpm_expr, supported={"sum","wsum"}, reified=False):

if isinstance(cond, _BoolVarImpl) and isinstance(sub_expr, _BoolVarImpl):
# shortcut for BV -> BV, convert to disjunction and apply linearize on it
return linearize_constraint(cond <= sub_expr)
return linearize_constraint(cond <= sub_expr,expr_dict=expr_dict)

# BV -> LinExpr
if isinstance(cond, _BoolVarImpl):
lin_sub = linearize_constraint(sub_expr, supported=supported, reified=True)
lin_sub = linearize_constraint(sub_expr, supported=supported, reified=True,expr_dict=expr_dict)
return [cond.implies(lin) for lin in lin_sub]

# comparisons
Expand Down Expand Up @@ -153,11 +153,11 @@ def linearize_constraint(cpm_expr, supported={"sum","wsum"}, reified=False):

# now fix the comparisons themselves
if cpm_expr.name == "<":
new_rhs, cons = get_or_make_var(rhs - 1) # if rhs is constant, will return new constant
return [lhs <= new_rhs] + linearize_constraint(cons)
new_rhs, cons = get_or_make_var(rhs - 1,expr_dict=expr_dict) # if rhs is constant, will return new constant
return [lhs <= new_rhs] + linearize_constraint(cons,expr_dict=expr_dict)
if cpm_expr.name == ">":
new_rhs, cons = get_or_make_var(rhs + 1) # if rhs is constant, will return new constant
return [lhs >= new_rhs] + linearize_constraint(cons)
new_rhs, cons = get_or_make_var(rhs + 1,expr_dict=expr_dict) # if rhs is constant, will return new constant
return [lhs >= new_rhs] + linearize_constraint(cons,expr_dict=expr_dict)
if cpm_expr.name == "!=":
# Special case: BV != BV
if isinstance(lhs, _BoolVarImpl) and isinstance(rhs, _BoolVarImpl):
Expand All @@ -177,13 +177,13 @@ def linearize_constraint(cpm_expr, supported={"sum","wsum"}, reified=False):
_, M1 = (lhs - rhs + 1).get_bounds()
_, M2 = (rhs - lhs + 1).get_bounds()
cons = [lhs + -M1*z <= rhs-1, lhs + -M2*z >= rhs-M2+1]
return linearize_constraint(flatten_constraint(cons), supported=supported, reified=reified)
return linearize_constraint(flatten_constraint(cons), supported=supported, reified=reified,expr_dict=expr_dict)

else:
# introduce new indicator constraints
z = boolvar()
constraints = [z.implies(lhs < rhs), (~z).implies(lhs > rhs)]
return linearize_constraint(constraints, supported=supported, reified=reified)
return linearize_constraint(constraints, supported=supported, reified=reified,expr_dict=expr_dict)


return [Comparison(cpm_expr.name, lhs, rhs)]
Expand Down Expand Up @@ -215,15 +215,15 @@ def linearize_constraint(cpm_expr, supported={"sum","wsum"}, reified=False):
return [cpm_expr]


def only_positive_bv(cpm_expr):
def only_positive_bv(cpm_expr,expr_dict={}):
"""
Replaces constraints containing NegBoolView with equivalent expression using only BoolVar.
cpm_expr is expected to be linearized. Only apply after applying linearize_constraint(cpm_expr)
Resulting expression is linear.
"""
if is_any_list(cpm_expr):
nn_cons = [only_positive_bv(expr) for expr in cpm_expr]
nn_cons = [only_positive_bv(expr,expr_dict=expr_dict) for expr in cpm_expr]
return [c for l in nn_cons for c in l]

if isinstance(cpm_expr, Comparison):
Expand All @@ -249,18 +249,18 @@ def only_positive_bv(cpm_expr):
lhs = copy.copy(lhs)
for i,arg in enumerate(list(lhs.args)):
if isinstance(arg, NegBoolView):
new_arg, cons = get_or_make_var(1 - arg)
new_arg, cons = get_or_make_var(1 - arg,expr_dict=expr_dict)
lhs.args[i] = new_arg
new_cons += cons

return [Comparison(cpm_expr.name, lhs, rhs)] + linearize_constraint(new_cons)
return [Comparison(cpm_expr.name, lhs, rhs)] + linearize_constraint(new_cons,expr_dict=expr_dict)

# reification
if cpm_expr.name == "->":
cond, subexpr = cpm_expr.args
assert isinstance(cond, _BoolVarImpl), f"{cpm_expr} is not a supported linear expression. Apply `linearize_constraint` before calling `only_positive_bv`"
if isinstance(cond, _BoolVarImpl): # BV -> Expr
subexpr = only_positive_bv(subexpr)
subexpr = only_positive_bv(subexpr,expr_dict=expr_dict)
return[cond.implies(expr) for expr in subexpr]

if isinstance(cpm_expr, (GlobalConstraint, BoolVal, DirectConstraint)):
Expand Down
14 changes: 7 additions & 7 deletions cpmpy/transformations/reification.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
- reify_rewrite(): rewrites reifications not supported by a solver to ones that are
"""

def only_bv_implies(constraints):
def only_bv_implies(constraints,expr_dict={}):
"""
Transforms all reifications to BV -> BE form
Expand All @@ -44,15 +44,15 @@ def only_bv_implies(constraints):
# BE -> BV :: ~BV -> ~BE
newexpr = (~a1).implies(recurse_negation(a0))
#newexpr = (~a1).implies(~a0) # XXX when push_down_neg is separate, negated_normal no longer needed separately
newcons.extend(only_bv_implies(flatten_constraint(newexpr)))
newcons.extend(only_bv_implies(flatten_constraint(newexpr,expr_dict=expr_dict),expr_dict=expr_dict))
elif isinstance(a1, Comparison) and \
a1.name == '==' and a1.args[0].is_bool() and a1.args[1].is_bool():
# BV0 -> BV2 == BV3 :: BV0 -> (BV2->BV3 & BV3->BV2)
# :: BV0 -> (BV2->BV3) & BV0 -> (BV3->BV2)
# :: BV0 -> (~BV2|BV3) & BV0 -> (~BV3|BV2)
bv2,bv3 = a1.args
newexpr = [a0.implies(~bv2|bv3), a0.implies(~bv3|bv2)]
newcons.extend(only_bv_implies(flatten_constraint(newexpr)))
newcons.extend(only_bv_implies(flatten_constraint(newexpr,expr_dict=expr_dict),expr_dict=expr_dict))
else:
newcons.append(cpm_expr)

Expand All @@ -71,15 +71,15 @@ def only_bv_implies(constraints):
# BE0 == BVar1 :: ~BVar1 -> ~BE0, BVar1 -> BE0
newexprs = ((~a1).implies(recurse_negation(a0)), a1.implies(a0))
#newexprs = ((~a1).implies(~a0), a1.implies(a0)) # XXX when push_down_neg is separate, negated_normal no longer needed separately
newcons.extend(only_bv_implies(flatten_constraint(newexprs)))
newcons.extend(only_bv_implies(flatten_constraint(newexprs,expr_dict=expr_dict),expr_dict=expr_dict))
else:
# all other flat normal form expressions are fine
newcons.append(cpm_expr)

return newcons


def reify_rewrite(constraints, supported=frozenset()):
def reify_rewrite(constraints, supported=frozenset(),expr_dict={}):
"""
Rewrites reified constraints not natively supported by a solver,
to a version that uses standard constraints and reification over equalities between variables.
Expand Down Expand Up @@ -154,11 +154,11 @@ def reify_rewrite(constraints, supported=frozenset()):
# use IV < IV.lb which will be false...
decomp = (lhs.args[1] < lhs.args[1].lb)
reifexpr.args[boolexpr_index] = decomp
newcons += flatten_constraint(reifexpr)
newcons += flatten_constraint(reifexpr,expr_dict=expr_dict)
else: # other cases (assuming LHS is a total function):
# (AUX,c) = get_or_make_var(LHS)
# return c+[Comp(OP,AUX,RHS) == BV] or +[Comp(OP,AUX,RHS) -> BV] or +[Comp(OP,AUX,RHS) <- BV]
(auxvar, cons) = get_or_make_var(lhs)
(auxvar, cons) = get_or_make_var(lhs,expr_dict=expr_dict)
newcons += cons
reifexpr = copy.copy(cpm_expr)
reifexpr.args[boolexpr_index] = Comparison(op, auxvar, rhs) # Comp(OP,AUX,RHS)
Expand Down
6 changes: 3 additions & 3 deletions cpmpy/transformations/to_cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@
- BV -> BE
"""

def to_cnf(constraints):
def to_cnf(constraints,expr_dict={}):
"""
Converts all logical constraints into Conjunctive Normal Form
Arguments:
- constraints: list[Expression] or Operator
- supported: (frozen)set of global constraint names that do not need to be decomposed
"""
fnf = flatten_constraint(constraints)
fnf = only_bv_implies(fnf)
fnf = flatten_constraint(constraints,expr_dict=expr_dict)
fnf = only_bv_implies(fnf,expr_dict=expr_dict)
return flat2cnf(fnf)

def flat2cnf(constraints):
Expand Down

0 comments on commit fc6ecd5

Please sign in to comment.