Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add new dask-behavior protocol #409

Merged
merged 10 commits into from
Nov 30, 2023
23 changes: 23 additions & 0 deletions docs/api/behavior.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
Behavior
--------

Utilities to implement array behaviors for dask-awkward arrays.


.. currentmodule:: dask_awkward


.. autosummary::
:toctree: generated/

dask_property

.. autosummary::
:toctree: generated/

dask_method

.. raw:: html

<script data-goatcounter="https://dask-awkward.goatcounter.com/count"
async src="//gc.zgo.at/count.js"></script>
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Table of Contents
api/io.rst
api/reducers.rst
api/structure.rst
api/behavior.rst

.. toctree::
:maxdepth: 1
Expand Down
2 changes: 2 additions & 0 deletions src/dask_awkward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from dask_awkward.lib.core import _type as type
from dask_awkward.lib.core import (
compatible_partitions,
dask_method,
dask_property,
map_partitions,
partition_compatibility,
)
Expand Down
297 changes: 204 additions & 93 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import inspect
import keyword
import logging
import math
Expand All @@ -10,6 +9,7 @@
from collections.abc import Callable, Hashable, Mapping, Sequence
from enum import IntEnum
from functools import cached_property, partial, wraps
from inspect import getattr_static
from numbers import Number
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload

Expand Down Expand Up @@ -69,6 +69,184 @@
log = logging.getLogger(__name__)


def _make_dask_descriptor(func: Callable) -> Callable[[T, type[T], Array], Any]:
"""Adapt a function accepting a `dask_array` into a dask-awkward descriptor
that invokes and returns the user function when invoked.

Parameters
----------
func : Callable dask-awkward descriptor body

Returns
-------
Callable
The callable dask-awkward descriptor
"""

def descriptor(instance: T, owner: type[T], dask_array: Array) -> Any:
impl = func.__get__(instance, owner)
return impl(dask_array)

return descriptor


def _make_dask_method(func: Callable) -> Callable[[T, type[T], Array], Callable]:
"""Adapt a function accepting a `dask_array` and additional arguments into
a dask-awkward descriptor that invokes and returns the bound user function.

Parameters
----------
func : Callable
The dask-awkward descriptor body.

Returns
-------
Callable
The callable dask-awkward descriptor.
"""

def descriptor(instance: T, owner: type[T], dask_array: Array) -> Any:
def impl(*args, **kwargs):
impl = func.__get__(instance, owner)
return impl(dask_array, *args, **kwargs)

return impl

return descriptor


F = TypeVar("F", bound=Callable)
G = TypeVar("G", bound=Callable)


class _DaskProperty(property):
"""A property descriptor that exposes a `.dask` method for registering
dask-awkward descriptor implementations.
"""

_dask_get: Callable | None = None

def dask(self, func: F) -> _DaskProperty:
assert self._dask_get is None
self._dask_get = _make_dask_descriptor(func)
return self


def _adapt_naive_dask_get(func: Callable) -> Callable:
"""Adapt a non-dask-awkward user-defined descriptor function into
a dask-awkward aware descriptor that invokes the original function.

Parameters
----------
func : Callable
The non-dask-awkward descriptor body.

Returns
-------
Callable
The callable dask-awkward aware descriptor body.
"""

def wrapper(self, dask_array, *args, **kwargs):
return func(self, *args, **kwargs)

return wrapper


@overload
def dask_property(maybe_func: Callable, *, no_dispatch: bool = False) -> _DaskProperty:
"""An extension of Python's built-in `property` that supports registration
of a dask getter via `.dask`.

Parameters
----------
maybe_func : Callable, optional
The property getter function.
no_dispatch : bool
If True, re-use the main getter function as the Dask implementation.

Returns
-------
Callable
The callable dask-awkward aware descriptor factory or the descriptor itself
"""


@overload
def dask_property(
maybe_func: None = None, *, no_dispatch: bool = False
) -> Callable[[Callable], _DaskProperty]:
"""An extension of Python's built-in `property` that supports registration
of a dask getter via `.dask`.

Parameters
----------
maybe_func : Callable, optional
The property getter function.
no_dispatch : bool
If True, re-use the main getter function as the Dask implementation.

Returns
-------
Callable
The callable dask-awkward aware descriptor factory or the descriptor itself
"""
...


def dask_property(maybe_func=None, *, no_dispatch=False):
"""An extension of Python's built-in `property` that supports registration
of a dask getter via `.dask`.

Parameters
----------
maybe_func : Callable, optional
The property getter function.
no_dispatch : bool
If True, re-use the main getter function as the Dask implementation

Returns
-------
Callable
The callable dask-awkward aware descriptor factory or the descriptor itself
"""

def dask_property_wrapper(func: Callable) -> _DaskProperty:
prop = _DaskProperty(func)
if no_dispatch:
return prop.dask(_adapt_naive_dask_get(func))
else:
return prop

if maybe_func is None:
return dask_property_wrapper
else:
return dask_property_wrapper(maybe_func)


def dask_method(func: F) -> F:
"""Decorate an instance method to provide a mechanism for overriding the
implementation for dask-awkward arrays via `.dask`.

Parameters
----------
func : Callable
The method implementation to decorate.

Returns
-------
Callable
The callable dask-awkward aware method.
"""

def dask(dask_func_impl: G) -> F:
func._dask_get = _make_dask_method(dask_func_impl) # type: ignore
return func

func.dask = dask # type: ignore
return func


class Scalar(DaskMethodsMixin, DaskOperatorMethodMixin):
"""Single partition Dask collection representing a lazy Scalar.

Expand Down Expand Up @@ -1218,101 +1396,34 @@ def __getitem__(self, where):

return self._getitem_single(where)

def _call_behavior_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
"""Call a behavior method for an awkward array.
If the function signature has __dunder__ parameters it is assumed that the
user wants to do the map_partitions dispatch themselves and the _meta's
behavior is called.
If there are no __dunder__ parameters in the function call then the function
is wrapped in map_partitions automatically.
"""
if hasattr(self._meta, method_name):
themethod = getattr(self._meta, method_name)
thesig = inspect.signature(themethod)
if "_dask_array_" in thesig.parameters:
if "_dask_array_" not in kwargs:
kwargs["_dask_array_"] = self
return themethod(*args, **kwargs)
return self.map_partitions(
_BehaviorMethodFn(method_name, **kwargs),
*args,
label=hyphenize(method_name),
)

raise AttributeError(
f"Method {method_name} is not available to this collection."
)

def _call_behavior_property(self, property_name: str) -> Any:
"""Call a property for an awkward array.
This also allows for some internal state to be tracked via behaviors
if a user follows the pattern:

class SomeMixin:

@property
def the_property(self):
...

@property
def a_property(array_context=None) # note: this can be any name

This pattern is caught if the property has an argument that single
argument is assumed to be the array context (i.e. self) so that self-
referenced re-indexing operations can be hidden in properties. The
user must do the appropriate dispatch of map_partitions.

If there is no argument the property call is wrapped in map_partitions.
"""
if hasattr(self._meta.__class__, property_name):
thegetter = getattr(self._meta.__class__, property_name).fget.__get__(
self._meta
)
thesig = inspect.signature(thegetter)

if len(thesig.parameters) == 1:
binding = thesig.bind(self)
return thegetter(*binding.args, **binding.kwargs)
elif len(thesig.parameters) > 1:
raise RuntimeError(
"Parametrized property cannot have more than one argument, the array context!"
)
return self.map_partitions(
_BehaviorPropertyFn(property_name),
label=hyphenize(property_name),
)
raise AttributeError(
f"Property {property_name} is not available to this collection."
)

def _maybe_behavior_method(self, attr: str) -> bool:
try:
res = getattr(self._meta.__class__, attr)
return (not isinstance(res, property)) and callable(res)
except AttributeError:
return False

def _maybe_behavior_property(self, attr: str) -> bool:
try:
res = getattr(self._meta.__class__, attr)
return isinstance(res, property)
except AttributeError:
return False
def _is_method_heuristic(self, resolved: Any) -> bool:
return callable(resolved)

def __getattr__(self, attr: str) -> Any:
if attr not in (self.fields or []):
# check for possible behavior method
if self._maybe_behavior_method(attr):

def wrapper(*args, **kwargs):
return self._call_behavior_method(attr, *args, **kwargs)

return wrapper
# check for possible behavior property
elif self._maybe_behavior_property(attr):
return self._call_behavior_property(attr)

raise AttributeError(f"{attr} not in fields.")
try:
cls_method = getattr_static(self._meta, attr)
except AttributeError:
raise AttributeError(f"{attr} not in fields.")
else:
if hasattr(cls_method, "_dask_get"):
return cls_method._dask_get(self._meta, type(self._meta), self)
elif self._is_method_heuristic(cls_method):

@wraps(cls_method)
def wrapper(*args, **kwargs):
return self.map_partitions(
_BehaviorMethodFn(attr, **kwargs),
*args,
label=hyphenize(attr),
)

return wrapper
else:
return self.map_partitions(
_BehaviorPropertyFn(attr),
label=hyphenize(attr),
)
try:
# at this point attr is either a field or we'll have to
# raise an exception.
Expand Down
Loading
Loading