Skip to content

Commit

Permalink
Utility functions for NDVarArrays (#492)
Browse files Browse the repository at this point in the history
* ensure axis element works for high dimensional arrays

* add tests

* add global constraints with axis argument

* expand tests

* fix prod

* ensure result is NDVarArray

* fix copy-paste errors

* remove generation of constraints from object

* update tests
  • Loading branch information
IgnaceBleukx authored Sep 26, 2024
1 parent 3edecee commit aff0011
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 70 deletions.
98 changes: 29 additions & 69 deletions cpmpy/expressions/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,132 +491,92 @@ def sum(self, axis=None, out=None):
"""
overwrite np.sum(NDVarArray) as people might use it
"""
from .python_builtins import sum as cpm_sum

if out is not None:
raise NotImplementedError()

if axis is None: # simple case where we want the sum over the whole array
arr = self.flatten()
return Operator("sum", arr)

arr = self.__axis(axis=axis)

out = []
for i in range(0, arr.shape[0]):
out.append(Operator("sum", arr[i, ...]))
return cpm_sum(self)

# return the NDVarArray that contains the sum constraints
return out
return cpm_array(np.apply_along_axis(cpm_sum, axis=axis, arr=self))


def prod(self, axis=None, out=None):
"""
overwrite np.prod(NDVarArray) as people might use it
"""

if out is not None:
raise NotImplementedError()

if axis is None: # simple case where we want the product over the whole array
arr = self.flatten()
return reduce(lambda a, b: a * b, arr)

arr = self.__axis(axis=axis)
if axis is None: # simple case where we want the product over the whole array
return reduce(lambda a, b: a * b, self.flatten())

out = []
for i in range(0, arr.shape[0]):
out.append(reduce(lambda a, b: a * b, arr[i, ...]))

# return the NDVarArray that contains the sum constraints
return out
# TODO: is there a better way? This does pairwise multiplication still
return cpm_array(np.multiply.reduce(self, axis=axis))

def max(self, axis=None, out=None):
"""
overwrite np.max(NDVarArray) as people might use it
"""
from .globalfunctions import Maximum
from .python_builtins import max as cpm_max
if out is not None:
raise NotImplementedError()

if axis is None: # simple case where we want the maximum over the whole array
arr = self.flatten()
return Maximum(arr)

arr = self.__axis(axis=axis)
return cpm_max(self)

out = []
for i in range(0, arr.shape[0]):
out.append(Maximum(arr[i, ...]))

# return the NDVarArray that contains the Maximum global constraints
return out
return cpm_array(np.apply_along_axis(cpm_max, axis=axis, arr=self))

def min(self, axis=None, out=None):
"""
overwrite np.min(NDVarArray) as people might use it
"""
from .globalfunctions import Minimum
from .python_builtins import min as cpm_min
if out is not None:
raise NotImplementedError()

if axis is None: # simple case where we want the Minimum over the whole array
arr = self.flatten()
return Minimum(arr)

arr = self.__axis(axis=axis)
if axis is None: # simple case where we want the minimum over the whole array
return cpm_min(self)

out = []
for i in range(0, arr.shape[0]):
out.append(Minimum(arr[i, ...]))

# return the NDVarArray that contains the Minimum global constraints
return out
return cpm_array(np.apply_along_axis(cpm_min, axis=axis, arr=self))

def any(self, axis=None, out=None):
"""
overwrite np.any(NDVarArray)
"""
from .python_builtins import any
from .python_builtins import any as cpm_any

if any(not is_boolexpr(x) for x in self.flatten()):
if any(not is_boolexpr(x) for x in self.flat):
raise TypeError("Cannot call .any() in an array not consisting only of bools")

if out is not None:
raise NotImplementedError()

if axis is None: # simple case where we want the .any() over the whole array
arr = self.flatten()
return any(arr)

arr = self.__axis(axis=axis)
if axis is None: # simple case where we want a disjunction over the whole array
return cpm_any(self)

out = []
for i in range(0, arr.shape[0]):
out.append(any(arr[i, ...]))
return cpm_array(np.apply_along_axis(cpm_any, axis=axis, arr=self))

# return the NDVarArray that contains the any() constraints
return out

def all(self, axis=None, out=None):
"""
overwrite np.any(NDVarArray)
"""

from .python_builtins import all
if out is not None:
raise NotImplementedError()

if axis is None: # simple case where we want the .all() over the whole array
arr = self.flatten()
return all(arr)
from .python_builtins import all as cpm_all

arr = self.__axis(axis=axis)
if any(not is_boolexpr(x) for x in self.flat):
raise TypeError("Cannot call .any() in an array not consisting only of bools")

out = []
for i in range(0, arr.shape[0]):
out.append(all(arr[i, ...]))
if out is not None:
raise NotImplementedError()

# return the NDVarArray that contains the all() constraints
return out
if axis is None: # simple case where we want a conjunction over the whole array
return cpm_all(self)

return cpm_array(np.apply_along_axis(cpm_all, axis=axis, arr=self))

def get_bounds(self):
lbs, ubs = zip(*[get_bounds(e) for e in self])
Expand Down
33 changes: 32 additions & 1 deletion tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,22 @@ def test_all(self):
res = np.array([all(x[i, ...].value()) for i in range(len(y))])
self.assertTrue(all(y.value() == res))


def test_multidim(self):

functions = ["all", "any", "max", "min", "sum", "prod"]
bv = cp.boolvar(shape=(5,4,3,2)) # high dimensional tensor
arr = np.zeros(shape=bv.shape) # numpy "ground truth"

for axis in range(len(bv.shape)):
np_res = arr.sum(axis=axis)
for func in functions:
cpm_res = getattr(bv, func)(axis=axis)
self.assertIsInstance(cpm_res, NDVarArray)
self.assertEqual(cpm_res.shape, np_res.shape)

def inclusive_range(lb,ub):
return range(lb,ub+1)
return range(lb,ub+1)

class TestBounds(unittest.TestCase):
def test_bounds_mul_sub_sum(self):
Expand Down Expand Up @@ -518,8 +531,26 @@ def test_description(self):
self.assertEqual(str(cons), "either a or b should be true, but not both -- (a) or (b)")


class TestBuildIns(unittest.TestCase):

def setUp(self):
self.x = cp.intvar(0,10,shape=3)

def test_sum(self):
gt = Operator("sum", list(self.x))

self.assertEqual(str(gt), str(cp.sum(self.x)))
self.assertEqual(str(gt), str(cp.sum(list(self.x))))
self.assertEqual(str(gt), str(cp.sum(v for v in self.x)))
self.assertEqual(str(gt), str(cp.sum(self.x[0], self.x[1], self.x[2])))

def test_max(self):
gt = Maximum(self.x)

self.assertEqual(str(gt), str(cp.max(self.x)))
self.assertEqual(str(gt), str(cp.max(list(self.x))))
self.assertEqual(str(gt), str(cp.max(v for v in self.x)))
self.assertEqual(str(gt), str(cp.max(self.x[0], self.x[1], self.x[2])))

if __name__ == '__main__':
unittest.main()

0 comments on commit aff0011

Please sign in to comment.