diff --git a/numpy_groupies/aggregate_numpy.py b/numpy_groupies/aggregate_numpy.py index 9dd82a8..0299312 100644 --- a/numpy_groupies/aggregate_numpy.py +++ b/numpy_groupies/aggregate_numpy.py @@ -272,6 +272,7 @@ def _nancumsum(group_idx, a, size, fill_value=None, dtype=None): generic=_generic_callable, ) _impl_dict.update(("nan" + k, v) for k, v in list(_impl_dict.items()) if k not in funcs_no_separate_nan) +_impl_dict["nancumsum"] = _nancumsum def _aggregate_base( @@ -308,6 +309,8 @@ def _aggregate_base( if "nan" in func: if "arg" in func: kwargs["_nansqueeze"] = True + elif "cum" in func: + pass else: good = ~np.isnan(a) if "len" not in func or is_pandas: diff --git a/numpy_groupies/tests/test_compare.py b/numpy_groupies/tests/test_compare.py index 4261541..6e45d63 100644 --- a/numpy_groupies/tests/test_compare.py +++ b/numpy_groupies/tests/test_compare.py @@ -4,7 +4,6 @@ may throw NotImplementedError in order to show missing functionality without throwing test errors. """ -import sys from itertools import product import numpy as np diff --git a/numpy_groupies/tests/test_generic.py b/numpy_groupies/tests/test_generic.py index 9cf99fa..a6eac40 100644 --- a/numpy_groupies/tests/test_generic.py +++ b/numpy_groupies/tests/test_generic.py @@ -24,6 +24,12 @@ def _deselect_purepy(aggregate_all, *args, **kwargs): return aggregate_all.__name__.endswith("purepy") +def _deselect_purepy_and_pandas(aggregate_all, *args, **kwargs): + # purepy and pandas implementation handle some nan cases differently. + # So they need to be excluded from several tests.""" + return aggregate_all.__name__.endswith(("pandas", "purepy")) + + def _deselect_purepy_and_invalid_axis(aggregate_all, size, axis, *args, **kwargs): if axis >= len(size): return True @@ -358,6 +364,17 @@ def test_cumsum(aggregate_all): np.testing.assert_array_equal(res, ref) +@pytest.mark.deselect_if(func=_deselect_purepy_and_pandas) +def test_nancumsum(aggregate_all): + # https://github.com/ml31415/numpy-groupies/issues/79 + group_idx = [0, 0, 0, 1, 1, 0, 0] + a = [2, 2, np.nan, 2, 2, 2, 2] + ref = [2., 4., 4., 2., 4., 6., 8.] + + res = aggregate_all(group_idx, a, func="nancumsum") + np.testing.assert_array_equal(res, ref) + + def test_cummax(aggregate_all): group_idx = np.array([4, 3, 3, 4, 4, 1, 1, 1, 7, 8, 7, 4, 3, 3, 1, 1]) a = np.array([3, 4, 1, 3, 9, 9, 6, 7, 7, 0, 8, 2, 1, 8, 9, 8]) diff --git a/numpy_groupies/utils.py b/numpy_groupies/utils.py index e2af7a3..56befb5 100644 --- a/numpy_groupies/utils.py +++ b/numpy_groupies/utils.py @@ -117,6 +117,8 @@ np.array: "array", np.asarray: "array", np.sort: "sort", + np.cumsum: "cumsum", + np.cumprod: "cumprod", np.nansum: "nansum", np.nanprod: "nanprod", np.nanmean: "nanmean", @@ -126,8 +128,7 @@ np.nanstd: "nanstd", np.nanargmax: "nanargmax", np.nanargmin: "nanargmin", - np.cumsum: "cumsum", - np.cumprod: "cumprod", + np.nancumsum: "nancumsum", } @@ -150,7 +151,7 @@ def get_aliasing(*extra): alias.update((k, k) for k in set(alias.values())) # Treat nan-functions as firstclass member and add them directly for key in set(alias.values()): - if key not in funcs_no_separate_nan: + if key not in funcs_no_separate_nan and not key.startswith("nan"): key = "nan" + key alias[key] = key return alias