Skip to content

Commit

Permalink
feat: add repartition method (#253)
Browse files Browse the repository at this point in the history
---

Co-authored-by: Doug Davis <ddavis@ddavis.io>
  • Loading branch information
martindurant and douglasdavis authored Jun 28, 2023
1 parent 868aa90 commit 552734b
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class AwkwardBlockwiseLayer(Blockwise):
"""Just like upstream Blockwise, except we override pickling"""

@classmethod
def from_blockwise(cls, layer) -> AwkwardBlockwiseLayer:
def from_blockwise(cls, layer: Blockwise) -> AwkwardBlockwiseLayer:
ob = object.__new__(cls)
ob.__dict__.update(layer.__dict__)
return ob
Expand Down
38 changes: 37 additions & 1 deletion src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _finalize_array(results: Sequence[Any]) -> Any:

# sometimes we just check the length of partitions so all results
# will be integers, just make an array out of that.
elif isinstance(results, tuple) and all(
elif isinstance(results, (tuple, list)) and all(
isinstance(r, (int, np.integer)) for r in results
):
return ak.Array(list(results))
Expand Down Expand Up @@ -560,6 +560,40 @@ def reset_meta(self) -> None:
"""Assign an empty typetracer array as the collection metadata."""
self._meta = empty_typetracer()

def repartition(self, npartitions=None, divisions=None, rows_per_partition=None):
from dask_awkward.layers import AwkwardMaterializedLayer
from dask_awkward.lib.structure import repartition_layer

if sum(bool(_) for _ in [npartitions, divisions, rows_per_partition]) != 1:
raise ValueError("Please specify exactly one of the inputs")
if not self.known_divisions:
self.eager_compute_divisions()
nrows = self.divisions[-1]
if npartitions:
rows_per_partition = math.ceil(nrows / npartitions)
if rows_per_partition:
divisions = list(range(0, nrows, rows_per_partition))
divisions.append(nrows)

token = tokenize(self, divisions)
key = f"repartition-{token}"

new_layer_raw = repartition_layer(self, key, divisions)
new_layer = AwkwardMaterializedLayer(
new_layer_raw,
previous_layer_names=[self.name],
)
new_graph = HighLevelGraph.from_collections(
key, new_layer, dependencies=(self,)
)
return new_array_object(
new_graph,
key,
meta=self._meta,
behavior=self.behavior,
divisions=divisions,
)

def __len__(self) -> int:
if not self.known_divisions:
self.eager_compute_divisions()
Expand Down Expand Up @@ -700,6 +734,7 @@ def keys_array(self) -> np.ndarray:
return np.array(self.__dask_keys__(), dtype=object)

def _partitions(self, index: Any) -> Array:
# TODO: this produces a materialized layer, but could work like repartition() and slice()
if not isinstance(index, tuple):
index = (index,)
token = tokenize(self, index)
Expand All @@ -718,6 +753,7 @@ def _partitions(self, index: Any) -> Array:

# if a single partition was requested we trivially know the new divisions.
if len(raw) == 1 and isinstance(raw[0], int) and self.known_divisions:
# TODO: don't we always know the divisions?
new_divisions = (
0,
self.divisions[raw[0] + 1] - self.divisions[raw[0]], # type: ignore
Expand Down
48 changes: 48 additions & 0 deletions src/dask_awkward/lib/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,3 +1089,51 @@ def zip(
raise DaskAwkwardNotImplemented(
"only sized iterables are supported by dak.zip (dict, list, or tuple)"
)


def _repartition_func(*stuff):
import builtins

import awkward as ak

*data, slices = stuff
data = [
d[sl[0] : sl[1]] if sl is not None else d
for d, sl in builtins.zip(data, slices)
]
return ak.concatenate(data)


def repartition_layer(arr: Array, key: str, divisions: list[int, ...]):
layer = {}

indivs = arr.divisions
i = 0
for index, (start, end) in enumerate(builtins.zip(divisions[:-1], divisions[1:])):
pp = []
ss = []
while indivs[i] <= start:
i += 1
j = i
i -= 1
while indivs[j] < end:
j += 1
for k in range(i, j):
if start < indivs[k]:
st = None
elif start < indivs[k + 1]:
st = start - indivs[k]
else:
continue
if end < indivs[k]:
continue
elif end < indivs[k + 1]:
en = end - indivs[k]
else:
en = None
pp.append(k)
ss.append((st, en))
layer[(key, index)] = (
(_repartition_func,) + tuple((arr.name, part) for part in pp) + (ss,)
)
return layer
2 changes: 1 addition & 1 deletion src/dask_awkward/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class LazyInputsDict(Mapping):
Parameters
----------
inputs : list[Any]
The list of dicionary values.
The list of dictionary values.
"""

Expand Down
26 changes: 26 additions & 0 deletions tests/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,3 +506,29 @@ def test_values_astype(daa, caa):
dak.values_astype(daa, np.float32),
ak.values_astype(caa, np.float32),
)


def test_repartition_whole(daa):
daa1 = daa.repartition(npartitions=1)
assert daa1.npartitions == 1
assert_eq(daa, daa1, check_divisions=False)


def test_repartition_no_change(daa):
daa1 = daa.repartition(divisions=(0, 5, 10, 15))
assert daa1.npartitions == 3
assert_eq(daa, daa1, check_divisions=False)


def test_repartition_split_all(daa):
daa1 = daa.repartition(rows_per_partition=1)
assert daa1.npartitions == len(daa)
out = daa1.compute()
assert out.tolist() == daa.compute().tolist()


def test_repartition_uneven(daa):
daa1 = daa.repartition(divisions=(0, 7, 8, 11, 12))
assert daa1.npartitions == 4
out = daa1.compute()
assert out.tolist() == daa.compute()[:12].tolist()

0 comments on commit 552734b

Please sign in to comment.