diff --git a/cpmpy/solvers/exact.py b/cpmpy/solvers/exact.py index a90d73c40..0e080f49c 100644 --- a/cpmpy/solvers/exact.py +++ b/cpmpy/solvers/exact.py @@ -322,7 +322,7 @@ def objective(self, expr, minimize): self.objective_minimize = minimize # make objective function non-nested - (flat_obj, flat_cons) = flatten_objective(expr) + (flat_obj, flat_cons) = flatten_objective(expr,expr_dict=self.expr_dict) self += flat_cons # add potentially created constraints self.user_vars.update(get_variables(flat_obj)) # add objvars to vars @@ -404,12 +404,12 @@ def transform(self, cpm_expr): # expressions have to be linearized to fit in MIP model. See /transformations/linearize cpm_cons = toplevel_list(cpm_expr) cpm_cons = decompose_in_tree(cpm_cons, supported=frozenset({'alldifferent'})) # Alldiff has a specialzed MIP decomp - cpm_cons = flatten_constraint(cpm_cons) # flat normal form - cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum'])) # constraints that support reification - cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum"])) # supports >, <, != - cpm_cons = only_bv_implies(cpm_cons) # anything that can create full reif should go above... - cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum","wsum"})) # the core of the MIP-linearization - cpm_cons = only_positive_bv(cpm_cons) # after linearisation, rewrite ~bv into 1-bv + cpm_cons = flatten_constraint(cpm_cons,expr_dict=self.expr_dict) # flat normal form + cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum']),expr_dict=self.expr_dict) # constraints that support reification + cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum"]),expr_dict=self.expr_dict) # supports >, <, != + cpm_cons = only_bv_implies(cpm_cons,expr_dict=self.expr_dict) # anything that can create full reif should go above... + cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum","wsum"}),expr_dict=self.expr_dict) # the core of the MIP-linearization + cpm_cons = only_positive_bv(cpm_cons,expr_dict=self.expr_dict) # after linearisation, rewrite ~bv into 1-bv return cpm_cons # NOTE: the transformations that are still done specifically for Exact are two-fold: diff --git a/cpmpy/solvers/gurobi.py b/cpmpy/solvers/gurobi.py index 59b0ac012..8c1cf5e1e 100644 --- a/cpmpy/solvers/gurobi.py +++ b/cpmpy/solvers/gurobi.py @@ -208,7 +208,7 @@ def objective(self, expr, minimize=True): from gurobipy import GRB # make objective function non-nested - (flat_obj, flat_cons) = (flatten_objective(expr)) + (flat_obj, flat_cons) = (flatten_objective(expr,expr_dict=self.expr_dict)) self += flat_cons get_variables(flat_obj, collect=self.user_vars) # add potentially created constraints @@ -270,12 +270,12 @@ def transform(self, cpm_expr): cpm_cons = toplevel_list(cpm_expr) supported = {"min", "max", "abs", "alldifferent"} # alldiff has a specialized MIP decomp in linearize cpm_cons = decompose_in_tree(cpm_cons, supported) - cpm_cons = flatten_constraint(cpm_cons) # flat normal form - cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum'])) # constraints that support reification - cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum", "sub"])) # supports >, <, != - cpm_cons = only_bv_implies(cpm_cons) # anything that can create full reif should go above... - cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum", "wsum","sub","min","max","mul","abs","pow","div"})) # the core of the MIP-linearization - cpm_cons = only_positive_bv(cpm_cons) # after linearization, rewrite ~bv into 1-bv + cpm_cons = flatten_constraint(cpm_cons,expr_dict=self.expr_dict) # flat normal form + cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum']),expr_dict=self.expr_dict) # constraints that support reification + cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum", "sub"]),expr_dict=self.expr_dict) # supports >, <, != + cpm_cons = only_bv_implies(cpm_cons,expr_dict=self.expr_dict) # anything that can create full reif should go above... + cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum", "wsum","sub","min","max","mul","abs","pow","div"}),expr_dict=self.expr_dict) # the core of the MIP-linearization + cpm_cons = only_positive_bv(cpm_cons,expr_dict=self.expr_dict) # after linearization, rewrite ~bv into 1-bv return cpm_cons def __add__(self, cpm_expr_orig): diff --git a/cpmpy/solvers/ortools.py b/cpmpy/solvers/ortools.py index 30a8af276..85a8b2df3 100644 --- a/cpmpy/solvers/ortools.py +++ b/cpmpy/solvers/ortools.py @@ -266,7 +266,7 @@ def objective(self, expr, minimize): are premanently posted to the solver) """ # make objective function non-nested - (flat_obj, flat_cons) = flatten_objective(expr) + (flat_obj, flat_cons) = flatten_objective(expr,expr_dict=self.expr_dict) self += flat_cons # add potentially created constraints get_variables(flat_obj, collect=self.user_vars) # add objvars to vars @@ -331,10 +331,10 @@ def transform(self, cpm_expr): cpm_cons = toplevel_list(cpm_expr) supported = {"min", "max", "abs", "element", "alldifferent", "xor", "table", "cumulative", "circuit", "inverse"} cpm_cons = decompose_in_tree(cpm_cons, supported) - cpm_cons = flatten_constraint(cpm_cons) # flat normal form - cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum'])) # constraints that support reification - cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum", "sub"])) # supports >, <, != - cpm_cons = only_bv_implies(cpm_cons) # everything that can create + cpm_cons = flatten_constraint(cpm_cons,expr_dict=self.expr_dict) # flat normal form + cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum']),expr_dict=self.expr_dict) # constraints that support reification + cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum", "sub"]),expr_dict=self.expr_dict) # supports >, <, != + cpm_cons = only_bv_implies(cpm_cons,expr_dict=self.expr_dict) # everything that can create # reified expr must go before this return cpm_cons diff --git a/cpmpy/solvers/pysat.py b/cpmpy/solvers/pysat.py index 3047881d6..feeca2880 100644 --- a/cpmpy/solvers/pysat.py +++ b/cpmpy/solvers/pysat.py @@ -229,8 +229,8 @@ def transform(self, cpm_expr): """ cpm_cons = toplevel_list(cpm_expr) cpm_cons = decompose_in_tree(cpm_cons) - cpm_cons = flatten_constraint(cpm_cons) - cpm_cons = only_bv_implies(cpm_cons) + cpm_cons = flatten_constraint(cpm_cons,expr_dict=self.expr_dict) + cpm_cons = only_bv_implies(cpm_cons,expr_dict=self.expr_dict) return cpm_cons def __add__(self, cpm_expr_orig): diff --git a/cpmpy/solvers/solver_interface.py b/cpmpy/solvers/solver_interface.py index 9e905fc17..65186d3b0 100644 --- a/cpmpy/solvers/solver_interface.py +++ b/cpmpy/solvers/solver_interface.py @@ -75,6 +75,7 @@ def __init__(self, name="dummy", cpm_model=None, subsolver=None): # initialise variable handling self.user_vars = set() # variables in the original (non-transformed) model self._varmap = dict() # maps cpmpy variables to native solver variables + self.expr_dict = dict() #maps expressions to cpmpy variables for cse purposes # rest uses own API if cpm_model is not None: diff --git a/cpmpy/transformations/comparison.py b/cpmpy/transformations/comparison.py index 6f4f40b91..20d9482db 100644 --- a/cpmpy/transformations/comparison.py +++ b/cpmpy/transformations/comparison.py @@ -18,7 +18,7 @@ - only_numexpr_equality(): transforms `NumExpr IV` to `(NumExpr == A) & (A IV)` if not supported """ -def only_numexpr_equality(constraints, supported=frozenset()): +def only_numexpr_equality(constraints, supported=frozenset(),expr_dict={}): """ transforms `NumExpr IV` to `(NumExpr == A) & (A IV)` if not supported @@ -34,7 +34,7 @@ def only_numexpr_equality(constraints, supported=frozenset()): lhs = con.args[0] if not isinstance(lhs, _NumVarImpl) and not lhs.name in supported: # LHS is unsupported for LHS IV, rewrite to `(LHS == A) & (A IV)` - (lhsvar, lhscons) = get_or_make_var(lhs) + (lhsvar, lhscons) = get_or_make_var(lhs,expr_dict=expr_dict) # replace comparison by A IV newcons[i] = Comparison(con.name, lhsvar, con.args[1]) # add lhscon(s), which will be [(LHS == A)] diff --git a/cpmpy/transformations/linearize.py b/cpmpy/transformations/linearize.py index 48a675a66..56396836f 100644 --- a/cpmpy/transformations/linearize.py +++ b/cpmpy/transformations/linearize.py @@ -49,7 +49,7 @@ from ..expressions.variables import _BoolVarImpl, boolvar, NegBoolView, _NumVarImpl -def linearize_constraint(cpm_expr, supported={"sum","wsum"}, reified=False): +def linearize_constraint(cpm_expr, supported={"sum","wsum"}, reified=False, expr_dict={}): """ Transforms all constraints to a linear form. This function assumes all constraints are in 'flat normal form' with only boolean variables on the lhs of an implication. @@ -61,7 +61,7 @@ def linearize_constraint(cpm_expr, supported={"sum","wsum"}, reified=False): """ if is_any_list(cpm_expr): - lin_cons = [linearize_constraint(expr, supported=supported, reified=reified) for expr in cpm_expr] + lin_cons = [linearize_constraint(expr, supported=supported, reified=reified,expr_dict=expr_dict) for expr in cpm_expr] return [c for l in lin_cons for c in l] # boolvar @@ -90,11 +90,11 @@ def linearize_constraint(cpm_expr, supported={"sum","wsum"}, reified=False): if isinstance(cond, _BoolVarImpl) and isinstance(sub_expr, _BoolVarImpl): # shortcut for BV -> BV, convert to disjunction and apply linearize on it - return linearize_constraint(cond <= sub_expr) + return linearize_constraint(cond <= sub_expr,expr_dict=expr_dict) # BV -> LinExpr if isinstance(cond, _BoolVarImpl): - lin_sub = linearize_constraint(sub_expr, supported=supported, reified=True) + lin_sub = linearize_constraint(sub_expr, supported=supported, reified=True,expr_dict=expr_dict) return [cond.implies(lin) for lin in lin_sub] # comparisons @@ -153,11 +153,11 @@ def linearize_constraint(cpm_expr, supported={"sum","wsum"}, reified=False): # now fix the comparisons themselves if cpm_expr.name == "<": - new_rhs, cons = get_or_make_var(rhs - 1) # if rhs is constant, will return new constant - return [lhs <= new_rhs] + linearize_constraint(cons) + new_rhs, cons = get_or_make_var(rhs - 1,expr_dict=expr_dict) # if rhs is constant, will return new constant + return [lhs <= new_rhs] + linearize_constraint(cons,expr_dict=expr_dict) if cpm_expr.name == ">": - new_rhs, cons = get_or_make_var(rhs + 1) # if rhs is constant, will return new constant - return [lhs >= new_rhs] + linearize_constraint(cons) + new_rhs, cons = get_or_make_var(rhs + 1,expr_dict=expr_dict) # if rhs is constant, will return new constant + return [lhs >= new_rhs] + linearize_constraint(cons,expr_dict=expr_dict) if cpm_expr.name == "!=": # Special case: BV != BV if isinstance(lhs, _BoolVarImpl) and isinstance(rhs, _BoolVarImpl): @@ -177,13 +177,13 @@ def linearize_constraint(cpm_expr, supported={"sum","wsum"}, reified=False): _, M1 = (lhs - rhs + 1).get_bounds() _, M2 = (rhs - lhs + 1).get_bounds() cons = [lhs + -M1*z <= rhs-1, lhs + -M2*z >= rhs-M2+1] - return linearize_constraint(flatten_constraint(cons), supported=supported, reified=reified) + return linearize_constraint(flatten_constraint(cons), supported=supported, reified=reified,expr_dict=expr_dict) else: # introduce new indicator constraints z = boolvar() constraints = [z.implies(lhs < rhs), (~z).implies(lhs > rhs)] - return linearize_constraint(constraints, supported=supported, reified=reified) + return linearize_constraint(constraints, supported=supported, reified=reified,expr_dict=expr_dict) return [Comparison(cpm_expr.name, lhs, rhs)] @@ -215,7 +215,7 @@ def linearize_constraint(cpm_expr, supported={"sum","wsum"}, reified=False): return [cpm_expr] -def only_positive_bv(cpm_expr): +def only_positive_bv(cpm_expr,expr_dict={}): """ Replaces constraints containing NegBoolView with equivalent expression using only BoolVar. cpm_expr is expected to be linearized. Only apply after applying linearize_constraint(cpm_expr) @@ -223,7 +223,7 @@ def only_positive_bv(cpm_expr): Resulting expression is linear. """ if is_any_list(cpm_expr): - nn_cons = [only_positive_bv(expr) for expr in cpm_expr] + nn_cons = [only_positive_bv(expr,expr_dict=expr_dict) for expr in cpm_expr] return [c for l in nn_cons for c in l] if isinstance(cpm_expr, Comparison): @@ -249,18 +249,18 @@ def only_positive_bv(cpm_expr): lhs = copy.copy(lhs) for i,arg in enumerate(list(lhs.args)): if isinstance(arg, NegBoolView): - new_arg, cons = get_or_make_var(1 - arg) + new_arg, cons = get_or_make_var(1 - arg,expr_dict=expr_dict) lhs.args[i] = new_arg new_cons += cons - return [Comparison(cpm_expr.name, lhs, rhs)] + linearize_constraint(new_cons) + return [Comparison(cpm_expr.name, lhs, rhs)] + linearize_constraint(new_cons,expr_dict=expr_dict) # reification if cpm_expr.name == "->": cond, subexpr = cpm_expr.args assert isinstance(cond, _BoolVarImpl), f"{cpm_expr} is not a supported linear expression. Apply `linearize_constraint` before calling `only_positive_bv`" if isinstance(cond, _BoolVarImpl): # BV -> Expr - subexpr = only_positive_bv(subexpr) + subexpr = only_positive_bv(subexpr,expr_dict=expr_dict) return[cond.implies(expr) for expr in subexpr] if isinstance(cpm_expr, (GlobalConstraint, BoolVal, DirectConstraint)): diff --git a/cpmpy/transformations/reification.py b/cpmpy/transformations/reification.py index 57820e71e..142eb9fac 100644 --- a/cpmpy/transformations/reification.py +++ b/cpmpy/transformations/reification.py @@ -23,7 +23,7 @@ - reify_rewrite(): rewrites reifications not supported by a solver to ones that are """ -def only_bv_implies(constraints): +def only_bv_implies(constraints,expr_dict={}): """ Transforms all reifications to BV -> BE form @@ -44,7 +44,7 @@ def only_bv_implies(constraints): # BE -> BV :: ~BV -> ~BE newexpr = (~a1).implies(recurse_negation(a0)) #newexpr = (~a1).implies(~a0) # XXX when push_down_neg is separate, negated_normal no longer needed separately - newcons.extend(only_bv_implies(flatten_constraint(newexpr))) + newcons.extend(only_bv_implies(flatten_constraint(newexpr,expr_dict=expr_dict),expr_dict=expr_dict)) elif isinstance(a1, Comparison) and \ a1.name == '==' and a1.args[0].is_bool() and a1.args[1].is_bool(): # BV0 -> BV2 == BV3 :: BV0 -> (BV2->BV3 & BV3->BV2) @@ -52,7 +52,7 @@ def only_bv_implies(constraints): # :: BV0 -> (~BV2|BV3) & BV0 -> (~BV3|BV2) bv2,bv3 = a1.args newexpr = [a0.implies(~bv2|bv3), a0.implies(~bv3|bv2)] - newcons.extend(only_bv_implies(flatten_constraint(newexpr))) + newcons.extend(only_bv_implies(flatten_constraint(newexpr,expr_dict=expr_dict),expr_dict=expr_dict)) else: newcons.append(cpm_expr) @@ -71,7 +71,7 @@ def only_bv_implies(constraints): # BE0 == BVar1 :: ~BVar1 -> ~BE0, BVar1 -> BE0 newexprs = ((~a1).implies(recurse_negation(a0)), a1.implies(a0)) #newexprs = ((~a1).implies(~a0), a1.implies(a0)) # XXX when push_down_neg is separate, negated_normal no longer needed separately - newcons.extend(only_bv_implies(flatten_constraint(newexprs))) + newcons.extend(only_bv_implies(flatten_constraint(newexprs,expr_dict=expr_dict),expr_dict=expr_dict)) else: # all other flat normal form expressions are fine newcons.append(cpm_expr) @@ -79,7 +79,7 @@ def only_bv_implies(constraints): return newcons -def reify_rewrite(constraints, supported=frozenset()): +def reify_rewrite(constraints, supported=frozenset(),expr_dict={}): """ Rewrites reified constraints not natively supported by a solver, to a version that uses standard constraints and reification over equalities between variables. @@ -154,11 +154,11 @@ def reify_rewrite(constraints, supported=frozenset()): # use IV < IV.lb which will be false... decomp = (lhs.args[1] < lhs.args[1].lb) reifexpr.args[boolexpr_index] = decomp - newcons += flatten_constraint(reifexpr) + newcons += flatten_constraint(reifexpr,expr_dict=expr_dict) else: # other cases (assuming LHS is a total function): # (AUX,c) = get_or_make_var(LHS) # return c+[Comp(OP,AUX,RHS) == BV] or +[Comp(OP,AUX,RHS) -> BV] or +[Comp(OP,AUX,RHS) <- BV] - (auxvar, cons) = get_or_make_var(lhs) + (auxvar, cons) = get_or_make_var(lhs,expr_dict=expr_dict) newcons += cons reifexpr = copy.copy(cpm_expr) reifexpr.args[boolexpr_index] = Comparison(op, auxvar, rhs) # Comp(OP,AUX,RHS) diff --git a/cpmpy/transformations/to_cnf.py b/cpmpy/transformations/to_cnf.py index 4db6ce6c4..29d0bc4c7 100644 --- a/cpmpy/transformations/to_cnf.py +++ b/cpmpy/transformations/to_cnf.py @@ -23,7 +23,7 @@ - BV -> BE """ -def to_cnf(constraints): +def to_cnf(constraints,expr_dict={}): """ Converts all logical constraints into Conjunctive Normal Form @@ -31,8 +31,8 @@ def to_cnf(constraints): - constraints: list[Expression] or Operator - supported: (frozen)set of global constraint names that do not need to be decomposed """ - fnf = flatten_constraint(constraints) - fnf = only_bv_implies(fnf) + fnf = flatten_constraint(constraints,expr_dict=expr_dict) + fnf = only_bv_implies(fnf,expr_dict=expr_dict) return flat2cnf(fnf) def flat2cnf(constraints):