Skip to content

Commit

Permalink
Ensure value of expressions is Python int (#517)
Browse files Browse the repository at this point in the history
* add tests for dtype

* fix objective value in OR-Tools

* fix objective value in Gurobi

* ensure result of wsum is int

* fix __pow__
  • Loading branch information
IgnaceBleukx authored Oct 3, 2024
1 parent aff0011 commit f343c19
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 7 deletions.
11 changes: 6 additions & 5 deletions cpmpy/expressions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,11 @@ def __rmod__(self, other):

def __pow__(self, other, modulo=None):
assert (modulo is None), "Power operator: modulo not supported"
if other == 0:
return 1
elif other == 1:
return self
if is_num(other):
if other == 0:
return 1
if other == 1:
return self
return Operator("pow", [self, other])

def __rpow__(self, other, modulo=None):
Expand Down Expand Up @@ -541,7 +542,7 @@ def value(self):
if any(a is None for a in arg_vals): return None
# non-boolean
elif self.name == "sum": return sum(arg_vals)
elif self.name == "wsum": return sum(arg_vals[0]*np.array(arg_vals[1]))
elif self.name == "wsum": return int(sum(arg_vals[0]*np.array(arg_vals[1])))
elif self.name == "mul": return arg_vals[0] * arg_vals[1]
elif self.name == "sub": return arg_vals[0] - arg_vals[1]
elif self.name == "mod": return arg_vals[0] % arg_vals[1]
Expand Down
6 changes: 5 additions & 1 deletion cpmpy/solvers/gurobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,11 @@ def solve(self, time_limit=None, solution_callback=None, **kwargs):
cpm_var._value = int(solver_val)
# set _objective_value
if self.has_objective():
self.objective_value_ = grb_objective.getValue()
grb_obj_val = grb_objective.getValue()
if grb_obj_val != int(grb_obj_val):
self.objective_value_ = grb_obj_val # can happen with DirectVar using floats
else:
self.objective_value_ = int(grb_obj_val)

return has_sol

Expand Down
4 changes: 3 additions & 1 deletion cpmpy/solvers/ortools.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def solve(self, time_limit=None, assumptions=None, solution_callback=None, **kwa

# translate objective
if self.has_objective():
self.objective_value_ = self.ort_solver.ObjectiveValue()
ort_obj_val = self.ort_solver.ObjectiveValue()
assert int(ort_obj_val) == ort_obj_val, "Objective value should be integer, please report on github"
self.objective_value_ = int(ort_obj_val) # ensure it is an integer

return has_sol

Expand Down
29 changes: 29 additions & 0 deletions tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,35 @@ def test_description(self):
self.assertEqual(str(cons), "either a or b should be true, but not both -- (a) or (b)")


def test_dtype(self):

x = cp.intvar(1,10,shape=(3,3), name="x")
self.assertTrue(cp.Model(cp.sum(x) >= 10).solve())
self.assertIsNotNone(x.value())
print(x.value())
# test all types of expressions
self.assertEqual(int, type(x[0,0].value())) # just the var
for v in x[0]:
self.assertEqual(int, type(v.value())) # array of var
self.assertEqual(int, type(cp.sum(x[0]).value()))
self.assertEqual(int, type(cp.sum(x).value()))
self.assertEqual(int, type(cp.sum([1,2,3] * x[0]).value()))
self.assertEqual(int, type(cp.sum(np.array([1, 2, 3]) * x[0]).value()))
a,b = x[0,[0,1]]
self.assertEqual(int, type((-a).value()))
self.assertEqual(int, type((a - b).value()))
self.assertEqual(int, type((a * b).value()))
self.assertEqual(int, type((a // b).value()))
self.assertEqual(int, type((a ** b).value()))
self.assertEqual(int, type((a % b).value()))








class TestBuildIns(unittest.TestCase):

def setUp(self):
Expand Down

0 comments on commit f343c19

Please sign in to comment.