Skip to content

Commit

Permalink
Improve precision for mean, std, var, cumsum. (#90)
Browse files Browse the repository at this point in the history
* Improve precision for mean, std, var.

np.bincount always accumulates to float64.
So only cast after the division.
  • Loading branch information
dcherian authored Jul 29, 2024
1 parent 12405c2 commit 6c499ff
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
28 changes: 19 additions & 9 deletions numpy_groupies/aggregate_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,23 +149,27 @@ def _mean(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
sums.real = np.bincount(group_idx, weights=a.real, minlength=size)
sums.imag = np.bincount(group_idx, weights=a.imag, minlength=size)
else:
sums = np.bincount(group_idx, weights=a, minlength=size).astype(
dtype, copy=False
)
sums = np.bincount(group_idx, weights=a, minlength=size)

with np.errstate(divide="ignore", invalid="ignore"):
ret = sums.astype(dtype, copy=False) / counts
ret = sums / counts
if not np.isnan(fill_value):
ret[counts == 0] = fill_value
return ret
if iscomplexobj(a):
return ret
else:
return ret.astype(dtype, copy=False)


def _sum_of_squres(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
ret = np.bincount(group_idx, weights=a * a, minlength=size)
if fill_value != 0:
counts = np.bincount(group_idx, minlength=size)
ret[counts == 0] = fill_value
return ret
if iscomplexobj(a):
return ret
else:
return ret.astype(dtype, copy=False)


def _var(
Expand All @@ -176,7 +180,7 @@ def _var(
counts = np.bincount(group_idx, minlength=size)
sums = np.bincount(group_idx, weights=a, minlength=size)
with np.errstate(divide="ignore", invalid="ignore"):
means = sums.astype(dtype, copy=False) / counts
means = sums / counts
counts = np.where(counts > ddof, counts - ddof, 0)
ret = (
np.bincount(group_idx, (a - means[group_idx]) ** 2, minlength=size) / counts
Expand All @@ -185,7 +189,10 @@ def _var(
ret = np.sqrt(ret) # this is now std not var
if not np.isnan(fill_value):
ret[counts == 0] = fill_value
return ret
if iscomplexobj(a):
return ret
else:
return ret.astype(dtype, copy=False)


def _std(group_idx, a, size, fill_value, dtype=np.dtype(np.float64), ddof=0):
Expand Down Expand Up @@ -252,7 +259,10 @@ def _cumsum(group_idx, a, size, fill_value=None, dtype=None):

increasing = np.arange(len(a), dtype=int)
group_starts = _min(group_idx_srt, increasing, size, fill_value=0)[group_idx_srt]
a_srt_cumsum += -a_srt_cumsum[group_starts] + a_srt[group_starts]
# First subtract large numbers
a_srt_cumsum -= a_srt_cumsum[group_starts]
# Then add potentially small numbers
a_srt_cumsum += a_srt[group_starts]
return a_srt_cumsum[invsortidx]


Expand Down
11 changes: 11 additions & 0 deletions numpy_groupies/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,14 @@ def test_var_with_nan_fill_value(aggregate_all, ddof, nan_inds, func):
group_idx, a, axis=-1, fill_value=np.nan, func=func, ddof=ddof
)
np.testing.assert_equal(actual, expected)


def test_cumsum_accuracy(aggregate_all):
array = np.array(
[0.00000000e00, 0.00000000e00, 0.00000000e00, 3.27680000e04, 9.99999975e-06]
)
group_idx = np.array([0, 0, 0, 0, 1])

actual = aggregate_all(group_idx, array, axis=-1, func="cumsum")
expected = array
np.testing.assert_allclose(actual, expected)

0 comments on commit 6c499ff

Please sign in to comment.