From eccb8f26c0c52866218f73406f679d0b5fa317cf Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Thu, 22 Feb 2024 13:13:15 -0600 Subject: [PATCH 1/3] split map_partitions into user-facing and internal/unsafe --- src/dask_awkward/lib/core.py | 128 +++++++++++++++++++++++------------ 1 file changed, 84 insertions(+), 44 deletions(-) diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index a6111867..f970928c 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -1539,6 +1539,15 @@ def wrapper(*args, **kwargs): except (IndexError, KeyError): raise AttributeError(f"{attr} not in fields.") + def _map_partitions( + self, + func: Callable, + *args: Any, + **kwargs: Any, + ) -> Array: + """Maps a function to partitions without flattening the function inputs.""" + return _map_partitions(func, self, *args, **kwargs) + def map_partitions( self, func: Callable, @@ -1890,6 +1899,72 @@ def __call__(self, *args_deps_expanded): return self.fn(*args, **kwargs) +def _map_partitions( + fn: Callable, + *args: Any, + label: str | None = None, + token: str | None = None, + meta: Any | None = None, + output_divisions: int | None = None, + **kwargs: Any, +) -> Array: + """Map a callable across all partitions of any number of collections. + No wrapper is used to flatten the function arguments. + """ + token = token or tokenize(fn, *args, meta, **kwargs) + label = hyphenize(label or funcname(fn)) + name = f"{label}-{token}" + + deps = [a for a in args if is_dask_collection(a)] + [ + v for v in kwargs.values() if is_dask_collection(v) + ] + + dak_arrays = tuple(filter(lambda x: isinstance(x, Array), deps)) + + lay = partitionwise_layer( + fn, + name, + *args, + **kwargs, + ) + + if meta is None: + meta = map_meta(fn, *args, **kwargs) + + hlg = HighLevelGraph.from_collections( + name, + lay, + dependencies=deps, + ) + + if len(dak_arrays) == 0: + raise TypeError( + "at least one argument passed to map_partitions " + "should be a dask_awkward.Array collection." + ) + in_npartitions = dak_arrays[0].npartitions + in_divisions = dak_arrays[0].divisions + + if output_divisions is not None: + if output_divisions == 1: + new_divisions = dak_arrays[0].divisions + else: + new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions)) + return new_array_object( + hlg, + name=name, + meta=meta, + divisions=new_divisions, + ) + else: + return new_array_object( + hlg, + name=name, + meta=meta, + npartitions=in_npartitions, + ) + + def map_partitions( base_fn: Callable, *args: Any, @@ -1971,6 +2046,9 @@ def map_partitions( This is effectively the same as `d = c * a` """ + token = token or tokenize(base_fn, *args, meta, **kwargs) + label = hyphenize(label or funcname(base_fn)) + opt_touch_all = kwargs.pop("opt_touch_all", None) if opt_touch_all is not None: warnings.warn( @@ -1979,9 +2057,6 @@ def map_partitions( "and the function call will likely fail." ) - token = token or tokenize(base_fn, *args, meta, **kwargs) - label = hyphenize(label or funcname(base_fn)) - name = f"{label}-{token}" kwarg_flat_deps, kwarg_repacker = unpack_collections(kwargs, traverse=traverse) flat_deps, _ = unpack_collections(*args, *kwargs.values(), traverse=traverse) @@ -2017,51 +2092,16 @@ def map_partitions( kwarg_repacker, arg_lens_for_repackers, ) - - lay = partitionwise_layer( + return _map_partitions( fn, - name, *arg_flat_deps_expanded, *kwarg_flat_deps, + label=label, + token=token, + meta=meta, + output_divisions=output_divisions, ) - if meta is None: - meta = map_meta(fn, *arg_flat_deps_expanded, *kwarg_flat_deps) - - hlg = HighLevelGraph.from_collections( - name, - lay, - dependencies=flat_deps, - ) - - dak_arrays = tuple(filter(lambda x: isinstance(x, Array), flat_deps)) - if len(dak_arrays) == 0: - raise TypeError( - "at least one argument passed to map_partitions " - "should be a dask_awkward.Array collection." - ) - in_npartitions = dak_arrays[0].npartitions - in_divisions = dak_arrays[0].divisions - - if output_divisions is not None: - if output_divisions == 1: - new_divisions = flat_deps[0].divisions - else: - new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions)) - return new_array_object( - hlg, - name=name, - meta=meta, - divisions=new_divisions, - ) - else: - return new_array_object( - hlg, - name=name, - meta=meta, - npartitions=in_npartitions, - ) - def _chunk_reducer_non_positional( chunk: ak.Array, @@ -2408,7 +2448,7 @@ def to_length_zero_arrays(objects: Sequence[Any]) -> tuple[Any, ...]: return tuple(map(length_zero_array_or_identity, objects)) -def map_meta(fn: ArgsKwargsPackedFunction, *deps: Any) -> ak.Array | None: +def map_meta(fn: Callable | ArgsKwargsPackedFunction, *deps: Any) -> ak.Array | None: # NOTE: fn is assumed to be a *packed* function # as defined up in map_partitions. be careful! try: From e1f59428ce24f48aca17d07f06fda266bfd1e9a3 Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Thu, 22 Feb 2024 13:25:26 -0600 Subject: [PATCH 2/3] use _map_partitions in some hotspots --- src/dask_awkward/lib/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index f970928c..9714e873 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -1212,7 +1212,7 @@ def _getitem_trivial_map_partitions( else: m = to_meta([where])[0] meta = self._meta[m] - return map_partitions( + return _map_partitions( operator.getitem, self, where, @@ -1232,7 +1232,7 @@ def _getitem_outer_bool_or_int_lazy_array(self, where): ) new_meta = self._meta[where._meta] - return self.map_partitions( + return self._map_partitions( operator.getitem, where, meta=new_meta, From dd319b6f26f26e3e3e38b749a69ee3f060e8839c Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Thu, 22 Feb 2024 15:15:38 -0600 Subject: [PATCH 3/3] add documentation / context for _map_partitions --- src/dask_awkward/lib/core.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index 9714e873..47cf13da 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -1909,7 +1909,13 @@ def _map_partitions( **kwargs: Any, ) -> Array: """Map a callable across all partitions of any number of collections. - No wrapper is used to flatten the function arguments. + No wrapper is used to flatten the function arguments. This is meant for + dask-awkward internal use or in situations where input data are sanitized. + + The parameters of this function are otherwise the same as map_partitions, + but the limitation that args, kwargs must be non-nested and flat. They + will not be traversed to extract all dask collections, except those in + the first dimension of args or kwargs. """ token = token or tokenize(fn, *args, meta, **kwargs) label = hyphenize(label or funcname(fn))