Skip to content

Commit

Permalink
feat: support for dask.array collections in map_partitions (#311)
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis authored Jul 6, 2023
1 parent 0a1389e commit a677760
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
31 changes: 9 additions & 22 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from dask.delayed import Delayed
from dask.highlevelgraph import HighLevelGraph
from dask.threaded import get as threaded_get
from dask.utils import IndexCallable, funcname, key_split
from dask.utils import IndexCallable, funcname, is_arraylike, key_split
from tlz import first

from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardMaterializedLayer
Expand Down Expand Up @@ -1259,29 +1259,9 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if method != "__call__":
raise RuntimeError("Array ufunc supports only method == '__call__'")

new_meta = None

# divisions need to be compat. (identical for now?)

inputs_meta = []
for inp in inputs:
# if input is a Dask Awkward Array collection, grab it's meta
if isinstance(inp, Array):
inputs_meta.append(inp._meta)
# if input is a concrete Awkward Array, grab it's typetracer
elif isinstance(inp, ak.Array):
inputs_meta.append(typetracer_array(inp))
# otherwise pass along
else:
inputs_meta.append(inp)

# compute new meta from inputs
new_meta = ufunc(*inputs_meta)

return map_partitions(
ufunc,
*inputs,
meta=new_meta,
output_divisions=1,
**kwargs,
)
Expand Down Expand Up @@ -1486,6 +1466,9 @@ def partitionwise_layer(
pairs.extend([arg, "i"])
elif len(arg.numblocks) == 2:
pairs.extend([arg, "ij"])
elif is_arraylike(arg) and is_dask_collection(arg) and arg.ndim == 1:
pairs.extend([arg.name, "i"])
numblocks[arg.name] = arg.numblocks
elif is_dask_collection(arg):
raise DaskAwkwardNotImplemented(
"Use of Array with other Dask collections is currently unsupported."
Expand Down Expand Up @@ -1994,6 +1977,10 @@ def meta_or_identity(obj: Any) -> Any:
"""
if is_awkward_collection(obj):
return obj._meta
elif is_dask_collection(obj) and is_arraylike(obj):
return ak.Array(
ak.from_numpy(obj._meta).layout.to_typetracer(forget_length=True)
)
return obj


Expand Down Expand Up @@ -2040,7 +2027,7 @@ def to_length_zero_arrays(objects: Sequence[Any]) -> tuple[Any, ...]:
return tuple(map(length_zero_array_or_identity, objects))


def map_meta(fn: Callable, *deps: Any) -> ak.Array | None:
def map_meta(fn: ArgsKwargsPackedFunction, *deps: Any) -> ak.Array | None:
# NOTE: fn is assumed to be a *packed* function
# as defined up in map_partitions. be careful!
try:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Any

import awkward as ak
import dask.array as da
import fsspec
import numpy as np
import pytest
Expand Down Expand Up @@ -727,3 +728,13 @@ def test_map_partitions_args_and_kwargs_have_collection():
ddd=dd,
)
assert_eq(res1, res2)


def test_dask_array_in_map_partitions(daa, caa):
x1 = dak.zeros_like(daa.points.x)
y1 = da.ones(len(x1), chunks=x1.divisions[1])
z1 = x1 + y1
x2 = ak.zeros_like(caa.points.x)
y2 = np.ones(len(x2))
z2 = x2 + y2
assert_eq(z1, z2)

0 comments on commit a677760

Please sign in to comment.