Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create vectorized typecheck #447

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions cpmpy/expressions/globalconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,9 @@ class Circuit(GlobalConstraint):
"""
def __init__(self, *args):
flatargs = flatlist(args)
if any(is_boolexpr(arg) for arg in flatargs):
raise TypeError("Circuit global constraint only takes arithmetic arguments: {}".format(flatargs))
check = vectorized_check(flatargs, lambda arg: is_boolexpr(arg))
if check is not None:
raise TypeError("Circuit global constraint only takes arithmetic arguments: {} is boolean".format(check))
super().__init__("circuit", flatargs)
if len(flatargs) < 2:
raise CPMpyException('Circuit constraint must be given a minimum of 2 variables')
Expand Down Expand Up @@ -275,8 +276,9 @@ class Inverse(GlobalConstraint):
"""
def __init__(self, fwd, rev):
flatargs = flatlist([fwd,rev])
if any(is_boolexpr(arg) for arg in flatargs):
raise TypeError("Only integer arguments allowed for global constraint Inverse: {}".format(flatargs))
check = vectorized_check(flatargs, lambda arg: is_boolexpr(arg))
if check is not None:
raise TypeError("Only integer arguments allowed for global constraint Inverse: {} is boolean".format(check))
assert len(fwd) == len(rev)
super().__init__("inverse", [fwd, rev])

Expand All @@ -297,8 +299,9 @@ class Table(GlobalConstraint):
"""
def __init__(self, array, table):
array = flatlist(array)
if not all(isinstance(x, Expression) for x in array):
raise TypeError("the first argument of a Table constraint should only contain variables/expressions")
check = vectorized_check(array, lambda x: not isinstance(x, Expression))
if check is not None:
raise TypeError("the first argument of a Table constraint should only contain variables/expressions, not {}".format(check))
super().__init__("table", [array, table])

def decompose(self):
Expand All @@ -317,8 +320,9 @@ def value(self):
# https://www.ibm.com/docs/en/icos/12.9.0?topic=methods-ifthenelse-method
class IfThenElse(GlobalConstraint):
def __init__(self, condition, if_true, if_false):
if not is_boolexpr(condition) or not is_boolexpr(if_true) or not is_boolexpr(if_false):
raise TypeError("only boolean expression allowed in IfThenElse")
check = vectorized_check([condition,if_true,if_false], lambda x: not is_boolexpr(x))
if check is not None:
raise TypeError("only boolean expression allowed in IfThenElse, not {}".format(check))
super().__init__("ite", [condition, if_true, if_false])

def value(self):
Expand All @@ -344,8 +348,13 @@ class InDomain(GlobalConstraint):
"""

def __init__(self, expr, arr):
assert not (is_boolexpr(expr) or any(is_boolexpr(a) for a in arr)), \
"The expressions in the InDomain constraint should not be boolean"
check = vectorized_check(arr, lambda x: is_boolexpr(x))
if check is not None:
raise TypeError("No boolean expressions allowed in the domain. {} is boolean".format(check))
if is_boolexpr(expr):
raise TypeError("The expression in the InDomain constraint should not be boolean: {} is".format(expr))
if is_any_list(expr):
raise TypeError("no lists allowed for the first argument: {}".format(expr))
super().__init__("InDomain", [expr, arr])

def decompose(self):
Expand Down Expand Up @@ -387,8 +396,9 @@ class Xor(GlobalConstraint):

def __init__(self, arg_list):
flatargs = flatlist(arg_list)
if not (all(is_boolexpr(arg) for arg in flatargs)):
raise TypeError("Only Boolean arguments allowed in Xor global constraint: {}".format(flatargs))
check = vectorized_check(flatargs, lambda x: not is_boolexpr(x))
if check is not None:
raise TypeError("Only Boolean arguments allowed in Xor global constraint: {} is not".format(check))
# convention for commutative binary operators:
# swap if right is constant and left is not
if len(arg_list) == 2 and is_num(arg_list[1]):
Expand Down Expand Up @@ -502,8 +512,9 @@ class GlobalCardinalityCount(GlobalConstraint):

def __init__(self, vars, vals, occ):
flatargs = flatlist([vars, vals, occ])
if any(is_boolexpr(arg) for arg in flatargs):
raise TypeError("Only numerical arguments allowed for gcc global constraint: {}".format(flatargs))
check = vectorized_check(flatargs, lambda x: is_boolexpr(x))
if check is not None:
raise TypeError("Only numerical arguments allowed for gcc global constraint: {} is boolean".format(check))
super().__init__("gcc", [vars,vals,occ])

def decompose(self):
Expand Down
2 changes: 1 addition & 1 deletion cpmpy/expressions/globalfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def decompose_comparison(self):
from ..exceptions import CPMpyException, IncompleteFunctionError, TypeError
from .core import Expression, Operator, Comparison
from .variables import boolvar, intvar, cpm_array, _NumVarImpl
from .utils import flatlist, all_pairs, argval, is_num, eval_comparison, is_any_list, is_boolexpr, get_bounds
from .utils import flatlist, all_pairs, argval, is_num, eval_comparison, is_any_list, is_boolexpr, get_bounds, vectorized_check


class GlobalFunction(Expression):
Expand Down
12 changes: 12 additions & 0 deletions cpmpy/expressions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ def argval(a):
raise e


def vectorized_check(arr, condition):
'''
check the condition for every element in arr
arr should be flat
Returns the first object that meets the condition, or None otherwise
'''
for a in arr:
if condition(a):
return a
return None


def eval_comparison(str_op, lhs, rhs):
"""
Internal function: evaluates the textual `str_op` comparison operator
Expand Down
11 changes: 11 additions & 0 deletions tests/test_globalconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,17 @@ def test_inverse(self):
self.assertRaises(TypeError,cp.Inverse,[a,b],[x,y])
self.assertRaises(TypeError,cp.Inverse,[a,b],[b,False])

def test_inDomain(self):
x = cp.intvar(-8, 8)
y = cp.intvar(-7, -1)
b = cp.boolvar()
a = cp.boolvar()
self.assertTrue(cp.Model(cp.InDomain(x, [x, y, x])).solve())
self.assertRaises(TypeError, cp.InDomain, [x, y, x], [x, y, x])
self.assertRaises(TypeError, cp.InDomain, a, [x, y])
self.assertRaises(TypeError, cp.InDomain, x, [a, b])
self.assertRaises(TypeError, cp.InDomain, x, [y, False])

def test_ITE(self):
x = cp.intvar(-8, 8)
y = cp.intvar(-7, -1)
Expand Down