Skip to content

Commit

Permalink
feat!: add Scalar binary and unary ops; adjust Scalar.__getattr__ (
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis authored Oct 26, 2023
1 parent f2ef9f1 commit 816cb30
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 15 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.10.0
rev: 23.10.1
hooks:
- id: black
language_version: python3
Expand All @@ -24,7 +24,7 @@ repos:
- --target-version=py312

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.1
rev: v0.1.2
hooks:
- id: ruff

Expand Down
123 changes: 111 additions & 12 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
from collections.abc import Callable, Hashable, Sequence
from enum import IntEnum
from functools import cached_property, partial
from functools import cached_property, partial, wraps
from numbers import Number
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload

Expand Down Expand Up @@ -37,7 +37,9 @@
from dask.delayed import Delayed
from dask.highlevelgraph import HighLevelGraph
from dask.threaded import get as threaded_get
from dask.utils import IndexCallable, funcname, is_arraylike, key_split
from dask.utils import IndexCallable
from dask.utils import OperatorMethodMixin as DaskOperatorMethodMixin
from dask.utils import funcname, is_arraylike, key_split

from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardMaterializedLayer
from dask_awkward.lib.optimize import all_optimizations
Expand Down Expand Up @@ -66,7 +68,7 @@
log = logging.getLogger(__name__)


class Scalar(DaskMethodsMixin):
class Scalar(DaskMethodsMixin, DaskOperatorMethodMixin):
"""Single partition Dask collection representing a lazy Scalar.
The class constructor is not intended for users. Instances of this
Expand Down Expand Up @@ -140,6 +142,8 @@ def key(self) -> Key:
def _check_meta(self, m: Any) -> Any | None:
if isinstance(m, (MaybeNone, OneOf)) or is_unknown_scalar(m):
return m
elif isinstance(m, ak.Array) and len(m) == 1:
return m
raise TypeError(f"meta must be a typetracer object, not a {type(m)}")

@property
Expand Down Expand Up @@ -199,10 +203,6 @@ def __getitem__(self, where: Any) -> Any:
hlg = HighLevelGraph.from_collections(name, task, dependencies=[self])
return new_scalar_object(hlg, name, meta=None)

def __getattr__(self, attr: str) -> Any:
d = self.to_delayed(optimize_graph=True)
return getattr(d, attr)

@property
def known_value(self) -> Any | None:
return self._known_value
Expand Down Expand Up @@ -230,6 +230,105 @@ def to_delayed(self, optimize_graph: bool = True) -> Delayed:
dsk = HighLevelGraph.from_collections(layer, dsk, dependencies=())
return Delayed(self.key, dsk, layer=layer)

def __getattr__(self, attr):
if attr.startswith("_"):
raise AttributeError # pragma: no cover
msg = (
"Attribute access on Scalars should be done after converting "
"the Scalar collection to delayed with the to_delayed method."
)
raise AttributeError(msg)

@classmethod
def _get_binary_operator(cls, op, inv=False):
def f(self, other):
name = f"{op.__name__}-{tokenize(self, other)}"
deps = [self]
plns = [self.name]
if is_dask_collection(other):
task = (op, self.key, *other.__dask_keys__())
deps.append(other)
plns.append(other.name)
else:
task = (op, self.key, other)
graph = HighLevelGraph.from_collections(
name,
layer=AwkwardMaterializedLayer(
{(name, 0): task},
previous_layer_names=plns,
fn=op,
),
dependencies=tuple(deps),
)
return new_scalar_object(graph, name, meta=None)

return f

@classmethod
def _get_unary_operator(cls, op, inv=False):
def f(self):
name = f"{op.__name__}-{tokenize(self)}"
layer = AwkwardMaterializedLayer(
{(name, 0): (op, self.key)},
previous_layer_names=[self.name],
)
graph = HighLevelGraph.from_collections(
name,
layer,
dependencies=(self,),
)
return new_scalar_object(graph, name, meta=None)

return f


def _promote_maybenones(op: Callable) -> Callable:
"""Wrap `op` function such that MaybeNone arguments are promoted.
Typetracer graphs (i.e. what is run by our necessary buffers
optimization) need `MaybeNone` results to be promoted to length 1
typetracer arrays. MaybeNone objects don't support these ops, but
arrays do.
"""

@wraps(op)
def f(*args):
args = tuple(
ak.Array(arg.content) if isinstance(arg, MaybeNone) else arg for arg in args
)
result = op(*args)
return result

return f


for op in [
_promote_maybenones(operator.abs),
_promote_maybenones(operator.neg),
_promote_maybenones(operator.pos),
_promote_maybenones(operator.invert),
_promote_maybenones(operator.add),
_promote_maybenones(operator.sub),
_promote_maybenones(operator.mul),
_promote_maybenones(operator.floordiv),
_promote_maybenones(operator.truediv),
_promote_maybenones(operator.mod),
_promote_maybenones(operator.pow),
_promote_maybenones(operator.and_),
_promote_maybenones(operator.or_),
_promote_maybenones(operator.xor),
_promote_maybenones(operator.lshift),
_promote_maybenones(operator.rshift),
_promote_maybenones(operator.eq),
_promote_maybenones(operator.ge),
_promote_maybenones(operator.gt),
_promote_maybenones(operator.ne),
_promote_maybenones(operator.le),
_promote_maybenones(operator.lt),
]:
Scalar._bind_operator(op)


def new_scalar_object(dsk: HighLevelGraph, name: str, *, meta: Any) -> Scalar:
"""Instantiate a new scalar collection.
Expand All @@ -250,10 +349,10 @@ def new_scalar_object(dsk: HighLevelGraph, name: str, *, meta: Any) -> Scalar:
"""
if meta is None:
meta = TypeTracerArray._new(dtype=np.dtype(None), shape=())
meta = ak.Array(TypeTracerArray._new(dtype=np.dtype(None), shape=()))

if isinstance(meta, MaybeNone):
pass
meta = ak.Array(meta.content)
else:
try:
if ak.backend(meta) != "typetracer":
Expand Down Expand Up @@ -619,13 +718,13 @@ def __iter__(self):
)

def _ipython_display_(self):
return self._meta._ipython_display_()
return self._meta._ipython_display_() # pragma: no cover

def _ipython_canary_method_should_not_exist_(self):
return self._meta._ipython_canary_method_should_not_exist_()
return self._meta._ipython_canary_method_should_not_exist_() # pragma: no cover

def _repr_mimebundle_(self):
return self._meta._repr_mimebundle_()
return self._meta._repr_mimebundle_() # pragma: no cover

def _ipython_key_completions_(self) -> list[str]:
if self._meta is not None:
Expand Down
30 changes: 29 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
import json
import operator
import sys
from collections import namedtuple
from collections.abc import Callable
Expand Down Expand Up @@ -209,7 +210,34 @@ def test_scalar_getitem_getattr() -> None:
Thing = namedtuple("Thing", "a b c")
t = Thing(c=3, b=2, a=1)
s = new_known_scalar(t)
assert s.c.compute() == t.c
with pytest.raises(AttributeError, match="should be done after converting"):
s.c.compute()
assert s.to_delayed().c.compute() == t.c


@pytest.mark.parametrize("op", [operator.add, operator.truediv, operator.mul])
def test_scalar_binary_ops(op: Callable, daa: Array, caa: ak.Array) -> None:
a1 = dak.max(daa.points.x, axis=None)
b1 = dak.min(daa.points.y, axis=None)
a2 = ak.max(caa.points.x, axis=None)
b2 = ak.min(caa.points.y, axis=None)
assert_eq(op(a1, b1), op(a2, b2))


@pytest.mark.parametrize("op", [operator.add, operator.truediv, operator.mul])
def test_scalar_binary_ops_other_not_dak(
op: Callable, daa: Array, caa: ak.Array
) -> None:
a1 = dak.max(daa.points.x, axis=None)
a2 = ak.max(caa.points.x, axis=None)
assert_eq(op(a1, 5), op(a2, 5))


@pytest.mark.parametrize("op", [operator.abs])
def test_scalar_unary_ops(op: Callable, daa: Array, caa: ak.Array) -> None:
a1 = dak.max(daa.points.x, axis=None)
a2 = ak.max(caa.points.x, axis=None)
assert_eq(op(-a1), op(-a2))


@pytest.mark.parametrize(
Expand Down

0 comments on commit 816cb30

Please sign in to comment.