From aff001114e53acbb5128b2c4bf0cd65f371d7653 Mon Sep 17 00:00:00 2001 From: Ignace Bleukx Date: Thu, 26 Sep 2024 14:55:58 +0200 Subject: [PATCH] Utility functions for NDVarArrays (#492) * ensure axis element works for high dimensional arrays * add tests * add global constraints with axis argument * expand tests * fix prod * ensure result is NDVarArray * fix copy-paste errors * remove generation of constraints from object * update tests --- cpmpy/expressions/variables.py | 98 ++++++++++------------------------ tests/test_expressions.py | 33 +++++++++++- 2 files changed, 61 insertions(+), 70 deletions(-) diff --git a/cpmpy/expressions/variables.py b/cpmpy/expressions/variables.py index 6c70e3eb0..0770e8ebe 100644 --- a/cpmpy/expressions/variables.py +++ b/cpmpy/expressions/variables.py @@ -491,132 +491,92 @@ def sum(self, axis=None, out=None): """ overwrite np.sum(NDVarArray) as people might use it """ + from .python_builtins import sum as cpm_sum + if out is not None: raise NotImplementedError() if axis is None: # simple case where we want the sum over the whole array - arr = self.flatten() - return Operator("sum", arr) - - arr = self.__axis(axis=axis) - - out = [] - for i in range(0, arr.shape[0]): - out.append(Operator("sum", arr[i, ...])) + return cpm_sum(self) - # return the NDVarArray that contains the sum constraints - return out + return cpm_array(np.apply_along_axis(cpm_sum, axis=axis, arr=self)) def prod(self, axis=None, out=None): """ overwrite np.prod(NDVarArray) as people might use it """ + if out is not None: raise NotImplementedError() - if axis is None: # simple case where we want the product over the whole array - arr = self.flatten() - return reduce(lambda a, b: a * b, arr) - - arr = self.__axis(axis=axis) + if axis is None: # simple case where we want the product over the whole array + return reduce(lambda a, b: a * b, self.flatten()) - out = [] - for i in range(0, arr.shape[0]): - out.append(reduce(lambda a, b: a * b, arr[i, ...])) - - # return the NDVarArray that contains the sum constraints - return out + # TODO: is there a better way? This does pairwise multiplication still + return cpm_array(np.multiply.reduce(self, axis=axis)) def max(self, axis=None, out=None): """ overwrite np.max(NDVarArray) as people might use it """ - from .globalfunctions import Maximum + from .python_builtins import max as cpm_max if out is not None: raise NotImplementedError() if axis is None: # simple case where we want the maximum over the whole array - arr = self.flatten() - return Maximum(arr) - - arr = self.__axis(axis=axis) + return cpm_max(self) - out = [] - for i in range(0, arr.shape[0]): - out.append(Maximum(arr[i, ...])) - - # return the NDVarArray that contains the Maximum global constraints - return out + return cpm_array(np.apply_along_axis(cpm_max, axis=axis, arr=self)) def min(self, axis=None, out=None): """ overwrite np.min(NDVarArray) as people might use it """ - from .globalfunctions import Minimum + from .python_builtins import min as cpm_min if out is not None: raise NotImplementedError() - if axis is None: # simple case where we want the Minimum over the whole array - arr = self.flatten() - return Minimum(arr) - - arr = self.__axis(axis=axis) + if axis is None: # simple case where we want the minimum over the whole array + return cpm_min(self) - out = [] - for i in range(0, arr.shape[0]): - out.append(Minimum(arr[i, ...])) - - # return the NDVarArray that contains the Minimum global constraints - return out + return cpm_array(np.apply_along_axis(cpm_min, axis=axis, arr=self)) def any(self, axis=None, out=None): """ overwrite np.any(NDVarArray) """ - from .python_builtins import any + from .python_builtins import any as cpm_any - if any(not is_boolexpr(x) for x in self.flatten()): + if any(not is_boolexpr(x) for x in self.flat): raise TypeError("Cannot call .any() in an array not consisting only of bools") if out is not None: raise NotImplementedError() - if axis is None: # simple case where we want the .any() over the whole array - arr = self.flatten() - return any(arr) - - arr = self.__axis(axis=axis) + if axis is None: # simple case where we want a disjunction over the whole array + return cpm_any(self) - out = [] - for i in range(0, arr.shape[0]): - out.append(any(arr[i, ...])) + return cpm_array(np.apply_along_axis(cpm_any, axis=axis, arr=self)) - # return the NDVarArray that contains the any() constraints - return out def all(self, axis=None, out=None): """ overwrite np.any(NDVarArray) """ - from .python_builtins import all - if out is not None: - raise NotImplementedError() - - if axis is None: # simple case where we want the .all() over the whole array - arr = self.flatten() - return all(arr) + from .python_builtins import all as cpm_all - arr = self.__axis(axis=axis) + if any(not is_boolexpr(x) for x in self.flat): + raise TypeError("Cannot call .any() in an array not consisting only of bools") - out = [] - for i in range(0, arr.shape[0]): - out.append(all(arr[i, ...])) + if out is not None: + raise NotImplementedError() - # return the NDVarArray that contains the all() constraints - return out + if axis is None: # simple case where we want a conjunction over the whole array + return cpm_all(self) + return cpm_array(np.apply_along_axis(cpm_all, axis=axis, arr=self)) def get_bounds(self): lbs, ubs = zip(*[get_bounds(e) for e in self]) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 7ff22e88e..75a5bb1ba 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -312,9 +312,22 @@ def test_all(self): res = np.array([all(x[i, ...].value()) for i in range(len(y))]) self.assertTrue(all(y.value() == res)) + + def test_multidim(self): + + functions = ["all", "any", "max", "min", "sum", "prod"] + bv = cp.boolvar(shape=(5,4,3,2)) # high dimensional tensor + arr = np.zeros(shape=bv.shape) # numpy "ground truth" + + for axis in range(len(bv.shape)): + np_res = arr.sum(axis=axis) + for func in functions: + cpm_res = getattr(bv, func)(axis=axis) + self.assertIsInstance(cpm_res, NDVarArray) + self.assertEqual(cpm_res.shape, np_res.shape) def inclusive_range(lb,ub): - return range(lb,ub+1) + return range(lb,ub+1) class TestBounds(unittest.TestCase): def test_bounds_mul_sub_sum(self): @@ -518,8 +531,26 @@ def test_description(self): self.assertEqual(str(cons), "either a or b should be true, but not both -- (a) or (b)") +class TestBuildIns(unittest.TestCase): + + def setUp(self): + self.x = cp.intvar(0,10,shape=3) + def test_sum(self): + gt = Operator("sum", list(self.x)) + + self.assertEqual(str(gt), str(cp.sum(self.x))) + self.assertEqual(str(gt), str(cp.sum(list(self.x)))) + self.assertEqual(str(gt), str(cp.sum(v for v in self.x))) + self.assertEqual(str(gt), str(cp.sum(self.x[0], self.x[1], self.x[2]))) + + def test_max(self): + gt = Maximum(self.x) + self.assertEqual(str(gt), str(cp.max(self.x))) + self.assertEqual(str(gt), str(cp.max(list(self.x)))) + self.assertEqual(str(gt), str(cp.max(v for v in self.x))) + self.assertEqual(str(gt), str(cp.max(self.x[0], self.x[1], self.x[2]))) if __name__ == '__main__': unittest.main()