From 552734b331e4d70cee5cefce3fbb8a31d5a074b9 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Wed, 28 Jun 2023 16:31:34 -0400 Subject: [PATCH] feat: add `repartition` method (#253) --- Co-authored-by: Doug Davis --- src/dask_awkward/layers/layers.py | 2 +- src/dask_awkward/lib/core.py | 38 +++++++++++++++++++++++- src/dask_awkward/lib/structure.py | 48 +++++++++++++++++++++++++++++++ src/dask_awkward/utils.py | 2 +- tests/test_structure.py | 26 +++++++++++++++++ 5 files changed, 113 insertions(+), 3 deletions(-) diff --git a/src/dask_awkward/layers/layers.py b/src/dask_awkward/layers/layers.py index 42e5217e..eb9c1805 100644 --- a/src/dask_awkward/layers/layers.py +++ b/src/dask_awkward/layers/layers.py @@ -19,7 +19,7 @@ class AwkwardBlockwiseLayer(Blockwise): """Just like upstream Blockwise, except we override pickling""" @classmethod - def from_blockwise(cls, layer) -> AwkwardBlockwiseLayer: + def from_blockwise(cls, layer: Blockwise) -> AwkwardBlockwiseLayer: ob = object.__new__(cls) ob.__dict__.update(layer.__dict__) return ob diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index 3e4af401..85155c8d 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -76,7 +76,7 @@ def _finalize_array(results: Sequence[Any]) -> Any: # sometimes we just check the length of partitions so all results # will be integers, just make an array out of that. - elif isinstance(results, tuple) and all( + elif isinstance(results, (tuple, list)) and all( isinstance(r, (int, np.integer)) for r in results ): return ak.Array(list(results)) @@ -560,6 +560,40 @@ def reset_meta(self) -> None: """Assign an empty typetracer array as the collection metadata.""" self._meta = empty_typetracer() + def repartition(self, npartitions=None, divisions=None, rows_per_partition=None): + from dask_awkward.layers import AwkwardMaterializedLayer + from dask_awkward.lib.structure import repartition_layer + + if sum(bool(_) for _ in [npartitions, divisions, rows_per_partition]) != 1: + raise ValueError("Please specify exactly one of the inputs") + if not self.known_divisions: + self.eager_compute_divisions() + nrows = self.divisions[-1] + if npartitions: + rows_per_partition = math.ceil(nrows / npartitions) + if rows_per_partition: + divisions = list(range(0, nrows, rows_per_partition)) + divisions.append(nrows) + + token = tokenize(self, divisions) + key = f"repartition-{token}" + + new_layer_raw = repartition_layer(self, key, divisions) + new_layer = AwkwardMaterializedLayer( + new_layer_raw, + previous_layer_names=[self.name], + ) + new_graph = HighLevelGraph.from_collections( + key, new_layer, dependencies=(self,) + ) + return new_array_object( + new_graph, + key, + meta=self._meta, + behavior=self.behavior, + divisions=divisions, + ) + def __len__(self) -> int: if not self.known_divisions: self.eager_compute_divisions() @@ -700,6 +734,7 @@ def keys_array(self) -> np.ndarray: return np.array(self.__dask_keys__(), dtype=object) def _partitions(self, index: Any) -> Array: + # TODO: this produces a materialized layer, but could work like repartition() and slice() if not isinstance(index, tuple): index = (index,) token = tokenize(self, index) @@ -718,6 +753,7 @@ def _partitions(self, index: Any) -> Array: # if a single partition was requested we trivially know the new divisions. if len(raw) == 1 and isinstance(raw[0], int) and self.known_divisions: + # TODO: don't we always know the divisions? new_divisions = ( 0, self.divisions[raw[0] + 1] - self.divisions[raw[0]], # type: ignore diff --git a/src/dask_awkward/lib/structure.py b/src/dask_awkward/lib/structure.py index 63453d9c..15e2423b 100644 --- a/src/dask_awkward/lib/structure.py +++ b/src/dask_awkward/lib/structure.py @@ -1089,3 +1089,51 @@ def zip( raise DaskAwkwardNotImplemented( "only sized iterables are supported by dak.zip (dict, list, or tuple)" ) + + +def _repartition_func(*stuff): + import builtins + + import awkward as ak + + *data, slices = stuff + data = [ + d[sl[0] : sl[1]] if sl is not None else d + for d, sl in builtins.zip(data, slices) + ] + return ak.concatenate(data) + + +def repartition_layer(arr: Array, key: str, divisions: list[int, ...]): + layer = {} + + indivs = arr.divisions + i = 0 + for index, (start, end) in enumerate(builtins.zip(divisions[:-1], divisions[1:])): + pp = [] + ss = [] + while indivs[i] <= start: + i += 1 + j = i + i -= 1 + while indivs[j] < end: + j += 1 + for k in range(i, j): + if start < indivs[k]: + st = None + elif start < indivs[k + 1]: + st = start - indivs[k] + else: + continue + if end < indivs[k]: + continue + elif end < indivs[k + 1]: + en = end - indivs[k] + else: + en = None + pp.append(k) + ss.append((st, en)) + layer[(key, index)] = ( + (_repartition_func,) + tuple((arr.name, part) for part in pp) + (ss,) + ) + return layer diff --git a/src/dask_awkward/utils.py b/src/dask_awkward/utils.py index c0642fa9..4f5c32db 100644 --- a/src/dask_awkward/utils.py +++ b/src/dask_awkward/utils.py @@ -37,7 +37,7 @@ class LazyInputsDict(Mapping): Parameters ---------- inputs : list[Any] - The list of dicionary values. + The list of dictionary values. """ diff --git a/tests/test_structure.py b/tests/test_structure.py index 009f33d9..296d84c5 100644 --- a/tests/test_structure.py +++ b/tests/test_structure.py @@ -506,3 +506,29 @@ def test_values_astype(daa, caa): dak.values_astype(daa, np.float32), ak.values_astype(caa, np.float32), ) + + +def test_repartition_whole(daa): + daa1 = daa.repartition(npartitions=1) + assert daa1.npartitions == 1 + assert_eq(daa, daa1, check_divisions=False) + + +def test_repartition_no_change(daa): + daa1 = daa.repartition(divisions=(0, 5, 10, 15)) + assert daa1.npartitions == 3 + assert_eq(daa, daa1, check_divisions=False) + + +def test_repartition_split_all(daa): + daa1 = daa.repartition(rows_per_partition=1) + assert daa1.npartitions == len(daa) + out = daa1.compute() + assert out.tolist() == daa.compute().tolist() + + +def test_repartition_uneven(daa): + daa1 = daa.repartition(divisions=(0, 7, 8, 11, 12)) + assert daa1.npartitions == 4 + out = daa1.compute() + assert out.tolist() == daa.compute()[:12].tolist()