diff --git a/docs/api/behavior.rst b/docs/api/behavior.rst new file mode 100644 index 00000000..a239af61 --- /dev/null +++ b/docs/api/behavior.rst @@ -0,0 +1,23 @@ +Behavior +-------- + +Utilities to implement array behaviors for dask-awkward arrays. + + +.. currentmodule:: dask_awkward + + +.. autosummary:: + :toctree: generated/ + + dask_property + +.. autosummary:: + :toctree: generated/ + + dask_method + +.. raw:: html + + diff --git a/docs/index.rst b/docs/index.rst index 2626eb37..9bb01088 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -49,6 +49,7 @@ Table of Contents api/io.rst api/reducers.rst api/structure.rst + api/behavior.rst .. toctree:: :maxdepth: 1 diff --git a/src/dask_awkward/__init__.py b/src/dask_awkward/__init__.py index 0bcbcf52..a9726941 100644 --- a/src/dask_awkward/__init__.py +++ b/src/dask_awkward/__init__.py @@ -13,6 +13,8 @@ from dask_awkward.lib.core import _type as type from dask_awkward.lib.core import ( compatible_partitions, + dask_method, + dask_property, map_partitions, partition_compatibility, ) diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index 0babd032..8b871690 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -1,6 +1,5 @@ from __future__ import annotations -import inspect import keyword import logging import math @@ -10,6 +9,7 @@ from collections.abc import Callable, Hashable, Mapping, Sequence from enum import IntEnum from functools import cached_property, partial, wraps +from inspect import getattr_static from numbers import Number from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload @@ -69,6 +69,184 @@ log = logging.getLogger(__name__) +def _make_dask_descriptor(func: Callable) -> Callable[[T, type[T], Array], Any]: + """Adapt a function accepting a `dask_array` into a dask-awkward descriptor + that invokes and returns the user function when invoked. + + Parameters + ---------- + func : Callable dask-awkward descriptor body + + Returns + ------- + Callable + The callable dask-awkward descriptor + """ + + def descriptor(instance: T, owner: type[T], dask_array: Array) -> Any: + impl = func.__get__(instance, owner) + return impl(dask_array) + + return descriptor + + +def _make_dask_method(func: Callable) -> Callable[[T, type[T], Array], Callable]: + """Adapt a function accepting a `dask_array` and additional arguments into + a dask-awkward descriptor that invokes and returns the bound user function. + + Parameters + ---------- + func : Callable + The dask-awkward descriptor body. + + Returns + ------- + Callable + The callable dask-awkward descriptor. + """ + + def descriptor(instance: T, owner: type[T], dask_array: Array) -> Any: + def impl(*args, **kwargs): + impl = func.__get__(instance, owner) + return impl(dask_array, *args, **kwargs) + + return impl + + return descriptor + + +F = TypeVar("F", bound=Callable) +G = TypeVar("G", bound=Callable) + + +class _DaskProperty(property): + """A property descriptor that exposes a `.dask` method for registering + dask-awkward descriptor implementations. + """ + + _dask_get: Callable | None = None + + def dask(self, func: F) -> _DaskProperty: + assert self._dask_get is None + self._dask_get = _make_dask_descriptor(func) + return self + + +def _adapt_naive_dask_get(func: Callable) -> Callable: + """Adapt a non-dask-awkward user-defined descriptor function into + a dask-awkward aware descriptor that invokes the original function. + + Parameters + ---------- + func : Callable + The non-dask-awkward descriptor body. + + Returns + ------- + Callable + The callable dask-awkward aware descriptor body. + """ + + def wrapper(self, dask_array, *args, **kwargs): + return func(self, *args, **kwargs) + + return wrapper + + +@overload +def dask_property(maybe_func: Callable, *, no_dispatch: bool = False) -> _DaskProperty: + """An extension of Python's built-in `property` that supports registration + of a dask getter via `.dask`. + + Parameters + ---------- + maybe_func : Callable, optional + The property getter function. + no_dispatch : bool + If True, re-use the main getter function as the Dask implementation. + + Returns + ------- + Callable + The callable dask-awkward aware descriptor factory or the descriptor itself + """ + + +@overload +def dask_property( + maybe_func: None = None, *, no_dispatch: bool = False +) -> Callable[[Callable], _DaskProperty]: + """An extension of Python's built-in `property` that supports registration + of a dask getter via `.dask`. + + Parameters + ---------- + maybe_func : Callable, optional + The property getter function. + no_dispatch : bool + If True, re-use the main getter function as the Dask implementation. + + Returns + ------- + Callable + The callable dask-awkward aware descriptor factory or the descriptor itself + """ + ... + + +def dask_property(maybe_func=None, *, no_dispatch=False): + """An extension of Python's built-in `property` that supports registration + of a dask getter via `.dask`. + + Parameters + ---------- + maybe_func : Callable, optional + The property getter function. + no_dispatch : bool + If True, re-use the main getter function as the Dask implementation + + Returns + ------- + Callable + The callable dask-awkward aware descriptor factory or the descriptor itself + """ + + def dask_property_wrapper(func: Callable) -> _DaskProperty: + prop = _DaskProperty(func) + if no_dispatch: + return prop.dask(_adapt_naive_dask_get(func)) + else: + return prop + + if maybe_func is None: + return dask_property_wrapper + else: + return dask_property_wrapper(maybe_func) + + +def dask_method(func: F) -> F: + """Decorate an instance method to provide a mechanism for overriding the + implementation for dask-awkward arrays via `.dask`. + + Parameters + ---------- + func : Callable + The method implementation to decorate. + + Returns + ------- + Callable + The callable dask-awkward aware method. + """ + + def dask(dask_func_impl: G) -> F: + func._dask_get = _make_dask_method(dask_func_impl) # type: ignore + return func + + func.dask = dask # type: ignore + return func + + class Scalar(DaskMethodsMixin, DaskOperatorMethodMixin): """Single partition Dask collection representing a lazy Scalar. @@ -1218,101 +1396,34 @@ def __getitem__(self, where): return self._getitem_single(where) - def _call_behavior_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any: - """Call a behavior method for an awkward array. - If the function signature has __dunder__ parameters it is assumed that the - user wants to do the map_partitions dispatch themselves and the _meta's - behavior is called. - If there are no __dunder__ parameters in the function call then the function - is wrapped in map_partitions automatically. - """ - if hasattr(self._meta, method_name): - themethod = getattr(self._meta, method_name) - thesig = inspect.signature(themethod) - if "_dask_array_" in thesig.parameters: - if "_dask_array_" not in kwargs: - kwargs["_dask_array_"] = self - return themethod(*args, **kwargs) - return self.map_partitions( - _BehaviorMethodFn(method_name, **kwargs), - *args, - label=hyphenize(method_name), - ) - - raise AttributeError( - f"Method {method_name} is not available to this collection." - ) - - def _call_behavior_property(self, property_name: str) -> Any: - """Call a property for an awkward array. - This also allows for some internal state to be tracked via behaviors - if a user follows the pattern: - - class SomeMixin: - - @property - def the_property(self): - ... - - @property - def a_property(array_context=None) # note: this can be any name - - This pattern is caught if the property has an argument that single - argument is assumed to be the array context (i.e. self) so that self- - referenced re-indexing operations can be hidden in properties. The - user must do the appropriate dispatch of map_partitions. - - If there is no argument the property call is wrapped in map_partitions. - """ - if hasattr(self._meta.__class__, property_name): - thegetter = getattr(self._meta.__class__, property_name).fget.__get__( - self._meta - ) - thesig = inspect.signature(thegetter) - - if len(thesig.parameters) == 1: - binding = thesig.bind(self) - return thegetter(*binding.args, **binding.kwargs) - elif len(thesig.parameters) > 1: - raise RuntimeError( - "Parametrized property cannot have more than one argument, the array context!" - ) - return self.map_partitions( - _BehaviorPropertyFn(property_name), - label=hyphenize(property_name), - ) - raise AttributeError( - f"Property {property_name} is not available to this collection." - ) - - def _maybe_behavior_method(self, attr: str) -> bool: - try: - res = getattr(self._meta.__class__, attr) - return (not isinstance(res, property)) and callable(res) - except AttributeError: - return False - - def _maybe_behavior_property(self, attr: str) -> bool: - try: - res = getattr(self._meta.__class__, attr) - return isinstance(res, property) - except AttributeError: - return False + def _is_method_heuristic(self, resolved: Any) -> bool: + return callable(resolved) def __getattr__(self, attr: str) -> Any: if attr not in (self.fields or []): - # check for possible behavior method - if self._maybe_behavior_method(attr): - - def wrapper(*args, **kwargs): - return self._call_behavior_method(attr, *args, **kwargs) - - return wrapper - # check for possible behavior property - elif self._maybe_behavior_property(attr): - return self._call_behavior_property(attr) - - raise AttributeError(f"{attr} not in fields.") + try: + cls_method = getattr_static(self._meta, attr) + except AttributeError: + raise AttributeError(f"{attr} not in fields.") + else: + if hasattr(cls_method, "_dask_get"): + return cls_method._dask_get(self._meta, type(self._meta), self) + elif self._is_method_heuristic(cls_method): + + @wraps(cls_method) + def wrapper(*args, **kwargs): + return self.map_partitions( + _BehaviorMethodFn(attr, **kwargs), + *args, + label=hyphenize(attr), + ) + + return wrapper + else: + return self.map_partitions( + _BehaviorPropertyFn(attr), + label=hyphenize(attr), + ) try: # at this point attr is either a field or we'll have to # raise an exception. diff --git a/tests/test_behavior.py b/tests/test_behavior.py index 67d4f78a..8000a4a6 100644 --- a/tests/test_behavior.py +++ b/tests/test_behavior.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import no_type_check + import awkward as ak import numpy as np import pytest @@ -23,12 +25,26 @@ def x2(self): def point_abs(self): return np.sqrt(self.x**2 + self.y**2) - @property - def non_dask_property(self, _dask_array_=None): + @dak.dask_property + def some_property(self): return "this is a non-dask property" - def non_dask_method(self, _dask_array_=None): - return _dask_array_ + @some_property.dask + def some_property_dask(self, array): + return f"this is a dask property ({type(array).__name__})" + + @dak.dask_property(no_dispatch=True) + def some_property_both(self): + return "this is a dask AND non-dask property" + + @dak.dask_method + def some_method(self): + return None + + @no_type_check + @some_method.dask + def some_method_dask(self, array): + return array @pytest.mark.xfail( @@ -60,9 +76,17 @@ def test_property_behavior(daa_p1: dak.Array, caa_p1: ak.Array) -> None: assert daa.behavior == caa.behavior - assert daa.non_dask_property == caa.non_dask_property + assert caa.some_property == "this is a non-dask property" + assert daa.some_property == "this is a dask property (Array)" + + assert repr(daa.some_method()) == repr(daa) + assert repr(caa.some_method()) == repr(None) - assert repr(daa.non_dask_method()) == repr(daa) + assert ( + daa.some_property_both + == caa.some_property_both + == "this is a dask AND non-dask property" + ) @pytest.mark.xfail( @@ -73,18 +97,6 @@ def test_nonexistent_behavior(daa_p1: dak.Array, daa_p2: dak.Array) -> None: daa1 = dak.with_name(daa_p1["points"], "Point", behavior=behaviors) daa2 = daa_p2 - with pytest.raises( - AttributeError, - match="Method doesnotexist is not available to this collection", - ): - daa1._call_behavior_method("doesnotexist", daa2) - - with pytest.raises( - AttributeError, - match="Property doesnotexist is not available to this collection", - ): - daa1._call_behavior_property("doesnotexist") - # in this case the field check is where we raise with pytest.raises(AttributeError, match="distance not in fields"): daa2.distance(daa1)