Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid numpy sum where we are at risk of overflow #390

Merged
merged 6 commits into from
Sep 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions cpmpy/expressions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,26 +554,31 @@ def get_bounds(self):
lb1, ub1 = get_bounds(self.args[0])
lb2, ub2 = get_bounds(self.args[1])
bounds = [lb1 * lb2, lb1 * ub2, ub1 * lb2, ub1 * ub2]
return min(bounds), max(bounds)
lowerbound, upperbound = min(bounds), max(bounds)
elif self.name == 'sum':
lbs, ubs = zip(*[get_bounds(x) for x in self.args])
return sum(lbs), sum(ubs)
lowerbound, upperbound = sum(lbs), sum(ubs)
elif self.name == 'wsum':
weights, vars = self.args
var_bounds = np.array([get_bounds(arg) for arg in vars]).T
bounds = var_bounds * weights
return bounds.min(axis=0).sum(), bounds.max(axis=0).sum() # for every column is axis=0...
bounds = []
#this may seem like too many lines, but avoiding np.sum avoids overflowing things at int32 bounds
for i, varbounds in enumerate([get_bounds(arg) for arg in vars]):
sortbounds = (list(weights[i] * x for x in varbounds))
sortbounds.sort()
bounds += [sortbounds]
lbs, ubs = (zip(*bounds))
lowerbound, upperbound = sum(lbs), sum(ubs) #this is builtins sum, not numpy sum
elif self.name == 'sub':
lb1, ub1 = get_bounds(self.args[0])
lb2, ub2 = get_bounds(self.args[1])
return lb1-ub2, ub1-lb2
lowerbound, upperbound = lb1-ub2, ub1-lb2
elif self.name == 'div':
lb1, ub1 = get_bounds(self.args[0])
lb2, ub2 = get_bounds(self.args[1])
if lb2 <= 0 <= ub2:
raise ZeroDivisionError("division by domain containing 0 is not supported")
bounds = [lb1 // lb2, lb1 // ub2, ub1 // lb2, ub1 // ub2]
return min(bounds), max(bounds)
lowerbound, upperbound = min(bounds), max(bounds)
elif self.name == 'mod':
lb1, ub1 = get_bounds(self.args[0])
lb2, ub2 = get_bounds(self.args[1])
Expand All @@ -595,14 +600,18 @@ def get_bounds(self):
# E.g., (-2)^2 is positive, but (-2)^1 is negative, so for (-2)^[0,2] we also need to add (-2)^1.
bounds += [lb1 ** (ub2 - 1), ub1 ** (ub2 - 1)]
# This approach is safe but not tight (e.g., [-2,-1]^2 will give (-2,4) as range instead of [1,4]).
return min(bounds), max(bounds)
lowerbound, upperbound = min(bounds), max(bounds)

elif self.name == '-':
lb1, ub1 = get_bounds(self.args[0])
return -ub1, -lb1

raise ValueError(f"Bound requested for unknown expression {self}, please report bug on github")

lowerbound, upperbound = -ub1, -lb1

if lowerbound == None:
raise ValueError(f"Bound requested for unknown expression {self}, please report bug on github")
if lowerbound > upperbound:
#overflow happened
raise OverflowError('Overflow when calculating bounds, your expression exceeds integer bounds.')
return lowerbound, upperbound
def _wsum_should(arg):
""" Internal helper: should the arg be in a wsum instead of sum

Expand Down