diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 338cf80f..1828e2dd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/psf/black-pre-commit-mirror - rev: 23.10.0 + rev: 23.10.1 hooks: - id: black language_version: python3 @@ -24,7 +24,7 @@ repos: - --target-version=py312 - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.1 + rev: v0.1.2 hooks: - id: ruff diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index afb0b8df..51b9e66d 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -9,7 +9,7 @@ import warnings from collections.abc import Callable, Hashable, Sequence from enum import IntEnum -from functools import cached_property, partial +from functools import cached_property, partial, wraps from numbers import Number from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload @@ -37,7 +37,9 @@ from dask.delayed import Delayed from dask.highlevelgraph import HighLevelGraph from dask.threaded import get as threaded_get -from dask.utils import IndexCallable, funcname, is_arraylike, key_split +from dask.utils import IndexCallable +from dask.utils import OperatorMethodMixin as DaskOperatorMethodMixin +from dask.utils import funcname, is_arraylike, key_split from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardMaterializedLayer from dask_awkward.lib.optimize import all_optimizations @@ -66,7 +68,7 @@ log = logging.getLogger(__name__) -class Scalar(DaskMethodsMixin): +class Scalar(DaskMethodsMixin, DaskOperatorMethodMixin): """Single partition Dask collection representing a lazy Scalar. The class constructor is not intended for users. Instances of this @@ -140,6 +142,8 @@ def key(self) -> Key: def _check_meta(self, m: Any) -> Any | None: if isinstance(m, (MaybeNone, OneOf)) or is_unknown_scalar(m): return m + elif isinstance(m, ak.Array) and len(m) == 1: + return m raise TypeError(f"meta must be a typetracer object, not a {type(m)}") @property @@ -199,10 +203,6 @@ def __getitem__(self, where: Any) -> Any: hlg = HighLevelGraph.from_collections(name, task, dependencies=[self]) return new_scalar_object(hlg, name, meta=None) - def __getattr__(self, attr: str) -> Any: - d = self.to_delayed(optimize_graph=True) - return getattr(d, attr) - @property def known_value(self) -> Any | None: return self._known_value @@ -230,6 +230,105 @@ def to_delayed(self, optimize_graph: bool = True) -> Delayed: dsk = HighLevelGraph.from_collections(layer, dsk, dependencies=()) return Delayed(self.key, dsk, layer=layer) + def __getattr__(self, attr): + if attr.startswith("_"): + raise AttributeError # pragma: no cover + msg = ( + "Attribute access on Scalars should be done after converting " + "the Scalar collection to delayed with the to_delayed method." + ) + raise AttributeError(msg) + + @classmethod + def _get_binary_operator(cls, op, inv=False): + def f(self, other): + name = f"{op.__name__}-{tokenize(self, other)}" + deps = [self] + plns = [self.name] + if is_dask_collection(other): + task = (op, self.key, *other.__dask_keys__()) + deps.append(other) + plns.append(other.name) + else: + task = (op, self.key, other) + graph = HighLevelGraph.from_collections( + name, + layer=AwkwardMaterializedLayer( + {(name, 0): task}, + previous_layer_names=plns, + fn=op, + ), + dependencies=tuple(deps), + ) + return new_scalar_object(graph, name, meta=None) + + return f + + @classmethod + def _get_unary_operator(cls, op, inv=False): + def f(self): + name = f"{op.__name__}-{tokenize(self)}" + layer = AwkwardMaterializedLayer( + {(name, 0): (op, self.key)}, + previous_layer_names=[self.name], + ) + graph = HighLevelGraph.from_collections( + name, + layer, + dependencies=(self,), + ) + return new_scalar_object(graph, name, meta=None) + + return f + + +def _promote_maybenones(op: Callable) -> Callable: + """Wrap `op` function such that MaybeNone arguments are promoted. + + Typetracer graphs (i.e. what is run by our necessary buffers + optimization) need `MaybeNone` results to be promoted to length 1 + typetracer arrays. MaybeNone objects don't support these ops, but + arrays do. + + """ + + @wraps(op) + def f(*args): + args = tuple( + ak.Array(arg.content) if isinstance(arg, MaybeNone) else arg for arg in args + ) + result = op(*args) + return result + + return f + + +for op in [ + _promote_maybenones(operator.abs), + _promote_maybenones(operator.neg), + _promote_maybenones(operator.pos), + _promote_maybenones(operator.invert), + _promote_maybenones(operator.add), + _promote_maybenones(operator.sub), + _promote_maybenones(operator.mul), + _promote_maybenones(operator.floordiv), + _promote_maybenones(operator.truediv), + _promote_maybenones(operator.mod), + _promote_maybenones(operator.pow), + _promote_maybenones(operator.and_), + _promote_maybenones(operator.or_), + _promote_maybenones(operator.xor), + _promote_maybenones(operator.lshift), + _promote_maybenones(operator.rshift), + _promote_maybenones(operator.eq), + _promote_maybenones(operator.ge), + _promote_maybenones(operator.gt), + _promote_maybenones(operator.ne), + _promote_maybenones(operator.le), + _promote_maybenones(operator.lt), +]: + Scalar._bind_operator(op) + def new_scalar_object(dsk: HighLevelGraph, name: str, *, meta: Any) -> Scalar: """Instantiate a new scalar collection. @@ -250,10 +349,10 @@ def new_scalar_object(dsk: HighLevelGraph, name: str, *, meta: Any) -> Scalar: """ if meta is None: - meta = TypeTracerArray._new(dtype=np.dtype(None), shape=()) + meta = ak.Array(TypeTracerArray._new(dtype=np.dtype(None), shape=())) if isinstance(meta, MaybeNone): - pass + meta = ak.Array(meta.content) else: try: if ak.backend(meta) != "typetracer": @@ -619,13 +718,13 @@ def __iter__(self): ) def _ipython_display_(self): - return self._meta._ipython_display_() + return self._meta._ipython_display_() # pragma: no cover def _ipython_canary_method_should_not_exist_(self): - return self._meta._ipython_canary_method_should_not_exist_() + return self._meta._ipython_canary_method_should_not_exist_() # pragma: no cover def _repr_mimebundle_(self): - return self._meta._repr_mimebundle_() + return self._meta._repr_mimebundle_() # pragma: no cover def _ipython_key_completions_(self) -> list[str]: if self._meta is not None: diff --git a/tests/test_core.py b/tests/test_core.py index b207e3f8..f77cdb78 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2,6 +2,7 @@ import copy import json +import operator import sys from collections import namedtuple from collections.abc import Callable @@ -209,7 +210,34 @@ def test_scalar_getitem_getattr() -> None: Thing = namedtuple("Thing", "a b c") t = Thing(c=3, b=2, a=1) s = new_known_scalar(t) - assert s.c.compute() == t.c + with pytest.raises(AttributeError, match="should be done after converting"): + s.c.compute() + assert s.to_delayed().c.compute() == t.c + + +@pytest.mark.parametrize("op", [operator.add, operator.truediv, operator.mul]) +def test_scalar_binary_ops(op: Callable, daa: Array, caa: ak.Array) -> None: + a1 = dak.max(daa.points.x, axis=None) + b1 = dak.min(daa.points.y, axis=None) + a2 = ak.max(caa.points.x, axis=None) + b2 = ak.min(caa.points.y, axis=None) + assert_eq(op(a1, b1), op(a2, b2)) + + +@pytest.mark.parametrize("op", [operator.add, operator.truediv, operator.mul]) +def test_scalar_binary_ops_other_not_dak( + op: Callable, daa: Array, caa: ak.Array +) -> None: + a1 = dak.max(daa.points.x, axis=None) + a2 = ak.max(caa.points.x, axis=None) + assert_eq(op(a1, 5), op(a2, 5)) + + +@pytest.mark.parametrize("op", [operator.abs]) +def test_scalar_unary_ops(op: Callable, daa: Array, caa: ak.Array) -> None: + a1 = dak.max(daa.points.x, axis=None) + a2 = ak.max(caa.points.x, axis=None) + assert_eq(op(-a1), op(-a2)) @pytest.mark.parametrize(