From d11259f03a6a03d77076be1c25d186a3195a181e Mon Sep 17 00:00:00 2001 From: wout4 Date: Fri, 19 Jan 2024 16:57:25 +0100 Subject: [PATCH] create vectorized typecheck --- cpmpy/expressions/globalconstraints.py | 39 +++++++++++++++++--------- cpmpy/expressions/globalfunctions.py | 2 +- cpmpy/expressions/utils.py | 12 ++++++++ tests/test_globalconstraints.py | 11 ++++++++ 4 files changed, 49 insertions(+), 15 deletions(-) diff --git a/cpmpy/expressions/globalconstraints.py b/cpmpy/expressions/globalconstraints.py index cb792bd89..a8e850a7c 100644 --- a/cpmpy/expressions/globalconstraints.py +++ b/cpmpy/expressions/globalconstraints.py @@ -221,8 +221,9 @@ class Circuit(GlobalConstraint): """ def __init__(self, *args): flatargs = flatlist(args) - if any(is_boolexpr(arg) for arg in flatargs): - raise TypeError("Circuit global constraint only takes arithmetic arguments: {}".format(flatargs)) + check = vectorized_check(flatargs, lambda arg: is_boolexpr(arg)) + if check is not None: + raise TypeError("Circuit global constraint only takes arithmetic arguments: {} is boolean".format(check)) super().__init__("circuit", flatargs) if len(flatargs) < 2: raise CPMpyException('Circuit constraint must be given a minimum of 2 variables') @@ -275,8 +276,9 @@ class Inverse(GlobalConstraint): """ def __init__(self, fwd, rev): flatargs = flatlist([fwd,rev]) - if any(is_boolexpr(arg) for arg in flatargs): - raise TypeError("Only integer arguments allowed for global constraint Inverse: {}".format(flatargs)) + check = vectorized_check(flatargs, lambda arg: is_boolexpr(arg)) + if check is not None: + raise TypeError("Only integer arguments allowed for global constraint Inverse: {} is boolean".format(check)) assert len(fwd) == len(rev) super().__init__("inverse", [fwd, rev]) @@ -297,8 +299,9 @@ class Table(GlobalConstraint): """ def __init__(self, array, table): array = flatlist(array) - if not all(isinstance(x, Expression) for x in array): - raise TypeError("the first argument of a Table constraint should only contain variables/expressions") + check = vectorized_check(array, lambda x: not isinstance(x, Expression)) + if check is not None: + raise TypeError("the first argument of a Table constraint should only contain variables/expressions, not {}".format(check)) super().__init__("table", [array, table]) def decompose(self): @@ -317,8 +320,9 @@ def value(self): # https://www.ibm.com/docs/en/icos/12.9.0?topic=methods-ifthenelse-method class IfThenElse(GlobalConstraint): def __init__(self, condition, if_true, if_false): - if not is_boolexpr(condition) or not is_boolexpr(if_true) or not is_boolexpr(if_false): - raise TypeError("only boolean expression allowed in IfThenElse") + check = vectorized_check([condition,if_true,if_false], lambda x: not is_boolexpr(x)) + if check is not None: + raise TypeError("only boolean expression allowed in IfThenElse, not {}".format(check)) super().__init__("ite", [condition, if_true, if_false]) def value(self): @@ -344,8 +348,13 @@ class InDomain(GlobalConstraint): """ def __init__(self, expr, arr): - assert not (is_boolexpr(expr) or any(is_boolexpr(a) for a in arr)), \ - "The expressions in the InDomain constraint should not be boolean" + check = vectorized_check(arr, lambda x: is_boolexpr(x)) + if check is not None: + raise TypeError("No boolean expressions allowed in the domain. {} is boolean".format(check)) + if is_boolexpr(expr): + raise TypeError("The expression in the InDomain constraint should not be boolean: {} is".format(expr)) + if is_any_list(expr): + raise TypeError("no lists allowed for the first argument: {}".format(expr)) super().__init__("InDomain", [expr, arr]) def decompose(self): @@ -387,8 +396,9 @@ class Xor(GlobalConstraint): def __init__(self, arg_list): flatargs = flatlist(arg_list) - if not (all(is_boolexpr(arg) for arg in flatargs)): - raise TypeError("Only Boolean arguments allowed in Xor global constraint: {}".format(flatargs)) + check = vectorized_check(flatargs, lambda x: not is_boolexpr(x)) + if check is not None: + raise TypeError("Only Boolean arguments allowed in Xor global constraint: {} is not".format(check)) # convention for commutative binary operators: # swap if right is constant and left is not if len(arg_list) == 2 and is_num(arg_list[1]): @@ -502,8 +512,9 @@ class GlobalCardinalityCount(GlobalConstraint): def __init__(self, vars, vals, occ): flatargs = flatlist([vars, vals, occ]) - if any(is_boolexpr(arg) for arg in flatargs): - raise TypeError("Only numerical arguments allowed for gcc global constraint: {}".format(flatargs)) + check = vectorized_check(flatargs, lambda x: is_boolexpr(x)) + if check is not None: + raise TypeError("Only numerical arguments allowed for gcc global constraint: {} is boolean".format(check)) super().__init__("gcc", [vars,vals,occ]) def decompose(self): diff --git a/cpmpy/expressions/globalfunctions.py b/cpmpy/expressions/globalfunctions.py index 16c050ce8..15bcfb7c3 100644 --- a/cpmpy/expressions/globalfunctions.py +++ b/cpmpy/expressions/globalfunctions.py @@ -66,7 +66,7 @@ def decompose_comparison(self): from ..exceptions import CPMpyException, IncompleteFunctionError, TypeError from .core import Expression, Operator, Comparison from .variables import boolvar, intvar, cpm_array, _NumVarImpl -from .utils import flatlist, all_pairs, argval, is_num, eval_comparison, is_any_list, is_boolexpr, get_bounds +from .utils import flatlist, all_pairs, argval, is_num, eval_comparison, is_any_list, is_boolexpr, get_bounds, vectorized_check class GlobalFunction(Expression): diff --git a/cpmpy/expressions/utils.py b/cpmpy/expressions/utils.py index 9b0310773..92e46087e 100644 --- a/cpmpy/expressions/utils.py +++ b/cpmpy/expressions/utils.py @@ -127,6 +127,18 @@ def argval(a): raise e +def vectorized_check(arr, condition): + ''' + check the condition for every element in arr + arr should be flat + Returns the first object that meets the condition, or None otherwise + ''' + for a in arr: + if condition(a): + return a + return None + + def eval_comparison(str_op, lhs, rhs): """ Internal function: evaluates the textual `str_op` comparison operator diff --git a/tests/test_globalconstraints.py b/tests/test_globalconstraints.py index 3598ea798..999f75423 100644 --- a/tests/test_globalconstraints.py +++ b/tests/test_globalconstraints.py @@ -728,6 +728,17 @@ def test_inverse(self): self.assertRaises(TypeError,cp.Inverse,[a,b],[x,y]) self.assertRaises(TypeError,cp.Inverse,[a,b],[b,False]) + def test_inDomain(self): + x = cp.intvar(-8, 8) + y = cp.intvar(-7, -1) + b = cp.boolvar() + a = cp.boolvar() + self.assertTrue(cp.Model(cp.InDomain(x, [x, y, x])).solve()) + self.assertRaises(TypeError, cp.InDomain, [x, y, x], [x, y, x]) + self.assertRaises(TypeError, cp.InDomain, a, [x, y]) + self.assertRaises(TypeError, cp.InDomain, x, [a, b]) + self.assertRaises(TypeError, cp.InDomain, x, [y, False]) + def test_ITE(self): x = cp.intvar(-8, 8) y = cp.intvar(-7, -1)