Skip to content

Commit

Permalink
ensure result is NDVarArray
Browse files Browse the repository at this point in the history
  • Loading branch information
IgnaceBleukx committed Jun 3, 2024
1 parent 3094e15 commit e159fd6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
28 changes: 14 additions & 14 deletions cpmpy/expressions/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def sum(self, axis=None, out=None):
if axis is None: # simple case where we want the sum over the whole array
return cpm_sum(self)

return np.apply_along_axis(cpm_sum, axis=axis, arr=self)
return cpm_array(np.apply_along_axis(cpm_sum, axis=axis, arr=self))


def prod(self, axis=None, out=None):
Expand All @@ -491,7 +491,7 @@ def prod(self, axis=None, out=None):
return reduce(lambda a, b: a * b, self.flatten())

# TODO: is there a better way? This does pairwise multiplication still
return np.multiply.reduce(self, axis=axis)
return cpm_array(np.multiply.reduce(self, axis=axis))

def max(self, axis=None, out=None):
"""
Expand All @@ -504,7 +504,7 @@ def max(self, axis=None, out=None):
if axis is None: # simple case where we want the maximum over the whole array
return cpm_max(self)

return np.apply_along_axis(cpm_max, axis=axis, arr=self)
return cpm_array(np.apply_along_axis(cpm_max, axis=axis, arr=self))

def min(self, axis=None, out=None):
"""
Expand All @@ -517,7 +517,7 @@ def min(self, axis=None, out=None):
if axis is None: # simple case where we want the maximum over the whole array
return cpm_min(self)

return np.apply_along_axis(cpm_min, axis=axis, arr=self)
return cpm_array(np.apply_along_axis(cpm_min, axis=axis, arr=self))

def any(self, axis=None, out=None):
"""
Expand All @@ -534,7 +534,7 @@ def any(self, axis=None, out=None):
if axis is None: # simple case where we want the maximum over the whole array
return cpm_any(self)

return np.apply_along_axis(cpm_any, axis=axis, arr=self)
return cpm_array(np.apply_along_axis(cpm_any, axis=axis, arr=self))


def all(self, axis=None, out=None):
Expand All @@ -553,39 +553,39 @@ def all(self, axis=None, out=None):
if axis is None: # simple case where we want the maximum over the whole array
return cpm_all(self)

return np.apply_along_axis(cpm_all, axis=axis, arr=self)
return cpm_array(np.apply_along_axis(cpm_all, axis=axis, arr=self))

def alldifferent(self, axis=None):
from .globalconstraints import AllDifferent
return np.apply_along_axis(lambda arr : AllDifferent(arr), axis=axis, arr=self)
return cpm_array(np.apply_along_axis(lambda arr : AllDifferent(arr), axis=axis, arr=self))

def alldifferent_except0(self, axis=None):
from .globalconstraints import AllDifferentExcept0
return np.apply_along_axis(lambda arr: AllDifferentExcept0(arr), axis=axis, arr=self)
return cpm_array(np.apply_along_axis(lambda arr: AllDifferentExcept0(arr), axis=axis, arr=self))

def allequal(self, axis=None):
from .globalconstraints import AllEqual
return np.apply_along_axis(lambda arr: AllEqual(arr), axis=axis, arr=self)
return cpm_array(np.apply_along_axis(lambda arr: AllEqual(arr), axis=axis, arr=self))

def circuit(self, axis=None):
from .globalconstraints import Circuit
return np.apply_along_axis(lambda arr: Circuit(arr), axis=axis, arr=self)
return cpm_array(np.apply_along_axis(lambda arr: Circuit(arr), axis=axis, arr=self))

def increasing(self, axis=None):
from .globalconstraints import Increasing
return np.apply_along_axis(lambda arr: Increasing(arr), axis=axis, arr=self)
return cpm_array(np.apply_along_axis(lambda arr: Increasing(arr), axis=axis, arr=self))

def increasing_strict(self, axis=None):
from .globalconstraints import IncreasingStrict
return np.apply_along_axis(lambda arr: IncreasingStrict(arr), axis=axis, arr=self)
return cpm_array(np.apply_along_axis(lambda arr: IncreasingStrict(arr), axis=axis, arr=self))

def decreasing(self, axis=None):
from .globalconstraints import Decreasing
return np.apply_along_axis(lambda arr: Decreasing(arr), axis=axis, arr=self)
return cpm_array(np.apply_along_axis(lambda arr: Decreasing(arr), axis=axis, arr=self))

def decreasing_strict(self, axis=None):
from .globalconstraints import DecreasingStrict
return np.apply_along_axis(lambda arr: DecreasingStrict(arr), axis=axis, arr=self)
return cpm_array(np.apply_along_axis(lambda arr: DecreasingStrict(arr), axis=axis, arr=self))



Expand Down
4 changes: 3 additions & 1 deletion tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,11 @@ def test_multidim(self):
np_res = arr.sum(axis=axis)
for func in functions:
cpm_res = getattr(bv, func)(axis=axis)
self.assertIsInstance(cpm_res, NDVarArray)
self.assertEqual(cpm_res.shape, np_res.shape)
for cons in constraints:
cpm_res = getattr(iv, func)(axis=axis)
cpm_res = getattr(iv, cons)(axis=axis)
self.assertIsInstance(cpm_res, NDVarArray)
self.assertEqual(cpm_res.shape, np_res.shape)


Expand Down

0 comments on commit e159fd6

Please sign in to comment.