Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: harden dtype handling for scalars #418

Merged
79 changes: 55 additions & 24 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,22 @@ def __init__(
self,
dsk: HighLevelGraph,
name: str,
meta: Any,
meta: Any | None = None,
dtype: DTypeLike | None = None,
known_value: Any | None = None,
) -> None:
if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(name, dsk, dependencies=()) # type: ignore
self._dask: HighLevelGraph = dsk
self._name: str = name
self._meta: Any = self._check_meta(meta)
if meta is not None and dtype is None:
self._meta = self._check_meta(meta)
self._dtype = np.dtype(self._meta.type.content.primitive)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the meta always has a NumpyType type, we could also just use self._meta.layout.dtype. I think that might be slightly better, because it uses Awkward's dtype→primitive conversion.

douglasdavis marked this conversation as resolved.
Show resolved Hide resolved
elif meta is None and dtype is not None:
self._meta = ak.Array(create_unknown_scalar(dtype))
self._dtype = dtype # type: ignore
else:
ValueError("One (and only one) of dtype or meta can be defined.")
self._known_value: Any | None = known_value

def __dask_graph__(self) -> Graph:
Expand Down Expand Up @@ -126,7 +134,7 @@ def _rebuild(self, dsk, *, rename=None):
return type(self)(dsk, name, self._meta, self.known_value)

def __reduce__(self):
return (Scalar, (self.dask, self.name, self._meta, self.known_value))
return (Scalar, (self.dask, self.name, None, self.dtype, self.known_value))

@property
def dask(self) -> HighLevelGraph:
Expand All @@ -140,23 +148,21 @@ def name(self) -> str:
def key(self) -> Key:
return (self._name, 0)

def _check_meta(self, m: Any) -> Any | None:
if m is None:
return m
elif isinstance(m, (MaybeNone, OneOf)) or is_unknown_scalar(m):
return m
def _check_meta(self, m):
if isinstance(m, MaybeNone):
return ak.Array(m.content)
elif isinstance(m, ak.Array) and len(m) == 1:
return m
raise TypeError(f"meta must be a typetracer object, not a {type(m)}")
elif isinstance(m, OneOf) or is_unknown_scalar(m):
if isinstance(m, TypeTracerArray):
return ak.Array(m)
else:
return m
raise TypeError(f"meta must be a typetracer, not a {type(m)}")

@property
def dtype(self) -> np.dtype | None:
try:
if self._meta is not None:
return self._meta.dtype
except AttributeError:
pass
return None
def dtype(self) -> np.dtype:
return self._dtype

@property
def npartitions(self) -> int:
Expand Down Expand Up @@ -185,15 +191,14 @@ def __repr__(self) -> str: # pragma: no cover
return self.__str__()

def __str__(self) -> str:
dt = self.dtype or "Unknown"
if self.known_value is not None:
return (
f"dask.awkward<{key_split(self.name)}, "
"type=Scalar, "
f"dtype={dt}, "
f"dtype={self.dtype}, "
f"known_value={self.known_value}>"
)
return f"dask.awkward<{key_split(self.name)}, type=Scalar, dtype={dt}>"
return f"dask.awkward<{key_split(self.name)}, type=Scalar, dtype={self.dtype}>"

def __getitem__(self, where: Any) -> Any:
msg = (
Expand Down Expand Up @@ -259,7 +264,11 @@ def f(self, other):
),
dependencies=tuple(deps),
)
return new_scalar_object(graph, name, meta=None)
if isinstance(other, Scalar):
meta = op(self._meta, other._meta)
else:
meta = op(self._meta, other)
return new_scalar_object(graph, name, meta=meta)

return f

Expand All @@ -276,7 +285,8 @@ def f(self):
layer,
dependencies=(self,),
)
return new_scalar_object(graph, name, meta=None)
meta = op(self._meta)
return new_scalar_object(graph, name, meta=meta)

return f

Expand Down Expand Up @@ -329,7 +339,13 @@ def f(*args):
Scalar._bind_operator(op)


def new_scalar_object(dsk: HighLevelGraph, name: str, *, meta: Any) -> Scalar:
def new_scalar_object(
dsk: HighLevelGraph,
name: str,
*,
meta: Any | None = None,
dtype: DTypeLike | None = None,
) -> Scalar:
"""Instantiate a new scalar collection.

Parameters
Expand All @@ -347,6 +363,14 @@ def new_scalar_object(dsk: HighLevelGraph, name: str, *, meta: Any) -> Scalar:
Resulting collection.

"""

if meta is not None and dtype is None:
pass
elif meta is None and dtype is not None:
meta = ak.Array(create_unknown_scalar(dtype))
else:
ValueError("One (and only one) of dtype or meta can be defined.")

if isinstance(meta, MaybeNone):
meta = ak.Array(meta.content)
elif meta is not None:
Expand Down Expand Up @@ -407,7 +431,12 @@ def new_known_scalar(
dtype = np.dtype(dtype)
llg = AwkwardMaterializedLayer({(name, 0): s}, previous_layer_names=[])
hlg = HighLevelGraph.from_collections(name, llg, dependencies=())
return Scalar(hlg, name, meta=create_unknown_scalar(dtype), known_value=s)
return Scalar(
hlg,
name,
dtype=dtype,
known_value=s,
)


class Record(Scalar):
Expand All @@ -422,7 +451,9 @@ class will be results from awkward operations.
"""

def __init__(self, dsk: HighLevelGraph, name: str, meta: Any | None = None) -> None:
super().__init__(dsk, name, meta)
self._dask: HighLevelGraph = dsk
self._name: str = name
self._meta: ak.Record = self._check_meta(meta)

def _check_meta(self, m: Any | None) -> Any | None:
if not isinstance(m, ak.Record):
Expand Down
2 changes: 1 addition & 1 deletion src/dask_awkward/lib/io/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def to_json(
AwkwardMaterializedLayer(dsk, previous_layer_names=[map_res.name]),
dependencies=(map_res,),
)
res = new_scalar_object(graph, name=name, meta=None)
res = new_scalar_object(graph, name=name, dtype="f8")
if compute:
res.compute()
return None
Expand Down
10 changes: 5 additions & 5 deletions src/dask_awkward/lib/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ def __call__(self, source: Any) -> Any:
subrg=[None],
subform=self.form,
highlevel=False,
attrs=None,
behavior=None,
fs=self.fs,
behavior=self.behavior,
attrs=self.attrs,
**self.kwargs,
)
return ak.Array(
Expand Down Expand Up @@ -178,9 +178,9 @@ def __call__(self, pair: Any) -> ak.Array:
subrg=subrg,
subform=self.form,
highlevel=False,
attrs=None,
behavior=None,
fs=self.fs,
behavior=self.behavior,
attrs=self.attrs,
**self.kwargs,
)
return ak.Array(
Expand Down Expand Up @@ -637,7 +637,7 @@ def to_parquet(
AwkwardMaterializedLayer(dsk, previous_layer_names=[map_res.name]),
dependencies=[map_res],
)
out = new_scalar_object(graph, final_name, meta=None)
out = new_scalar_object(graph, final_name, dtype="f8")
if compute:
out.compute()
return None
Expand Down
6 changes: 5 additions & 1 deletion src/dask_awkward/lib/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,11 @@ def num(
{(name, 0): (_numaxis0, *keys)}, previous_layer_names=[per_axis.name]
)
hlg = HighLevelGraph.from_collections(name, matlayer, dependencies=(per_axis,))
return new_scalar_object(hlg, name, meta=create_unknown_scalar(np.int64))
return new_scalar_object(
hlg,
name,
meta=ak.Array(create_unknown_scalar(np.dtype("int64"))),
)
else:
return map_partitions(
ak.num,
Expand Down
20 changes: 6 additions & 14 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import operator
import sys
from collections import namedtuple
from collections.abc import Callable
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -216,18 +215,13 @@ def test_scalar_collection(daa: Array) -> None:
assert type(daa["points", "x"][0][0]) is Scalar


def test_scalar_getitem_getattr() -> None:
d = {"a": 5}
s = new_known_scalar(d)
with pytest.raises(NotImplementedError, match="should be done after converting"):
s["a"].compute() == d["a"]
assert s.to_delayed()["a"].compute() == d["a"] # type: ignore
Thing = namedtuple("Thing", "a b c")
t = Thing(c=3, b=2, a=1)
s = new_known_scalar(t)
def test_known_scalar() -> None:
i = 5
s = new_known_scalar(5)
assert s.compute() == 5
with pytest.raises(AttributeError, match="should be done after converting"):
s.c.compute()
assert s.to_delayed().c.compute() == t.c
s.denominator.compute()
assert s.to_delayed().denominator.compute() == i.denominator


@pytest.mark.parametrize("op", [operator.add, operator.truediv, operator.mul])
Expand Down Expand Up @@ -476,8 +470,6 @@ def test_scalar_dtype() -> None:
s = 2
c = new_known_scalar(s)
assert c.dtype == np.dtype(type(s))
c._meta = None
assert c.dtype is None


def test_scalar_pickle(daa: Array) -> None:
Expand Down
Loading