From 184b8352d701eccc6e4ee884afbdc3900b1b9479 Mon Sep 17 00:00:00 2001 From: Ignace Bleukx Date: Tue, 7 Nov 2023 14:29:30 +0100 Subject: [PATCH] get_bounds() for arrays (#434) * add get_bounds() for (nested) lists --------- Co-authored-by: wout4 --- cpmpy/expressions/core.py | 2 +- cpmpy/expressions/utils.py | 6 ++++++ cpmpy/expressions/variables.py | 7 ++++++- tests/test_expressions.py | 22 ++++++++++++++++++++++ 4 files changed, 35 insertions(+), 2 deletions(-) diff --git a/cpmpy/expressions/core.py b/cpmpy/expressions/core.py index 709d00811..1b210e680 100644 --- a/cpmpy/expressions/core.py +++ b/cpmpy/expressions/core.py @@ -570,7 +570,7 @@ def get_bounds(self): bounds = [lb1 * lb2, lb1 * ub2, ub1 * lb2, ub1 * ub2] lowerbound, upperbound = min(bounds), max(bounds) elif self.name == 'sum': - lbs, ubs = zip(*[get_bounds(x) for x in self.args]) + lbs, ubs = get_bounds(self.args) lowerbound, upperbound = sum(lbs), sum(ubs) elif self.name == 'wsum': weights, vars = self.args diff --git a/cpmpy/expressions/utils.py b/cpmpy/expressions/utils.py index 9cded85ca..9b0310773 100644 --- a/cpmpy/expressions/utils.py +++ b/cpmpy/expressions/utils.py @@ -163,9 +163,15 @@ def get_bounds(expr): returns appropriately rounded integers """ + # import here to avoid circular import from cpmpy.expressions.core import Expression + from cpmpy.expressions.variables import cpm_array + if isinstance(expr, Expression): return expr.get_bounds() + elif is_any_list(expr): + lbs, ubs = zip(*[get_bounds(e) for e in expr]) + return list(lbs), list(ubs) # return list as NDVarArray is covered above else: assert is_num(expr), f"All Expressions should have a get_bounds function, `{expr}`" if is_bool(expr): diff --git a/cpmpy/expressions/variables.py b/cpmpy/expressions/variables.py index b8a0f88d3..31c1ab33b 100644 --- a/cpmpy/expressions/variables.py +++ b/cpmpy/expressions/variables.py @@ -53,7 +53,7 @@ import numpy as np from .core import Expression, Operator -from .utils import is_num, is_int, flatlist, is_boolexpr, is_true_cst, is_false_cst +from .utils import is_num, is_int, flatlist, is_boolexpr, is_true_cst, is_false_cst, get_bounds def BoolVar(shape=1, name=None): @@ -594,6 +594,11 @@ def all(self, axis=None, out=None): # return the NDVarArray that contains the all() constraints return out + + def get_bounds(self): + lbs, ubs = zip(*[get_bounds(e) for e in self]) + return cpm_array(lbs), cpm_array(ubs) + # VECTORIZED master function (delegate) def _vectorized(self, other, attr): if not isinstance(other, Iterable): diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 86b8db426..8921aa5ea 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -6,6 +6,7 @@ from cpmpy.expressions import * from cpmpy.expressions.variables import NDVarArray from cpmpy.expressions.core import Operator, Expression +from cpmpy.expressions.utils import get_bounds class TestComparison(unittest.TestCase): def test_comps(self): @@ -444,6 +445,27 @@ def test_incomplete_func(self): self.assertTrue(m.solve(solver="z3")) self.assertTrue(cons.value()) + + def test_list(self): + + # cpm_array + iv = cp.intvar(0,10,shape=3) + lbs, ubs = iv.get_bounds() + self.assertListEqual([0,0,0], lbs.tolist()) + self.assertListEqual([10,10,10], ubs.tolist()) + # list + iv = [cp.intvar(0,10) for _ in range(3)] + lbs, ubs = get_bounds(iv) + self.assertListEqual([0, 0, 0], lbs) + self.assertListEqual([10, 10, 10], ubs) + # nested list + exprs = [intvar(0,1), [intvar(2,3), intvar(4,5)], [intvar(5,6)]] + lbs, ubs = get_bounds(exprs) + self.assertListEqual([0,[2,4],[5]], lbs) + self.assertListEqual([1,[3,5],[6]], ubs) + + + def test_not_operator(self): p = boolvar() q = boolvar()