Skip to content

Commit

Permalink
Fix as_subchunks and num_subchunks with array indices
Browse files Browse the repository at this point in the history
The logic was not correct for multiple array indices (or multidimensional
boolean array indices).
  • Loading branch information
asmeurer committed Feb 15, 2024
1 parent da64742 commit e7a35fd
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
56 changes: 50 additions & 6 deletions ndindex/chunking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Sequence
from itertools import product
from itertools import chain, product

from .ndindex import ImmutableObject, operator_index, ndindex
from .tuple import Tuple
Expand Down Expand Up @@ -240,8 +240,29 @@ def _fallback():
if isinstance(i, Integer):
iters.append([i.raw//n])
elif isinstance(i, IntegerArray):
from numpy import unique
iters.append(unique(i.array//n).flat)
# All arrays will be together after calling expand() (Tuple does not support arrays
# separated by non-integer indices). Collect them all together
# at once.
arrs = []
while True:
try:
if isinstance(i, IntegerArray):
arrs.append(i.raw.flatten()//n)
else:
idx_args = chain([i], idx_args)
self_ = chain([n], self_)
break
i = next(idx_args)
n = next(self_)
except StopIteration:
break

import numpy as np
a = np.unique(np.stack(arrs), axis=-1)
def _array_iter(a):
for i in range(a.shape[-1]):
yield tuple(a[..., i].flat)
iters.append(_array_iter(a))
elif isinstance(i, Slice) and i.step > 0:
def _slice_iter(s, n):
a, N, m = s.args
Expand All @@ -257,8 +278,16 @@ def _slice_iter(s, n):
yield from _fallback()
return # pragma: no cover

def _flatten(l):
for element in l:
if isinstance(element, tuple):
yield from element
else:
yield element

def _indices(iters):
for p in product(*iters):
for _p in product(*iters):
p = _flatten(_p)
# p = (0, 0, 0), (0, 0, 1), ...
yield Tuple(*[Slice(chunk_size*i, min(chunk_size*(i + 1), n), 1)
for n, chunk_size, i in zip(shape, self, p)])
Expand Down Expand Up @@ -318,8 +347,23 @@ def num_subchunks(self, idx, shape):
if isinstance(i, Integer):
continue
elif isinstance(i, IntegerArray):
from numpy import unique
res *= unique(i.array//n).size
arrs = []
# see as_subchunks
while True:
try:
if isinstance(i, IntegerArray):
arrs.append(i.raw.flatten()//n)
else:
idx_args = chain([i], idx_args)
self_ = chain([n], self_)
break
i = next(idx_args)
n = next(self_)
except StopIteration:
break

import numpy as np
res *= np.unique(np.stack(arrs), axis=-1).shape[-1]
elif isinstance(i, Slice):
if i.step < 0:
raise NotImplementedError("num_subchunks() is not implemented for slices with negative step")
Expand Down
4 changes: 4 additions & 0 deletions ndindex/tests/test_chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def test_indices(chunk_size, shape):
elements = [i for x in subarrays for i in x.flatten()]
assert sorted(elements) == list(range(size))

@example(chunk_size=(1, 1), idx=[[False, True], [True, True]],
shape=(2, 2))
@example(chunk_size=(1,), idx=slice(None, None, -1), shape=(2,))
@example((1,), True, (1,))
@example(chunk_size=(1, 1), idx=slice(1, None, 2), shape=(4, 1))
Expand Down Expand Up @@ -172,6 +174,8 @@ def test_as_subchunks(chunk_size, idx, shape):
def test_as_subchunks_error():
raises(ValueError, lambda: next(ChunkSize((1, 2)).as_subchunks(..., (1, 2, 3))))

@example(chunk_size=(1, 1), idx=[[False, True], [True, True]],
shape=(2, 2))
@example(chunk_size=(1,), idx=None, shape=(1,))
@example((1,), True, (1,))
@example(chunk_size=(1, 1), idx=slice(1, None, 2), shape=(4, 1))
Expand Down

0 comments on commit e7a35fd

Please sign in to comment.