From f343c19073365d7482685092b61ce24e58a1088a Mon Sep 17 00:00:00 2001 From: Ignace Bleukx Date: Thu, 3 Oct 2024 11:31:37 +0200 Subject: [PATCH] Ensure value of expressions is Python int (#517) * add tests for dtype * fix objective value in OR-Tools * fix objective value in Gurobi * ensure result of wsum is int * fix __pow__ --- cpmpy/expressions/core.py | 11 ++++++----- cpmpy/solvers/gurobi.py | 6 +++++- cpmpy/solvers/ortools.py | 4 +++- tests/test_expressions.py | 29 +++++++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/cpmpy/expressions/core.py b/cpmpy/expressions/core.py index 47d03be41..2cbdecf75 100644 --- a/cpmpy/expressions/core.py +++ b/cpmpy/expressions/core.py @@ -329,10 +329,11 @@ def __rmod__(self, other): def __pow__(self, other, modulo=None): assert (modulo is None), "Power operator: modulo not supported" - if other == 0: - return 1 - elif other == 1: - return self + if is_num(other): + if other == 0: + return 1 + if other == 1: + return self return Operator("pow", [self, other]) def __rpow__(self, other, modulo=None): @@ -541,7 +542,7 @@ def value(self): if any(a is None for a in arg_vals): return None # non-boolean elif self.name == "sum": return sum(arg_vals) - elif self.name == "wsum": return sum(arg_vals[0]*np.array(arg_vals[1])) + elif self.name == "wsum": return int(sum(arg_vals[0]*np.array(arg_vals[1]))) elif self.name == "mul": return arg_vals[0] * arg_vals[1] elif self.name == "sub": return arg_vals[0] - arg_vals[1] elif self.name == "mod": return arg_vals[0] % arg_vals[1] diff --git a/cpmpy/solvers/gurobi.py b/cpmpy/solvers/gurobi.py index d58e93a8c..fd560fe92 100644 --- a/cpmpy/solvers/gurobi.py +++ b/cpmpy/solvers/gurobi.py @@ -180,7 +180,11 @@ def solve(self, time_limit=None, solution_callback=None, **kwargs): cpm_var._value = int(solver_val) # set _objective_value if self.has_objective(): - self.objective_value_ = grb_objective.getValue() + grb_obj_val = grb_objective.getValue() + if grb_obj_val != int(grb_obj_val): + self.objective_value_ = grb_obj_val # can happen with DirectVar using floats + else: + self.objective_value_ = int(grb_obj_val) return has_sol diff --git a/cpmpy/solvers/ortools.py b/cpmpy/solvers/ortools.py index 3805bae3c..a03bebe8c 100644 --- a/cpmpy/solvers/ortools.py +++ b/cpmpy/solvers/ortools.py @@ -217,7 +217,9 @@ def solve(self, time_limit=None, assumptions=None, solution_callback=None, **kwa # translate objective if self.has_objective(): - self.objective_value_ = self.ort_solver.ObjectiveValue() + ort_obj_val = self.ort_solver.ObjectiveValue() + assert int(ort_obj_val) == ort_obj_val, "Objective value should be integer, please report on github" + self.objective_value_ = int(ort_obj_val) # ensure it is an integer return has_sol diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 75a5bb1ba..f1d9d561f 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -531,6 +531,35 @@ def test_description(self): self.assertEqual(str(cons), "either a or b should be true, but not both -- (a) or (b)") + def test_dtype(self): + + x = cp.intvar(1,10,shape=(3,3), name="x") + self.assertTrue(cp.Model(cp.sum(x) >= 10).solve()) + self.assertIsNotNone(x.value()) + print(x.value()) + # test all types of expressions + self.assertEqual(int, type(x[0,0].value())) # just the var + for v in x[0]: + self.assertEqual(int, type(v.value())) # array of var + self.assertEqual(int, type(cp.sum(x[0]).value())) + self.assertEqual(int, type(cp.sum(x).value())) + self.assertEqual(int, type(cp.sum([1,2,3] * x[0]).value())) + self.assertEqual(int, type(cp.sum(np.array([1, 2, 3]) * x[0]).value())) + a,b = x[0,[0,1]] + self.assertEqual(int, type((-a).value())) + self.assertEqual(int, type((a - b).value())) + self.assertEqual(int, type((a * b).value())) + self.assertEqual(int, type((a // b).value())) + self.assertEqual(int, type((a ** b).value())) + self.assertEqual(int, type((a % b).value())) + + + + + + + + class TestBuildIns(unittest.TestCase): def setUp(self):