diff --git a/ndindex/chunking.py b/ndindex/chunking.py index 5e3e9517..b1570a2b 100644 --- a/ndindex/chunking.py +++ b/ndindex/chunking.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from itertools import product +from itertools import chain, product from .ndindex import ImmutableObject, operator_index, ndindex from .tuple import Tuple @@ -240,8 +240,29 @@ def _fallback(): if isinstance(i, Integer): iters.append([i.raw//n]) elif isinstance(i, IntegerArray): - from numpy import unique - iters.append(unique(i.array//n).flat) + # All arrays will be together after calling expand() (Tuple does not support arrays + # separated by non-integer indices). Collect them all together + # at once. + arrs = [] + while True: + try: + if isinstance(i, IntegerArray): + arrs.append(i.raw.flatten()//n) + else: + idx_args = chain([i], idx_args) + self_ = chain([n], self_) + break + i = next(idx_args) + n = next(self_) + except StopIteration: + break + + import numpy as np + a = np.unique(np.stack(arrs), axis=-1) + def _array_iter(a): + for i in range(a.shape[-1]): + yield tuple(a[..., i].flat) + iters.append(_array_iter(a)) elif isinstance(i, Slice) and i.step > 0: def _slice_iter(s, n): a, N, m = s.args @@ -257,8 +278,16 @@ def _slice_iter(s, n): yield from _fallback() return # pragma: no cover + def _flatten(l): + for element in l: + if isinstance(element, tuple): + yield from element + else: + yield element + def _indices(iters): - for p in product(*iters): + for _p in product(*iters): + p = _flatten(_p) # p = (0, 0, 0), (0, 0, 1), ... yield Tuple(*[Slice(chunk_size*i, min(chunk_size*(i + 1), n), 1) for n, chunk_size, i in zip(shape, self, p)]) @@ -318,8 +347,23 @@ def num_subchunks(self, idx, shape): if isinstance(i, Integer): continue elif isinstance(i, IntegerArray): - from numpy import unique - res *= unique(i.array//n).size + arrs = [] + # see as_subchunks + while True: + try: + if isinstance(i, IntegerArray): + arrs.append(i.raw.flatten()//n) + else: + idx_args = chain([i], idx_args) + self_ = chain([n], self_) + break + i = next(idx_args) + n = next(self_) + except StopIteration: + break + + import numpy as np + res *= np.unique(np.stack(arrs), axis=-1).shape[-1] elif isinstance(i, Slice): if i.step < 0: raise NotImplementedError("num_subchunks() is not implemented for slices with negative step") diff --git a/ndindex/tests/helpers.py b/ndindex/tests/helpers.py index 1cceaff9..958ffca4 100644 --- a/ndindex/tests/helpers.py +++ b/ndindex/tests/helpers.py @@ -483,7 +483,7 @@ def iterslice(start_range=(-10, 10), yield (start, stop, step) -chunk_shapes = shared(shapes) +chunk_shapes = short_shapes @composite def chunk_sizes(draw, shapes=chunk_shapes): diff --git a/ndindex/tests/test_chunking.py b/ndindex/tests/test_chunking.py index 6e6eaccc..9c02cf04 100644 --- a/ndindex/tests/test_chunking.py +++ b/ndindex/tests/test_chunking.py @@ -104,6 +104,8 @@ def test_indices(chunk_size, shape): elements = [i for x in subarrays for i in x.flatten()] assert sorted(elements) == list(range(size)) +@example(chunk_size=(1, 1), idx=[[False, True], [True, True]], + shape=(2, 2)) @example(chunk_size=(1,), idx=slice(None, None, -1), shape=(2,)) @example((1,), True, (1,)) @example(chunk_size=(1, 1), idx=slice(1, None, 2), shape=(4, 1)) @@ -172,6 +174,8 @@ def test_as_subchunks(chunk_size, idx, shape): def test_as_subchunks_error(): raises(ValueError, lambda: next(ChunkSize((1, 2)).as_subchunks(..., (1, 2, 3)))) +@example(chunk_size=(1, 1), idx=[[False, True], [True, True]], + shape=(2, 2)) @example(chunk_size=(1,), idx=None, shape=(1,)) @example((1,), True, (1,)) @example(chunk_size=(1, 1), idx=slice(1, None, 2), shape=(4, 1))