Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: private map partitions #477

Merged
merged 3 commits into from
Feb 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 92 additions & 46 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,7 @@ def _getitem_trivial_map_partitions(
else:
m = to_meta([where])[0]
meta = self._meta[m]
return map_partitions(
return _map_partitions(
operator.getitem,
self,
where,
Expand All @@ -1232,7 +1232,7 @@ def _getitem_outer_bool_or_int_lazy_array(self, where):
)

new_meta = self._meta[where._meta]
return self.map_partitions(
return self._map_partitions(
operator.getitem,
where,
meta=new_meta,
Expand Down Expand Up @@ -1539,6 +1539,15 @@ def wrapper(*args, **kwargs):
except (IndexError, KeyError):
raise AttributeError(f"{attr} not in fields.")

def _map_partitions(
self,
func: Callable,
*args: Any,
**kwargs: Any,
) -> Array:
"""Maps a function to partitions without flattening the function inputs."""
return _map_partitions(func, self, *args, **kwargs)

def map_partitions(
self,
func: Callable,
Expand Down Expand Up @@ -1890,6 +1899,78 @@ def __call__(self, *args_deps_expanded):
return self.fn(*args, **kwargs)


def _map_partitions(
fn: Callable,
*args: Any,
label: str | None = None,
token: str | None = None,
meta: Any | None = None,
output_divisions: int | None = None,
**kwargs: Any,
) -> Array:
"""Map a callable across all partitions of any number of collections.
No wrapper is used to flatten the function arguments. This is meant for
dask-awkward internal use or in situations where input data are sanitized.

The parameters of this function are otherwise the same as map_partitions,
but the limitation that args, kwargs must be non-nested and flat. They
will not be traversed to extract all dask collections, except those in
the first dimension of args or kwargs.
"""
token = token or tokenize(fn, *args, meta, **kwargs)
label = hyphenize(label or funcname(fn))
name = f"{label}-{token}"

deps = [a for a in args if is_dask_collection(a)] + [
v for v in kwargs.values() if is_dask_collection(v)
]

dak_arrays = tuple(filter(lambda x: isinstance(x, Array), deps))

lay = partitionwise_layer(
fn,
name,
*args,
**kwargs,
)

if meta is None:
meta = map_meta(fn, *args, **kwargs)

hlg = HighLevelGraph.from_collections(
name,
lay,
dependencies=deps,
)

if len(dak_arrays) == 0:
raise TypeError(
"at least one argument passed to map_partitions "
"should be a dask_awkward.Array collection."
)
in_npartitions = dak_arrays[0].npartitions
in_divisions = dak_arrays[0].divisions

if output_divisions is not None:
if output_divisions == 1:
new_divisions = dak_arrays[0].divisions
else:
new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions))
return new_array_object(
hlg,
name=name,
meta=meta,
divisions=new_divisions,
)
else:
return new_array_object(
hlg,
name=name,
meta=meta,
npartitions=in_npartitions,
)


def map_partitions(
base_fn: Callable,
*args: Any,
Expand Down Expand Up @@ -1971,6 +2052,9 @@ def map_partitions(
This is effectively the same as `d = c * a`

"""
token = token or tokenize(base_fn, *args, meta, **kwargs)
label = hyphenize(label or funcname(base_fn))

opt_touch_all = kwargs.pop("opt_touch_all", None)
if opt_touch_all is not None:
warnings.warn(
Expand All @@ -1979,9 +2063,6 @@ def map_partitions(
"and the function call will likely fail."
)

token = token or tokenize(base_fn, *args, meta, **kwargs)
label = hyphenize(label or funcname(base_fn))
name = f"{label}-{token}"
kwarg_flat_deps, kwarg_repacker = unpack_collections(kwargs, traverse=traverse)
flat_deps, _ = unpack_collections(*args, *kwargs.values(), traverse=traverse)

Expand Down Expand Up @@ -2017,51 +2098,16 @@ def map_partitions(
kwarg_repacker,
arg_lens_for_repackers,
)

lay = partitionwise_layer(
return _map_partitions(
fn,
name,
*arg_flat_deps_expanded,
*kwarg_flat_deps,
label=label,
token=token,
meta=meta,
output_divisions=output_divisions,
)

if meta is None:
meta = map_meta(fn, *arg_flat_deps_expanded, *kwarg_flat_deps)

hlg = HighLevelGraph.from_collections(
name,
lay,
dependencies=flat_deps,
)

dak_arrays = tuple(filter(lambda x: isinstance(x, Array), flat_deps))
if len(dak_arrays) == 0:
raise TypeError(
"at least one argument passed to map_partitions "
"should be a dask_awkward.Array collection."
)
in_npartitions = dak_arrays[0].npartitions
in_divisions = dak_arrays[0].divisions

if output_divisions is not None:
if output_divisions == 1:
new_divisions = flat_deps[0].divisions
else:
new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions))
return new_array_object(
hlg,
name=name,
meta=meta,
divisions=new_divisions,
)
else:
return new_array_object(
hlg,
name=name,
meta=meta,
npartitions=in_npartitions,
)


def _chunk_reducer_non_positional(
chunk: ak.Array,
Expand Down Expand Up @@ -2408,7 +2454,7 @@ def to_length_zero_arrays(objects: Sequence[Any]) -> tuple[Any, ...]:
return tuple(map(length_zero_array_or_identity, objects))


def map_meta(fn: ArgsKwargsPackedFunction, *deps: Any) -> ak.Array | None:
def map_meta(fn: Callable | ArgsKwargsPackedFunction, *deps: Any) -> ak.Array | None:
# NOTE: fn is assumed to be a *packed* function
# as defined up in map_partitions. be careful!
try:
Expand Down