Skip to content

Commit

Permalink
avoid numpy sum where we are at risk of overflow (#390)
Browse files Browse the repository at this point in the history
* avoid numpy sum where we are at risk of overflow
check for overflow in get_bounds
  • Loading branch information
Wout4 authored Sep 8, 2023
1 parent e759044 commit 0e2e160
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions cpmpy/expressions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,26 +554,31 @@ def get_bounds(self):
lb1, ub1 = get_bounds(self.args[0])
lb2, ub2 = get_bounds(self.args[1])
bounds = [lb1 * lb2, lb1 * ub2, ub1 * lb2, ub1 * ub2]
return min(bounds), max(bounds)
lowerbound, upperbound = min(bounds), max(bounds)
elif self.name == 'sum':
lbs, ubs = zip(*[get_bounds(x) for x in self.args])
return sum(lbs), sum(ubs)
lowerbound, upperbound = sum(lbs), sum(ubs)
elif self.name == 'wsum':
weights, vars = self.args
var_bounds = np.array([get_bounds(arg) for arg in vars]).T
bounds = var_bounds * weights
return bounds.min(axis=0).sum(), bounds.max(axis=0).sum() # for every column is axis=0...
bounds = []
#this may seem like too many lines, but avoiding np.sum avoids overflowing things at int32 bounds
for i, varbounds in enumerate([get_bounds(arg) for arg in vars]):
sortbounds = (list(weights[i] * x for x in varbounds))
sortbounds.sort()
bounds += [sortbounds]
lbs, ubs = (zip(*bounds))
lowerbound, upperbound = sum(lbs), sum(ubs) #this is builtins sum, not numpy sum
elif self.name == 'sub':
lb1, ub1 = get_bounds(self.args[0])
lb2, ub2 = get_bounds(self.args[1])
return lb1-ub2, ub1-lb2
lowerbound, upperbound = lb1-ub2, ub1-lb2
elif self.name == 'div':
lb1, ub1 = get_bounds(self.args[0])
lb2, ub2 = get_bounds(self.args[1])
if lb2 <= 0 <= ub2:
raise ZeroDivisionError("division by domain containing 0 is not supported")
bounds = [lb1 // lb2, lb1 // ub2, ub1 // lb2, ub1 // ub2]
return min(bounds), max(bounds)
lowerbound, upperbound = min(bounds), max(bounds)
elif self.name == 'mod':
lb1, ub1 = get_bounds(self.args[0])
lb2, ub2 = get_bounds(self.args[1])
Expand All @@ -595,14 +600,18 @@ def get_bounds(self):
# E.g., (-2)^2 is positive, but (-2)^1 is negative, so for (-2)^[0,2] we also need to add (-2)^1.
bounds += [lb1 ** (ub2 - 1), ub1 ** (ub2 - 1)]
# This approach is safe but not tight (e.g., [-2,-1]^2 will give (-2,4) as range instead of [1,4]).
return min(bounds), max(bounds)
lowerbound, upperbound = min(bounds), max(bounds)

elif self.name == '-':
lb1, ub1 = get_bounds(self.args[0])
return -ub1, -lb1

raise ValueError(f"Bound requested for unknown expression {self}, please report bug on github")

lowerbound, upperbound = -ub1, -lb1

if lowerbound == None:
raise ValueError(f"Bound requested for unknown expression {self}, please report bug on github")
if lowerbound > upperbound:
#overflow happened
raise OverflowError('Overflow when calculating bounds, your expression exceeds integer bounds.')
return lowerbound, upperbound
def _wsum_should(arg):
""" Internal helper: should the arg be in a wsum instead of sum
Expand Down

0 comments on commit 0e2e160

Please sign in to comment.