From a677760aef8220873bb6cf2856a4e8e538b757b5 Mon Sep 17 00:00:00 2001 From: Doug Davis Date: Thu, 6 Jul 2023 11:50:56 -0500 Subject: [PATCH] feat: support for `dask.array` collections in `map_partitions` (#311) --- src/dask_awkward/lib/core.py | 31 +++++++++---------------------- tests/test_core.py | 11 +++++++++++ 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index 85155c8d..6962b992 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -35,7 +35,7 @@ from dask.delayed import Delayed from dask.highlevelgraph import HighLevelGraph from dask.threaded import get as threaded_get -from dask.utils import IndexCallable, funcname, key_split +from dask.utils import IndexCallable, funcname, is_arraylike, key_split from tlz import first from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardMaterializedLayer @@ -1259,29 +1259,9 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): if method != "__call__": raise RuntimeError("Array ufunc supports only method == '__call__'") - new_meta = None - - # divisions need to be compat. (identical for now?) - - inputs_meta = [] - for inp in inputs: - # if input is a Dask Awkward Array collection, grab it's meta - if isinstance(inp, Array): - inputs_meta.append(inp._meta) - # if input is a concrete Awkward Array, grab it's typetracer - elif isinstance(inp, ak.Array): - inputs_meta.append(typetracer_array(inp)) - # otherwise pass along - else: - inputs_meta.append(inp) - - # compute new meta from inputs - new_meta = ufunc(*inputs_meta) - return map_partitions( ufunc, *inputs, - meta=new_meta, output_divisions=1, **kwargs, ) @@ -1486,6 +1466,9 @@ def partitionwise_layer( pairs.extend([arg, "i"]) elif len(arg.numblocks) == 2: pairs.extend([arg, "ij"]) + elif is_arraylike(arg) and is_dask_collection(arg) and arg.ndim == 1: + pairs.extend([arg.name, "i"]) + numblocks[arg.name] = arg.numblocks elif is_dask_collection(arg): raise DaskAwkwardNotImplemented( "Use of Array with other Dask collections is currently unsupported." @@ -1994,6 +1977,10 @@ def meta_or_identity(obj: Any) -> Any: """ if is_awkward_collection(obj): return obj._meta + elif is_dask_collection(obj) and is_arraylike(obj): + return ak.Array( + ak.from_numpy(obj._meta).layout.to_typetracer(forget_length=True) + ) return obj @@ -2040,7 +2027,7 @@ def to_length_zero_arrays(objects: Sequence[Any]) -> tuple[Any, ...]: return tuple(map(length_zero_array_or_identity, objects)) -def map_meta(fn: Callable, *deps: Any) -> ak.Array | None: +def map_meta(fn: ArgsKwargsPackedFunction, *deps: Any) -> ak.Array | None: # NOTE: fn is assumed to be a *packed* function # as defined up in map_partitions. be careful! try: diff --git a/tests/test_core.py b/tests/test_core.py index c50d2d22..3730325e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any import awkward as ak +import dask.array as da import fsspec import numpy as np import pytest @@ -727,3 +728,13 @@ def test_map_partitions_args_and_kwargs_have_collection(): ddd=dd, ) assert_eq(res1, res2) + + +def test_dask_array_in_map_partitions(daa, caa): + x1 = dak.zeros_like(daa.points.x) + y1 = da.ones(len(x1), chunks=x1.divisions[1]) + z1 = x1 + y1 + x2 = ak.zeros_like(caa.points.x) + y2 = np.ones(len(x2)) + z2 = x2 + y2 + assert_eq(z1, z2)