From 085d9191e633742e7617ff36564ca8e54124c251 Mon Sep 17 00:00:00 2001 From: Dimos Tsouros Date: Thu, 3 Oct 2024 14:36:33 +0200 Subject: [PATCH] cover cp.abs() using the builtin (#513) * cover cp.abs() using the builtin * abs implementation * actually abs for only 1 element * tests for builtins * support vectorized * recursive for iterable * fix error in test for all_diff_except_n * add testcase * add testcase * fix missing 'not' remove redundant 'abs' code * still overwrite __abs__ in expressions * Update python_builtins.py bit more explicit... --------- Co-authored-by: wout4 Co-authored-by: Tias Guns --- cpmpy/expressions/__init__.py | 2 +- cpmpy/expressions/python_builtins.py | 27 ++++++++++++++-- tests/test_builtins.py | 47 ++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 4 deletions(-) create mode 100644 tests/test_builtins.py diff --git a/cpmpy/expressions/__init__.py b/cpmpy/expressions/__init__.py index e86de9032..b695ff926 100644 --- a/cpmpy/expressions/__init__.py +++ b/cpmpy/expressions/__init__.py @@ -28,4 +28,4 @@ from .globalconstraints import alldifferent, allequal, circuit # Old, to be deprecated from .globalfunctions import Maximum, Minimum, Abs, Element, Count, NValue, NValueExcept, Among from .core import BoolVal -from .python_builtins import all, any, max, min, sum +from .python_builtins import all, any, max, min, sum, abs diff --git a/cpmpy/expressions/python_builtins.py b/cpmpy/expressions/python_builtins.py index 7fcaa805c..140669a09 100644 --- a/cpmpy/expressions/python_builtins.py +++ b/cpmpy/expressions/python_builtins.py @@ -17,13 +17,15 @@ max min sum + abs """ import builtins # to use the original Python-builtins -from .utils import is_false_cst, is_true_cst -from .variables import NDVarArray +from .utils import is_false_cst, is_true_cst, is_any_list +from .variables import NDVarArray, cpm_array from .core import Expression, Operator -from .globalfunctions import Minimum, Maximum +from .globalfunctions import Minimum, Maximum, Abs +from ..exceptions import CPMpyException # Overwriting all/any python built-ins @@ -125,3 +127,22 @@ def sum(*iterable, **kwargs): assert len(kwargs)==0, "sum over decision variables does not support keyword arguments" return Operator("sum", iterable) + + +def abs(element): + """ + abs() overwrites the python built-in to support decision variables. + + if the element given is not a CPMpy expression, the built-in is called + else an Absolute functional global constraint is constructed. + """ + if is_any_list(element): # compat: not allowed by builtins.abs(), but allowed by numpy.abs() + return cpm_array([abs(elem) for elem in element]) + + if isinstance(element, Expression): + # create global + return Abs(element) + + return builtins.abs(element) + + diff --git a/tests/test_builtins.py b/tests/test_builtins.py new file mode 100644 index 000000000..5c95b13aa --- /dev/null +++ b/tests/test_builtins.py @@ -0,0 +1,47 @@ +import unittest + +import cpmpy as cp +from cpmpy.exceptions import CPMpyException + +iv = cp.intvar(-8, 8, shape=5) + + +class TestBuiltin(unittest.TestCase): + + def test_max(self): + constraints = [cp.max(iv) + 9 <= 8] + model = cp.Model(constraints) + self.assertTrue(model.solve()) + self.assertTrue(cp.max(iv.value()) <= -1) + + model = cp.Model(cp.max(iv).decompose_comparison('!=', 4)) + self.assertTrue(model.solve()) + self.assertNotEqual(str(cp.max(iv.value())), '4') + + def test_min(self): + constraints = [cp.min(iv) + 9 == 8] + model = cp.Model(constraints) + self.assertTrue(model.solve()) + self.assertEqual(str(cp.min(iv.value())), '-1') + + model = cp.Model(cp.min(iv).decompose_comparison('==', 4)) + self.assertTrue(model.solve()) + self.assertEqual(str(cp.min(iv.value())), '4') + + def test_abs(self): + constraints = [cp.abs(iv[0]) + 9 <= 8] + model = cp.Model(constraints) + self.assertFalse(model.solve()) + + #with list + constraints = [cp.abs(iv+2) <= 8, iv < 0] + model = cp.Model(constraints) + self.assertTrue(model.solve()) + + constraints = [cp.abs([iv[0], iv[2], iv[1], -8]) <= 8, iv < 0] + model = cp.Model(constraints) + self.assertTrue(model.solve()) + + model = cp.Model(cp.abs(iv[0]).decompose_comparison('!=', 4)) + self.assertTrue(model.solve()) + self.assertNotEqual(str(cp.abs(iv[0].value())), '4') \ No newline at end of file