Skip to content

Commit

Permalink
feat: support variable rebinning (#913)
Browse files Browse the repository at this point in the history
* feat: support full UHI for rebinning

* One type of axis at time

* style: pre-commit fixes

* Revert accidental deletion

* Better code quality

* fix: partial fix

Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com>

* Refactor and move the rebinning logic in the loop

* Fix repr test

* Update src/boost_histogram/tag.py

* Fix rebinning logic

* Add logic for updating bin contents

* make it work for nd hists

* fix: support callable, add validation

Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com>

* fix: the result of group_mapping() should be checked for None

Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com>

---------

Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Henry Schreiner <henryschreineriii@gmail.com>
  • Loading branch information
3 people authored Aug 23, 2024
1 parent 5d04fa0 commit 92df5a6
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 24 deletions.
5 changes: 1 addition & 4 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
[submodule "pybind11"]
path = extern/pybind11
url = ../../pybind/pybind11.git
[submodule "extern/boosthistogram"]
[submodule "extern/histogram"]
path = extern/histogram
url = ../../boostorg/histogram.git
[submodule "extern/core"]
Expand Down
61 changes: 53 additions & 8 deletions src/boost_histogram/_internal/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from boost_histogram import _core

from .axestuple import AxesTuple
from .axis import Axis
from .axis import Axis, Variable
from .enum import Kind
from .storage import Double, Storage
from .typing import Accumulator, ArrayLike, CppHistogram
Expand Down Expand Up @@ -827,6 +827,7 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
slices: list[_core.algorithm.reduce_command] = []
pick_each: dict[int, int] = {}
pick_set: dict[int, list[int]] = {}
reduced: CppHistogram | None = None

# Compute needed slices and projections
for i, ind in enumerate(indexes):
Expand Down Expand Up @@ -855,38 +856,82 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
# This ensures that callable start/stop are handled
start, stop = self.axes[i]._process_loc(ind.start, ind.stop)

groups = []
if ind != slice(None):
merge = 1
if ind.step is not None:
if hasattr(ind.step, "factor"):
if getattr(ind.step, "factor", None) is not None:
merge = ind.step.factor
elif (
hasattr(ind.step, "group_mapping")
and (tmp_groups := ind.step.group_mapping(self.axes[i]))
is not None
):
groups = tmp_groups
elif callable(ind.step):
if ind.step is sum:
integrations.add(i)
else:
raise RuntimeError("Full UHI not supported yet")
raise NotImplementedError

if ind.start is not None or ind.stop is not None:
slices.append(
_core.algorithm.slice(
i, start, stop, _core.algorithm.slice_mode.crop
)
)
continue
if len(groups) == 0:
continue
else:
raise IndexError(
"The third argument to a slice must be rebin or projection"
)

assert isinstance(start, int)
assert isinstance(stop, int)
slices.append(_core.algorithm.slice_and_rebin(i, start, stop, merge))
# rebinning with factor
if len(groups) == 0:
slices.append(
_core.algorithm.slice_and_rebin(i, start, stop, merge)
)
# rebinning with groups
elif len(groups) != 0:
if not reduced:
reduced = self._hist
axes = [reduced.axis(x) for x in range(reduced.rank())]
reduced_view = reduced.view(flow=True)
new_axes_indices = [axes[i].edges[0]]

j = 0
for group in groups:
new_axes_indices += [axes[i].edges[j + group]]
j += group

variable_axis = Variable(
new_axes_indices, metadata=axes[i].metadata
)
axes[i] = variable_axis._ax

logger.debug("Axes: %s", axes)

new_reduced = reduced.__class__(axes)
new_view = new_reduced.view(flow=True)

j = 1
for new_j, group in enumerate(groups):
for _ in range(group):
pos = [slice(None)] * (i)
new_view[(*pos, new_j + 1, ...)] += reduced_view[ # type: ignore[arg-type]
(*pos, j, ...) # type: ignore[arg-type]
]
j += 1

reduced = new_reduced

# Will be updated below
if slices or pick_set or pick_each or integrations:
if (slices or pick_set or pick_each or integrations) and not reduced:
reduced = self._hist
else:
logger.debug("Reduce actions are all empty, just making a copy")
elif not reduced:
reduced = copy.copy(self._hist)

if pick_each:
Expand Down
47 changes: 39 additions & 8 deletions src/boost_histogram/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

import copy
from builtins import sum
from typing import TypeVar
from typing import TYPE_CHECKING, Sequence, TypeVar

if TYPE_CHECKING:
from uhi.typing.plottable import PlottableAxis

from ._internal.typing import AxisLike

Expand Down Expand Up @@ -108,12 +111,40 @@ def __call__(self, axis: AxisLike) -> int: # noqa: ARG002


class rebin:
__slots__ = ("factor",)

def __init__(self, value: int) -> None:
self.factor = value
__slots__ = (
"factor",
"groups",
)

def __init__(
self,
factor: int | None = None,
*,
groups: Sequence[int] | None = None,
) -> None:
if not sum(i is None for i in [factor, groups]) == 1:
raise ValueError("Exactly one, a factor or groups should be provided")
self.factor = factor
self.groups = groups

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.factor})"

# TODO: Add __call__ to support UHI
repr_str = f"{self.__class__.__name__}"
args: dict[str, int | Sequence[int] | None] = {
"factor": self.factor,
"groups": self.groups,
}
for k, v in args.items():
if v is not None:
return_str = f"{repr_str}({k}={v})"
break
return return_str

def group_mapping(self, axis: PlottableAxis) -> Sequence[int]:
if self.groups is not None:
if sum(self.groups) != len(axis):
msg = f"The sum of the groups ({sum(self.groups)}) must be equal to the number of bins in the axis ({len(axis)})"
raise ValueError(msg)
return self.groups
if self.factor is not None:
return [self.factor] * len(axis)
raise ValueError("No rebinning factor or groups provided")
62 changes: 59 additions & 3 deletions tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,13 +632,17 @@ def test_shrink_1d():

def test_rebin_1d():
h = bh.Histogram(bh.axis.Regular(20, 1, 5))
h.fill(1.1)
h.fill([1.1, 2.2, 3.3, 4.4])

hs = h[{0: slice(None, None, bh.rebin(4))}]
assert_array_equal(hs.view(), [1, 0, 0, 0, 0])
assert_array_equal(hs.view(), [1, 1, 1, 0, 1])

hs = h[{0: bh.rebin(4)}]
assert_array_equal(hs.view(), [1, 0, 0, 0, 0])
assert_array_equal(hs.view(), [1, 1, 1, 0, 1])

hs = h[{0: bh.rebin(groups=[1, 2, 3, 14])}]
assert_array_equal(hs.view(), [1, 0, 0, 3])
assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0])


def test_shrink_rebin_1d():
Expand All @@ -659,8 +663,60 @@ def test_rebin_nd():
assert h[{1: s[:: bh.rebin(2)]}].axes.size == (20, 15, 40)
assert h[{2: s[:: bh.rebin(2)]}].axes.size == (20, 30, 20)

assert h[{0: s[:: bh.rebin(groups=[1, 2, 17])]}].axes.size == (3, 30, 40)
assert h[{1: s[:: bh.rebin(groups=[1, 2, 27])]}].axes.size == (20, 3, 40)
assert h[{2: s[:: bh.rebin(groups=[1, 2, 37])]}].axes.size == (20, 30, 3)
assert np.all(
np.isclose(
h[{0: s[:: bh.rebin(groups=[1, 2, 17])]}].axes[0].edges,
[1.0, 1.1, 1.3, 3.0],
)
)
assert np.all(
np.isclose(
h[{1: s[:: bh.rebin(groups=[1, 2, 27])]}].axes[1].edges,
[1.0, 1.06666667, 1.2, 3.0],
)
)
assert np.all(
np.isclose(
h[{2: s[:: bh.rebin(groups=[1, 2, 37])]}].axes[2].edges,
[1.0, 1.05, 1.15, 3.0],
)
)

assert h[{0: s[:: bh.rebin(2)], 2: s[:: bh.rebin(2)]}].axes.size == (10, 30, 20)

assert h[
{0: s[:: bh.rebin(groups=[1, 2, 17])], 2: s[:: bh.rebin(groups=[1, 2, 37])]}
].axes.size == (3, 30, 3)
assert np.all(
np.isclose(
h[
{
0: s[:: bh.rebin(groups=[1, 2, 17])],
2: s[:: bh.rebin(groups=[1, 2, 37])],
}
]
.axes[0]
.edges,
[1.0, 1.1, 1.3, 3],
)
)
assert np.all(
np.isclose(
h[
{
0: s[:: bh.rebin(groups=[1, 2, 17])],
2: s[:: bh.rebin(groups=[1, 2, 37])],
}
]
.axes[2]
.edges,
[1.0, 1.05, 1.15, 3.0],
)
)

assert h[{1: s[:: bh.sum]}].axes.size == (20, 40)
assert h[{1: bh.sum}].axes.size == (20, 40)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_histogram_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def test_repr():
assert repr(bh.overflow + 1) == "overflow + 1"
assert repr(bh.overflow - 1) == "overflow - 1"

assert repr(bh.rebin(2)) == "rebin(2)"
assert repr(bh.rebin(2)) == "rebin(factor=2)"


# Was broken in 0.6.1
Expand Down

0 comments on commit 92df5a6

Please sign in to comment.