diff --git a/cpmpy/expressions/core.py b/cpmpy/expressions/core.py index 1b210e680..b2a01ba62 100644 --- a/cpmpy/expressions/core.py +++ b/cpmpy/expressions/core.py @@ -72,7 +72,7 @@ import numpy as np -from .utils import is_num, is_any_list, flatlist, argval, get_bounds, is_boolexpr, is_true_cst, is_false_cst +from .utils import is_num, is_any_list, flatlist, argval, get_bounds, is_boolexpr, is_true_cst, is_false_cst, argvals from ..exceptions import IncompleteFunctionError, TypeError @@ -144,6 +144,7 @@ def is_bool(self): def value(self): return None # default + def get_bounds(self): if self.is_bool(): return 0, 1 #default for boolean expressions @@ -400,7 +401,8 @@ def __repr__(self): # return the value of the expression # optional, default: None def value(self): - arg_vals = [argval(a) for a in self.args] + arg_vals = argvals(self.args) + if any(a is None for a in arg_vals): return None if self.name == "==": return arg_vals[0] == arg_vals[1] elif self.name == "!=": return arg_vals[0] != arg_vals[1] @@ -526,11 +528,12 @@ def wrap_bracket(arg): return "{}({})".format(self.name, self.args) def value(self): + if self.name == "wsum": # wsum: arg0 is list of constants, no .value() use as is - arg_vals = [self.args[0], [argval(arg) for arg in self.args[1]]] + arg_vals = [self.args[0], argvals(self.args[1])] else: - arg_vals = [argval(arg) for arg in self.args] + arg_vals = argvals(self.args) if any(a is None for a in arg_vals): return None @@ -546,7 +549,8 @@ def value(self): try: return arg_vals[0] // arg_vals[1] except ZeroDivisionError: - raise IncompleteFunctionError(f"Division by zero during value computation for expression {self}") + raise IncompleteFunctionError(f"Division by zero during value computation for expression {self}" + + "\n Use argval(expr) to get the value of expr with relational semantics.") # boolean elif self.name == "and": return all(arg_vals) diff --git a/cpmpy/expressions/globalconstraints.py b/cpmpy/expressions/globalconstraints.py index 9c6fe3933..d193567e9 100644 --- a/cpmpy/expressions/globalconstraints.py +++ b/cpmpy/expressions/globalconstraints.py @@ -120,7 +120,7 @@ def my_circuit_decomp(self): from ..exceptions import CPMpyException, IncompleteFunctionError, TypeError from .core import Expression, Operator, Comparison from .variables import boolvar, intvar, cpm_array, _NumVarImpl, _IntVarImpl -from .utils import flatlist, all_pairs, argval, is_num, eval_comparison, is_any_list, is_boolexpr, get_bounds +from .utils import flatlist, all_pairs, argval, is_num, eval_comparison, is_any_list, is_boolexpr, get_bounds, argvals from .globalfunctions import * # XXX make this file backwards compatible @@ -178,7 +178,7 @@ def decompose(self): return [var1 != var2 for var1, var2 in all_pairs(self.args)], [] def value(self): - return len(set(a.value() for a in self.args)) == len(self.args) + return len(set(argvals(self.args))) == len(self.args) class AllDifferentExcept0(GlobalConstraint): @@ -193,10 +193,9 @@ def decompose(self): return [(var1 == var2).implies(var1 == 0) for var1, var2 in all_pairs(self.args)], [] def value(self): - vals = [a.value() for a in self.args if a.value() != 0] + vals = [argval(a) for a in self.args if argval(a) != 0] return len(set(vals)) == len(vals) - def allequal(args): warnings.warn("Deprecated, use AllEqual(v1,v2,...,vn) instead, will be removed in stable version", DeprecationWarning) return AllEqual(*args) # unfold list as individual arguments @@ -215,8 +214,7 @@ def decompose(self): return [var1 == var2 for var1, var2 in zip(self.args[:-1], self.args[1:])], [] def value(self): - return len(set(a.value() for a in self.args)) == 1 - + return len(set(argvals(self.args))) == 1 def circuit(args): warnings.warn("Deprecated, use Circuit(v1,v2,...,vn) instead, will be removed in stable version", DeprecationWarning) @@ -259,7 +257,8 @@ def value(self): pathlen = 0 idx = 0 visited = set() - arr = [argval(a) for a in self.args] + arr = argvals(self.args) + while idx not in visited: if idx is None: return False @@ -294,9 +293,13 @@ def decompose(self): return [all(rev[x] == i for i, x in enumerate(fwd))], [] def value(self): - fwd = [argval(a) for a in self.args[0]] - rev = [argval(a) for a in self.args[1]] - return all(rev[x] == i for i, x in enumerate(fwd)) + fwd = argvals(self.args[0]) + rev = argvals(self.args[1]) + # args are fine, now evaluate actual inverse cons + try: + return all(rev[x] == i for i, x in enumerate(fwd)) + except IndexError: # partiality of Element constraint + return False class Table(GlobalConstraint): @@ -315,10 +318,11 @@ def decompose(self): def value(self): arr, tab = self.args - arrval = [argval(a) for a in arr] + arrval = argvals(arr) return arrval in tab + # syntax of the form 'if b then x == 9 else x == 0' is not supported (no override possible) # same semantic as CPLEX IfThenElse constraint # https://www.ibm.com/docs/en/icos/12.9.0?topic=methods-ifthenelse-method @@ -379,7 +383,7 @@ def decompose(self): def value(self): - return argval(self.args[0]) in argval(self.args[1]) + return argval(self.args[0]) in argvals(self.args[1]) def __repr__(self): return "{} in {}".format(self.args[0], self.args[1]) @@ -409,7 +413,7 @@ def decompose(self): return decomp, [] def value(self): - return sum(argval(a) for a in self.args) % 2 == 1 + return sum(argvals(self.args)) % 2 == 1 def __repr__(self): if len(self.args) == 2: @@ -478,14 +482,14 @@ def decompose(self): return cons, [] def value(self): - argvals = [np.array([argval(a) for a in arg]) if is_any_list(arg) + arg_vals = [np.array(argvals(arg)) if is_any_list(arg) else argval(arg) for arg in self.args] - if any(a is None for a in argvals): + if any(a is None for a in arg_vals): return None # start, dur, end are np arrays - start, dur, end, demand, capacity = argvals + start, dur, end, demand, capacity = arg_vals # start and end seperated by duration if not (start + dur == end).all(): return False diff --git a/cpmpy/expressions/globalfunctions.py b/cpmpy/expressions/globalfunctions.py index 1562ee5fa..60cd15c74 100644 --- a/cpmpy/expressions/globalfunctions.py +++ b/cpmpy/expressions/globalfunctions.py @@ -257,7 +257,8 @@ def value(self): if idxval is not None: if idxval >= 0 and idxval < len(arr): return argval(arr[idxval]) - raise IncompleteFunctionError(f"Index {idxval} out of range for array of length {len(arr)} while calculating value for expression {self}") + raise IncompleteFunctionError(f"Index {idxval} out of range for array of length {len(arr)} while calculating value for expression {self}" + + "\n Use argval(expr) to get the value of expr with relational semantics.") return None # default def decompose_comparison(self, cpm_op, cpm_rhs): diff --git a/cpmpy/expressions/utils.py b/cpmpy/expressions/utils.py index a42e88c82..e5a78ed95 100644 --- a/cpmpy/expressions/utils.py +++ b/cpmpy/expressions/utils.py @@ -121,11 +121,19 @@ def argval(a): We check with hasattr instead of isinstance to avoid circular dependency """ - try: - return a.value() if hasattr(a, "value") else a - except IncompleteFunctionError as e: - if a.is_bool(): return False - raise e + if hasattr(a, "value"): + try: + return a.value() + except IncompleteFunctionError as e: + if a.is_bool(): + return False + else: + raise e + return a + + +def argvals(arr): + return [argval(a) for a in arr] def eval_comparison(str_op, lhs, rhs): diff --git a/cpmpy/solvers/choco.py b/cpmpy/solvers/choco.py index d2d502794..3869df8be 100644 --- a/cpmpy/solvers/choco.py +++ b/cpmpy/solvers/choco.py @@ -29,7 +29,7 @@ from ..expressions.globalconstraints import DirectConstraint from ..expressions.variables import _NumVarImpl, _IntVarImpl, _BoolVarImpl, NegBoolView, intvar from ..expressions.globalconstraints import GlobalConstraint -from ..expressions.utils import is_num, is_int, is_boolexpr, is_any_list, get_bounds +from ..expressions.utils import is_num, is_int, is_boolexpr, is_any_list, get_bounds, argval, argvals from ..transformations.decompose_global import decompose_in_tree from ..transformations.get_variables import get_variables from ..transformations.flatten_model import flatten_constraint, flatten_objective @@ -208,9 +208,9 @@ def solveAll(self, display=None, time_limit=None, solution_limit=None, call_from cpm_var._value = value # print the desired display if isinstance(display, Expression): - print(display.value()) + print(argval(display)) elif isinstance(display, list): - print([v.value() for v in display]) + print(argvals(display)) else: display() # callback @@ -349,6 +349,7 @@ def __add__(self, cpm_expr): """ # add new user vars to the set get_variables(cpm_expr, collect=self.user_vars) + # ensure all vars are known to solver # transform and post the constraints for con in self.transform(cpm_expr): diff --git a/cpmpy/solvers/exact.py b/cpmpy/solvers/exact.py index 2d1c37ffd..5f0598cfb 100644 --- a/cpmpy/solvers/exact.py +++ b/cpmpy/solvers/exact.py @@ -37,7 +37,7 @@ from ..transformations.normalize import toplevel_list from ..expressions.globalconstraints import DirectConstraint from ..exceptions import NotSupportedError -from ..expressions.utils import flatlist +from ..expressions.utils import flatlist, argvals import numpy as np import numbers @@ -263,9 +263,9 @@ def solveAll(self, display=None, time_limit=None, solution_limit=None, call_from if display is not None: self._fillObjAndVars() if isinstance(display, Expression): - print(display.value()) + print(argval(display)) elif isinstance(display, list): - print([v.value() for v in display]) + print(argvals(display)) else: display() # callback elif my_status == 2: # found inconsistency diff --git a/cpmpy/solvers/gurobi.py b/cpmpy/solvers/gurobi.py index 7223ad33f..dbe3618bf 100644 --- a/cpmpy/solvers/gurobi.py +++ b/cpmpy/solvers/gurobi.py @@ -28,6 +28,7 @@ from .solver_interface import SolverInterface, SolverStatus, ExitStatus from ..expressions.core import * +from ..expressions.utils import argvals from ..expressions.variables import _BoolVarImpl, NegBoolView, _IntVarImpl, _NumVarImpl, intvar from ..expressions.globalconstraints import DirectConstraint from ..transformations.comparison import only_numexpr_equality @@ -461,9 +462,9 @@ def solveAll(self, display=None, time_limit=None, solution_limit=None, call_from if display is not None: if isinstance(display, Expression): - print(display.value()) + print(argval(display)) elif isinstance(display, list): - print([v.value() for v in display]) + print(argvals(display)) else: display() # callback diff --git a/cpmpy/solvers/minizinc.py b/cpmpy/solvers/minizinc.py index 0bb6bd9fd..d7335cf7c 100644 --- a/cpmpy/solvers/minizinc.py +++ b/cpmpy/solvers/minizinc.py @@ -36,7 +36,7 @@ from ..expressions.core import Expression, Comparison, Operator, BoolVal from ..expressions.variables import _NumVarImpl, _IntVarImpl, _BoolVarImpl, NegBoolView, intvar from ..expressions.globalconstraints import DirectConstraint -from ..expressions.utils import is_num, is_any_list, eval_comparison +from ..expressions.utils import is_num, is_any_list, eval_comparison, argvals, argval from ..transformations.decompose_global import decompose_in_tree from ..transformations.get_variables import get_variables from ..exceptions import MinizincPathException, NotSupportedError @@ -323,9 +323,9 @@ async def _solveAll(self, display=None, time_limit=None, solution_limit=None, ** # and the actual displaying if isinstance(display, Expression): - print(display.value()) + print(argval(display)) elif isinstance(display, list): - print([v.value() for v in display]) + print(argvals(display)) else: display() # callback diff --git a/cpmpy/solvers/ortools.py b/cpmpy/solvers/ortools.py index 175440770..c70dcbe8f 100644 --- a/cpmpy/solvers/ortools.py +++ b/cpmpy/solvers/ortools.py @@ -32,7 +32,7 @@ from ..expressions.globalconstraints import DirectConstraint from ..expressions.variables import _NumVarImpl, _IntVarImpl, _BoolVarImpl, NegBoolView, boolvar from ..expressions.globalconstraints import GlobalConstraint -from ..expressions.utils import is_num, is_any_list, eval_comparison, flatlist +from ..expressions.utils import is_num, is_any_list, eval_comparison, flatlist, argval, argvals from ..transformations.decompose_global import decompose_in_tree from ..transformations.get_variables import get_variables from ..transformations.flatten_model import flatten_constraint, flatten_objective @@ -689,10 +689,10 @@ def on_solution_callback(self): cpm_var._value = self.Value(self._varmap[cpm_var]) if isinstance(self._display, Expression): - print(self._display.value()) + print(argval(self._display)) elif isinstance(self._display, list): # explicit list of expressions to display - print([v.value() for v in self._display]) + print(argvals(self._display)) else: # callable self._display() diff --git a/cpmpy/solvers/pysat.py b/cpmpy/solvers/pysat.py index 32c0b1edf..889e48aa8 100644 --- a/cpmpy/solvers/pysat.py +++ b/cpmpy/solvers/pysat.py @@ -38,7 +38,7 @@ from ..transformations.decompose_global import decompose_in_tree from ..transformations.get_variables import get_variables from ..transformations.flatten_model import flatten_constraint -from ..transformations.normalize import toplevel_list +from ..transformations.normalize import toplevel_list, simplify_boolean from ..transformations.reification import only_implies, only_bv_reifies @@ -233,6 +233,7 @@ def transform(self, cpm_expr): """ cpm_cons = toplevel_list(cpm_expr) cpm_cons = decompose_in_tree(cpm_cons) + cpm_cons = simplify_boolean(cpm_cons) cpm_cons = flatten_constraint(cpm_cons) cpm_cons = only_bv_reifies(cpm_cons) cpm_cons = only_implies(cpm_cons) diff --git a/cpmpy/solvers/pysdd.py b/cpmpy/solvers/pysdd.py index 03003256d..a00800121 100644 --- a/cpmpy/solvers/pysdd.py +++ b/cpmpy/solvers/pysdd.py @@ -26,7 +26,7 @@ from ..expressions.core import Expression, Comparison, Operator, BoolVal from ..expressions.variables import _BoolVarImpl, NegBoolView, boolvar from ..expressions.globalconstraints import DirectConstraint -from ..expressions.utils import is_any_list, is_bool +from ..expressions.utils import is_any_list, is_bool, argval, argvals from ..transformations.decompose_global import decompose_in_tree from ..transformations.get_variables import get_variables from ..transformations.normalize import toplevel_list, simplify_boolean @@ -121,7 +121,8 @@ def solve(self, time_limit=None, assumptions=None): if lit in sol: cpm_var._value = bool(sol[lit]) else: - raise ValueError(f"Var {cpm_var} is unknown to the PySDD solver, this is unexpected - please report on github...") + cpm_var._value = cpm_var.get_bounds()[0] # dummy value - TODO: ensure Pysdd assigns an actual value + # cpm_var._value = None # not specified... return has_sol @@ -177,9 +178,9 @@ def solveAll(self, display=None, time_limit=None, solution_limit=None, call_from # display is not None: if isinstance(display, Expression): - print(display.value()) + print(argval(display)) elif isinstance(display, list): - print([v.value() for v in display]) + print(argvals(display)) else: display() # callback return solution_count diff --git a/cpmpy/transformations/normalize.py b/cpmpy/transformations/normalize.py index ec0c5e756..c69ac77cb 100644 --- a/cpmpy/transformations/normalize.py +++ b/cpmpy/transformations/normalize.py @@ -132,15 +132,15 @@ def simplify_boolean(lst_of_expr, num_context=False): if name == "==" or name == "<=": newlist.append(recurse_negation(lhs)) if name == "<": - newlist.append(BoolVal(False)) + newlist.append(0 if num_context else BoolVal(False)) if name == ">=": - newlist.append(BoolVal(True)) + newlist.append(1 if num_context else BoolVal(True)) elif 0 < rhs < 1: # a floating point value if name == "==": - newlist.append(BoolVal(False)) + newlist.append(0 if num_context else BoolVal(False)) if name == "!=": - newlist.append(BoolVal(True)) + newlist.append(1 if num_context else BoolVal(True)) if name == "<" or name == "<=": newlist.append(recurse_negation(lhs)) if name == ">" or name == ">=": @@ -151,9 +151,9 @@ def simplify_boolean(lst_of_expr, num_context=False): if name == "!=" or name == "<": newlist.append(recurse_negation(lhs)) if name == ">": - newlist.append(BoolVal(False)) + newlist.append(0 if num_context else BoolVal(False)) if name == "<=": - newlist.append(BoolVal(True)) + newlist.append(1 if num_context else BoolVal(True)) elif rhs > 1: newlist.append(BoolVal(name in {"!=", "<", "<="})) # all other operators evaluate to False else: diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 062d727ca..b777ce4e7 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -1,3 +1,6 @@ +import inspect + +import cpmpy from cpmpy import Model, SolverLookup, BoolVal from cpmpy.expressions.globalconstraints import * from cpmpy.expressions.globalfunctions import * @@ -8,16 +11,27 @@ # make sure that `SolverLookup.get(solver)` works # also add exclusions to the 3 EXCLUDE_* below as needed SOLVERNAMES = [name for name, solver in SolverLookup.base_solvers() if solver.supported()] +ALL_SOLS = False # test wheter all solutions returned by the solver satisfy the constraint # Exclude some global constraints for solvers -# Can be used when .value() method is not implemented/contains bugs -EXCLUDE_GLOBAL = {"ortools": {}, - "gurobi": {}, - "minizinc": {"circuit"}, - "pysat": {"circuit", "element","min","max","count", "nvalue", "allequal","alldifferent","cumulative","increasing","decreasing","strictly_increasing","strictly_decreasing"}, - "pysdd": {"circuit", "element","min","max","count", "nvalue", "allequal","alldifferent","cumulative","xor","increasing","decreasing","strictly_increasing","strictly_decreasing"}, - "exact": {}, - "choco": {} + +NUM_GLOBAL = { + "AllEqual", "AllDifferent", "AllDifferentExcept0", "Cumulative", "GlobalCardinalityCount", "InDomain", "Inverse", "Table", "Circuit", + "Increasing", "IncreasingStrict", "Decreasing", "DecreasingStrict", + # also global functions + "Abs", "Element", "Minimum", "Maximum", "Count", "NValue", "NValueExcept" +} + +# Solvers not supporting arithmetic constraints +SAT_SOLVERS = {"pysat", "pysdd"} + +EXCLUDE_GLOBAL = {"pysat": NUM_GLOBAL, + "pysdd": NUM_GLOBAL | {"Xor"}, + "z3": {"Inverse"}, + "choco": {"Inverse"}, + "ortools":{"Inverse"}, + "exact": {"Inverse"}, + "minizinc": {"IncreasingStrict"} # bug #813 reported on libminizinc } # Exclude certain operators for solvers. @@ -28,16 +42,6 @@ "exact": {"mod","pow","div","mul"}, } -# Some solvers only support a subset of operators in imply-constraints -# This subset can differ between left and right hand side of the implication -EXCLUDE_IMPL = {"ortools": {}, - "minizinc": {}, - "z3": {}, - "pysat": {}, - "pysdd": {}, - "exact": {"mod","pow","div","mul"}, - } - # Variables to use in the rest of the test script NUM_ARGS = [intvar(-3, 5, name=n) for n in "xyz"] # Numerical variables NN_VAR = intvar(0, 10, name="n_neg") # Non-negative variable, needed in power functions @@ -60,8 +64,7 @@ def numexprs(solver): Numexpr: - Operator (non-Boolean) with all args Var/constant (examples: +,*,/,mod,wsum) (CPMpy class 'Operator', not is_bool()) - - Global constraint (non-Boolean) (examples: Max,Min,Element) - (CPMpy class 'GlobalConstraint', not is_bool())) + - Global functions (examples: Max,Min,Element) (CPMpy class 'GlobalFunction') """ names = [(name, arity) for name, (arity, is_bool) in Operator.allowed.items() if not is_bool] if solver in EXCLUDE_OPERATORS: @@ -80,6 +83,33 @@ def numexprs(solver): yield Operator(name, operator_args) + # boolexprs are also numeric + for expr in bool_exprs(solver): + yield expr + + # also global functions + classes = inspect.getmembers(cpmpy.expressions.globalfunctions, inspect.isclass) + classes = [(name, cls) for name, cls in classes if issubclass(cls, GlobalFunction) and name != "GlobalFunction"] + classes = [(name, cls) for name, cls in classes if name not in EXCLUDE_GLOBAL.get(solver, {})] + + for name, cls in classes: + if name == "Abs": + expr = cls(NUM_ARGS[0]) + elif name == "Count": + expr = cls(NUM_ARGS, NUM_VAR) + elif name == "Element": + expr = cls(NUM_ARGS, POS_VAR) + elif name == "NValueExcept": + expr = cls(NUM_ARGS, 3) + else: + expr = cls(NUM_ARGS) + + if solver in EXCLUDE_GLOBAL and expr.name in EXCLUDE_GLOBAL[solver]: + continue + else: + yield expr + + # Generate all possible comparison constraints def comp_constraints(solver): @@ -92,30 +122,19 @@ def comp_constraints(solver): - Numeric inequality (>=,>,<,<=): Numexpr >=< Var (CPMpy class 'Comparison') """ for comp_name in Comparison.allowed: + for numexpr in numexprs(solver): - for rhs in [NUM_VAR, BOOL_VAR, 1, BoolVal(True)]: + # numeric vs bool/num var/val (incl global func) + lb, ub = get_bounds(numexpr) + for rhs in [NUM_VAR, BOOL_VAR, BoolVal(True), 1]: + if solver in SAT_SOLVERS and not is_num(rhs): + continue + if comp_name == ">" and ub <= get_bounds(rhs)[1]: + continue + if comp_name == "<" and lb >= get_bounds(rhs)[0]: + continue yield Comparison(comp_name, numexpr, rhs) - for comp_name in Comparison.allowed: - for glob_expr in global_constraints(solver): - if not glob_expr.is_bool(): - for rhs in [NUM_VAR, BOOL_VAR, 1, BoolVal(True)]: - if comp_name == "<" and get_bounds(glob_expr)[0] >= get_bounds(rhs)[1]: - continue - yield Comparison(comp_name, glob_expr, rhs) - - if solver == "z3": - for comp_name in Comparison.allowed: - for boolexpr in bool_exprs(solver): - for rhs in [NUM_VAR, BOOL_VAR, 1, BoolVal(True)]: - if comp_name == '>': - # >1 is unsat for boolean expressions, so change it to 0 - if isinstance(rhs, int) and rhs == 1: - rhs = 0 - if isinstance(rhs, BoolVal) and rhs.args[0] == True: - rhs = BoolVal(False) - yield Comparison(comp_name, boolexpr, rhs) - # Generate all possible boolean expressions def bool_exprs(solver): @@ -123,6 +142,7 @@ def bool_exprs(solver): Generate all boolean expressions: - Boolean operators: and([Var]), or([Var]) (CPMpy class 'Operator', is_bool()) - Boolean equality: Var == Var (CPMpy class 'Comparison') + - Global constraints """ names = [(name, arity) for name, (arity, is_bool) in Operator.allowed.items() if is_bool] @@ -143,8 +163,7 @@ def bool_exprs(solver): yield Comparison(eq_name, *BOOL_ARGS[:2]) for cpm_cons in global_constraints(solver): - if cpm_cons.is_bool(): - yield cpm_cons + yield cpm_cons def global_constraints(solver): """ @@ -152,30 +171,40 @@ def global_constraints(solver): - AllDifferent, AllEqual, Circuit, Minimum, Maximum, Element, Xor, Cumulative, NValue, Count """ - global_cons = [AllDifferent, AllEqual, Minimum, Maximum, NValue, Increasing, Decreasing, IncreasingStrict, DecreasingStrict] - for global_type in global_cons: - cons = global_type(NUM_ARGS) - if solver not in EXCLUDE_GLOBAL or cons.name not in EXCLUDE_GLOBAL[solver]: - yield cons - - # "special" constructors - if solver not in EXCLUDE_GLOBAL or "element" not in EXCLUDE_GLOBAL[solver]: - yield cpm_array(NUM_ARGS)[NUM_VAR] - - if solver not in EXCLUDE_GLOBAL or "xor" not in EXCLUDE_GLOBAL[solver]: - yield Xor(BOOL_ARGS) - - if solver not in EXCLUDE_GLOBAL or "count" not in EXCLUDE_GLOBAL[solver]: - yield Count(NUM_ARGS, NUM_VAR) - - if solver not in EXCLUDE_GLOBAL or "cumulative" not in EXCLUDE_GLOBAL[solver]: - s = intvar(0,10,shape=3,name="start") - e = intvar(0,10,shape=3,name="end") - dur = [1,4,3] - demand = [4,5,7] - cap = 10 - yield Cumulative(s, dur, e, demand, cap) - + classes = inspect.getmembers(cpmpy.expressions.globalconstraints, inspect.isclass) + classes = [(name, cls) for name, cls in classes if issubclass(cls, GlobalConstraint) and name != "GlobalConstraint"] + classes = [(name, cls) for name, cls in classes if name not in EXCLUDE_GLOBAL.get(solver, {})] + + for name, cls in classes: + + if name == "Xor": + expr = cls(BOOL_ARGS) + elif name == "Inverse": + expr = cls(NUM_ARGS, [1,0,2]) + elif name == "Table": + expr = cls(NUM_ARGS, [[0,1,2],[1,2,0],[1,0,2]]) + elif name == "IfThenElse": + expr = cls(*BOOL_ARGS) + elif name == "InDomain": + expr = cls(NUM_VAR, [0,1,6]) + elif name == "Cumulative": + s = intvar(0, 10, shape=3, name="start") + e = intvar(0, 10, shape=3, name="end") + dur = [1, 4, 3] + demand = [4, 5, 7] + cap = 10 + expr = Cumulative(s, dur, e, demand, cap) + elif name == "GlobalCardinalityCount": + vals = [1, 2, 3] + cnts = intvar(0,10,shape=3) + expr = cls(NUM_ARGS, vals, cnts) + else: # default constructor, list of numvars + expr= cls(NUM_ARGS) + + if solver in EXCLUDE_GLOBAL and name in EXCLUDE_GLOBAL[solver]: + continue + else: + yield expr def reify_imply_exprs(solver): """ @@ -184,44 +213,59 @@ def reify_imply_exprs(solver): Var -> Boolexpr (CPMpy class 'Operator', is_bool()) """ for bool_expr in bool_exprs(solver): - if solver not in EXCLUDE_IMPL or \ - bool_expr.name not in EXCLUDE_IMPL[solver]: - yield bool_expr.implies(BOOL_VAR) - yield BOOL_VAR.implies(bool_expr) - yield bool_expr == BOOL_VAR + yield bool_expr.implies(BOOL_VAR) + yield BOOL_VAR.implies(bool_expr) + yield bool_expr == BOOL_VAR for comp_expr in comp_constraints(solver): lhs, rhs = comp_expr.args - if solver not in EXCLUDE_IMPL or \ - (not isinstance(lhs, Expression) or lhs.name not in EXCLUDE_IMPL[solver]) and \ - (not isinstance(rhs, Expression) or rhs.name not in EXCLUDE_IMPL[solver]): - yield comp_expr.implies(BOOL_VAR) - yield BOOL_VAR.implies(comp_expr) - yield comp_expr == BOOL_VAR + yield comp_expr.implies(BOOL_VAR) + yield BOOL_VAR.implies(comp_expr) + yield comp_expr == BOOL_VAR + + +def verify(cons): + assert argval(cons) + assert cons.value() -@pytest.mark.parametrize(("solver","constraint"),_generate_inputs(bool_exprs), ids=str) +@pytest.mark.parametrize(("solver","constraint"),list(_generate_inputs(bool_exprs)), ids=str) def test_bool_constaints(solver, constraint): """ Tests boolean constraint by posting it to the solver and checking the value after solve. """ - assert SolverLookup.get(solver, Model(constraint)).solve() - assert constraint.value() + if ALL_SOLS: + n_sols = SolverLookup.get(solver, Model(constraint)).solveAll(display=lambda: verify(constraint)) + assert n_sols >= 1 + else: + assert SolverLookup.get(solver, Model(constraint)).solve() + assert argval(constraint) + assert constraint.value() -@pytest.mark.parametrize(("solver","constraint"), _generate_inputs(comp_constraints), ids=str) +@pytest.mark.parametrize(("solver","constraint"), list(_generate_inputs(comp_constraints)), ids=str) def test_comparison_constraints(solver, constraint): """ Tests comparison constraint by posting it to the solver and checking the value after solve. """ - assert SolverLookup.get(solver,Model(constraint)).solve() - assert constraint.value() + if ALL_SOLS: + n_sols = SolverLookup.get(solver, Model(constraint)).solveAll(display= lambda: verify(constraint)) + assert n_sols >= 1 + else: + assert SolverLookup.get(solver,Model(constraint)).solve() + assert argval(constraint) + assert constraint.value() -@pytest.mark.parametrize(("solver","constraint"), _generate_inputs(reify_imply_exprs), ids=str) +@pytest.mark.parametrize(("solver","constraint"), list(_generate_inputs(reify_imply_exprs)), ids=str) def test_reify_imply_constraints(solver, constraint): """ Tests boolean expression by posting it to solver and checking the value after solve. """ - assert SolverLookup.get(solver, Model(constraint)).solve() - assert constraint.value() + if ALL_SOLS: + n_sols = SolverLookup.get(solver, Model(constraint)).solveAll(display=lambda: verify(constraint)) + assert n_sols >= 1 + else: + assert SolverLookup.get(solver, Model(constraint)).solve() + assert argval(constraint) + assert constraint.value() diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 8921aa5ea..7ff22e88e 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -6,7 +6,7 @@ from cpmpy.expressions import * from cpmpy.expressions.variables import NDVarArray from cpmpy.expressions.core import Operator, Expression -from cpmpy.expressions.utils import get_bounds +from cpmpy.expressions.utils import get_bounds, argval class TestComparison(unittest.TestCase): def test_comps(self): @@ -437,6 +437,7 @@ def test_incomplete_func(self): if cp.SolverLookup.lookup("z3").supported(): self.assertTrue(m.solve(solver="z3")) # ortools does not support divisor spanning 0 work here self.assertRaises(IncompleteFunctionError, cons.value) + self.assertFalse(argval(cons)) # mayhem cons = (arr[10 // (a - b)] == 1).implies(p) diff --git a/tests/test_solvers.py b/tests/test_solvers.py index 38136db1d..55108c91d 100644 --- a/tests/test_solvers.py +++ b/tests/test_solvers.py @@ -693,7 +693,7 @@ def test_vars_not_removed(self): #test unique sols, should be same number self.assertEqual(len(sols),8) - + @pytest.mark.skipif(not CPM_minizinc.supported(), reason="Minizinc not installed") def test_count_mzn(self): @@ -707,4 +707,3 @@ def test_count_mzn(self): m = cp.Model([x + y == 2, wsum == 9]) self.assertTrue(m.solve(solver="minizinc")) -