Skip to content

Commit

Permalink
feat: add new attribute for when divisions are known (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis authored Jul 20, 2023
1 parent d10fb42 commit d159390
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 21 deletions.
55 changes: 36 additions & 19 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,9 @@ def __getitem__(self, where: Any) -> Any:
hlg = HighLevelGraph.from_collections(name, task, dependencies=(d,))
return Delayed(name, hlg)

def __getattr__(self, where: str) -> Any:
def __getattr__(self, attr: str) -> Any:
d = self.to_delayed(optimize_graph=True)
return getattr(d, where)
return getattr(d, attr)

@property
def known_value(self) -> Any | None:
Expand Down Expand Up @@ -561,25 +561,34 @@ def reset_meta(self) -> None:
"""Assign an empty typetracer array as the collection metadata."""
self._meta = empty_typetracer()

def repartition(self, npartitions=None, divisions=None, rows_per_partition=None):
def repartition(
self,
npartitions: int | None = None,
divisions: tuple[int, ...] | None = None,
rows_per_partition: int | None = None,
) -> Array:
from dask_awkward.layers import AwkwardMaterializedLayer
from dask_awkward.lib.structure import repartition_layer

if sum(bool(_) for _ in [npartitions, divisions, rows_per_partition]) != 1:
raise ValueError("Please specify exactly one of the inputs")
if not self.known_divisions:
self.eager_compute_divisions()
nrows = self.divisions[-1]
if npartitions:
nrows = self.defined_divisions[-1]
new_divisions: tuple[int, ...] = tuple()
if divisions:
new_divisions = divisions
elif npartitions:
rows_per_partition = math.ceil(nrows / npartitions)
if rows_per_partition:
divisions = list(range(0, nrows, rows_per_partition))
divisions.append(nrows)
new_divs = list(range(0, nrows, rows_per_partition))
new_divs.append(nrows)
new_divisions = tuple(new_divs)

token = tokenize(self, divisions)
key = f"repartition-{token}"

new_layer_raw = repartition_layer(self, key, divisions)
new_layer_raw = repartition_layer(self, key, new_divisions)
new_layer = AwkwardMaterializedLayer(
new_layer_raw,
previous_layer_names=[self.name],
Expand All @@ -592,7 +601,7 @@ def repartition(self, npartitions=None, divisions=None, rows_per_partition=None)
key,
meta=self._meta,
behavior=self.behavior,
divisions=divisions,
divisions=tuple(new_divisions),
)

def __len__(self) -> int:
Expand Down Expand Up @@ -688,6 +697,12 @@ def known_divisions(self) -> bool:
"""True if the divisions are known (absence of ``None`` in the tuple)."""
return len(self.divisions) > 0 and None not in self.divisions

@property
def defined_divisions(self) -> tuple[int, ...]:
if not self.known_divisions:
raise ValueError("defined_divisions only works when divisions are known.")
return self._divisions # type: ignore

@property
def npartitions(self) -> int:
"""Total number of partitions."""
Expand Down Expand Up @@ -936,9 +951,9 @@ def _getitem_slice_on_zero(self, where: tuple[slice, ...]):

if not self.known_divisions:
self.eager_compute_divisions()
stop = sl.stop or self.divisions[-1]
start = start if start >= 0 else self.divisions[-1] + start
stop = stop if stop >= 0 else self.divisions[-1] + stop
stop = sl.stop or self.defined_divisions[-1]
start = start if start >= 0 else self.defined_divisions[-1] + start
stop = stop if stop >= 0 else self.defined_divisions[-1] + stop
if step < 0:
raise DaskAwkwardNotImplemented("negative step slice on zeroth dimension")

Expand All @@ -951,21 +966,22 @@ def _getitem_slice_on_zero(self, where: tuple[slice, ...]):
dask = {}
# make low-level graph
for i in range(self.npartitions):
if start > self.divisions[i + 1]:
if start > self.defined_divisions[i + 1]:
# first partition not yet found
continue
if stop < self.divisions[i] and dask:
if stop < self.defined_divisions[i] and dask:
# no more partitions with valid rows
# does **NOT** exit if there are no partitions yet, to make sure there is always
# at least one, needed to get metadata of empty output right
break
slice_start = max(start - self.divisions[i], 0 + remainder)
slice_start = max(start - self.defined_divisions[i], 0 + remainder)
slice_end = min(
stop - self.divisions[i], self.divisions[i + 1] - self.divisions[i]
stop - self.defined_divisions[i],
self.defined_divisions[i + 1] - self.defined_divisions[i],
)
if (
slice_end == slice_start
and (self.divisions[i + 1] - self.divisions[i])
and (self.defined_divisions[i + 1] - self.defined_divisions[i])
and dask
):
# in case of zero-row last partition (if not only partition)
Expand All @@ -978,7 +994,8 @@ def _getitem_slice_on_zero(self, where: tuple[slice, ...]):
)
outpart += 1
remainder = (
(self.divisions[i] + slice_start) - self.divisions[i + 1]
(self.defined_divisions[i] + slice_start)
- self.defined_divisions[i + 1]
) % step
remainder = step - remainder if remainder < 0 else remainder
nextdiv = math.ceil((slice_end - slice_start) / step)
Expand Down Expand Up @@ -1343,7 +1360,7 @@ def head(self, nrow=10, compute=True):
if compute:
return out.compute()
if self.known_divisions:
out._divisions = (0, min(nrow, self.divisions[1]))
out._divisions = (0, min(nrow, self.defined_divisions[1]))
return out


Expand Down
4 changes: 2 additions & 2 deletions src/dask_awkward/lib/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,10 +1116,10 @@ def _repartition_func(*stuff):
return ak.concatenate(data)


def repartition_layer(arr: Array, key: str, divisions: list[int, ...]):
def repartition_layer(arr: Array, key: str, divisions: tuple[int, ...]):
layer = {}

indivs = arr.divisions
indivs = arr.defined_divisions
i = 0
for index, (start, end) in enumerate(builtins.zip(divisions[:-1], divisions[1:])):
pp = []
Expand Down
6 changes: 6 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,12 @@ def test_scalar_to_delayed(daa: Array, optimize_graph: bool) -> None:
assert d1.compute() == s1c


def test_defined_divisions_exception(ndjson_points1):
jsds = dak.from_json([ndjson_points1] * 3)
with pytest.raises(ValueError, match="defined_divisions only works"):
jsds.defined_divisions


def test_compatible_partitions(ndjson_points_file: str) -> None:
daa1 = dak.from_json([ndjson_points_file] * 5)
daa2 = dak.from_awkward(daa1.compute(), npartitions=4)
Expand Down

0 comments on commit d159390

Please sign in to comment.