Skip to content

Commit

Permalink
OOP based report implementation (#433)
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis authored Dec 11, 2023
1 parent 2ce95fa commit 910ca6b
Show file tree
Hide file tree
Showing 12 changed files with 314 additions and 203 deletions.
44 changes: 24 additions & 20 deletions src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,29 @@ def __repr__(self) -> str:


class ImplementsIOFunction(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> AwkwardArray:
def __call__(self, *args, **kwargs):
...


T = TypeVar("T")


class ImplementsMocking(Protocol):
class ImplementsMocking(ImplementsIOFunction, Protocol):
def mock(self) -> AwkwardArray:
...


class ImplementsMockEmpty(ImplementsIOFunction, Protocol):
def mock_empty(self, backend: BackendT) -> AwkwardArray:
...


class ImplementsReport(ImplementsIOFunction, Protocol):
@property
def return_report(self) -> bool:
...


class ImplementsProjection(ImplementsMocking, Protocol[T]):
def prepare_for_projection(self) -> tuple[AwkwardArray, TypeTracerReport, T]:
...
Expand Down Expand Up @@ -93,19 +101,6 @@ def mock(self) -> AwkwardArray:
assert self._meta is not None
return self._meta

def mock_empty(self, backend: BackendT = "cpu") -> AwkwardArray:
import awkward as ak

if backend not in ("cpu", "jax", "cuda"):
raise ValueError(
f"backend must be one of 'cpu', 'jax', or 'cuda', received {backend}"
)
return ak.to_backend(
self.mock().layout.form.length_zero_array(highlevel=False),
backend=backend,
highlevel=True,
)


def io_func_implements_projection(func: ImplementsIOFunction) -> bool:
return hasattr(func, "prepare_for_projection")
Expand All @@ -123,6 +118,10 @@ def io_func_implements_columnar(func: ImplementsIOFunction) -> bool:
return hasattr(func, "necessary_columns")


def io_func_implements_report(func: ImplementsIOFunction) -> bool:
return hasattr(func, "return_report")


class AwkwardInputLayer(AwkwardBlockwiseLayer):
"""A layer known to perform IO and produce Awkward arrays
Expand Down Expand Up @@ -183,7 +182,6 @@ def is_columnar(self) -> bool:

def mock(self) -> AwkwardInputLayer:
assert self.is_mockable

return AwkwardInputLayer(
name=self.name,
inputs=[None][: int(list(self.numblocks.values())[0][0])],
Expand Down Expand Up @@ -229,10 +227,15 @@ def prepare_for_projection(self) -> tuple[AwkwardInputLayer, TypeTracerReport, T
ImplementsProjection, self.io_func
).prepare_for_projection()

new_return = new_meta_array
if io_func_implements_report(self.io_func):
if cast(ImplementsReport, self.io_func).return_report:
new_return = (new_meta_array, type(new_meta_array)([]))

new_input_layer = AwkwardInputLayer(
name=self.name,
inputs=[None][: int(list(self.numblocks.values())[0][0])],
io_func=lambda *_, **__: new_meta_array,
io_func=lambda *_, **__: new_return,
label=self.label,
produces_tasks=self.produces_tasks,
creation_info=self.creation_info,
Expand All @@ -246,12 +249,13 @@ def project(
state: T,
) -> AwkwardInputLayer:
assert self.is_projectable
io_func = cast(ImplementsProjection, self.io_func).project(
report=report, state=state
)
return AwkwardInputLayer(
name=self.name,
inputs=self.inputs,
io_func=cast(ImplementsProjection, self.io_func).project(
report=report, state=state
),
io_func=io_func,
label=self.label,
produces_tasks=self.produces_tasks,
creation_info=self.creation_info,
Expand Down
22 changes: 14 additions & 8 deletions src/dask_awkward/lib/io/columnar.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
from typing import TYPE_CHECKING, Protocol, TypeVar, cast

import awkward as ak
from awkward import Array as AwkwardArray
from awkward.forms import Form
from awkward.typetracer import typetracer_from_form, typetracer_with_report

from dask_awkward.layers.layers import (
BackendT,
Expand Down Expand Up @@ -43,7 +44,7 @@ def behavior(self) -> dict | None:
def project_columns(self: T, columns: frozenset[str]) -> T:
...

def __call__(self, *args: Any, **kwargs: Any) -> AwkwardArray:
def __call__(self, *args, **kwargs):
...


Expand All @@ -60,13 +61,18 @@ class ColumnProjectionMixin(ImplementsNecessaryColumns[FormStructure]):
"""

def mock(self: S) -> AwkwardArray:
return ak.typetracer.typetracer_from_form(self.form, behavior=self.behavior)
return cast(
AwkwardArray, typetracer_from_form(self.form, behavior=self.behavior)
)

def mock_empty(self: S, backend: BackendT = "cpu") -> AwkwardArray:
return ak.to_backend(
self.form.length_zero_array(highlevel=False, behavior=self.behavior),
backend,
highlevel=True,
return cast(
AwkwardArray,
ak.to_backend(
self.form.length_zero_array(highlevel=False, behavior=self.behavior),
backend,
highlevel=True,
),
)

def prepare_for_projection(
Expand All @@ -75,7 +81,7 @@ def prepare_for_projection(
form = form_with_unique_keys(self.form, "@")

# Build typetracer and associated report object
(meta, report) = ak.typetracer.typetracer_with_report(
(meta, report) = typetracer_with_report(
form,
highlevel=True,
behavior=self.behavior,
Expand Down
100 changes: 33 additions & 67 deletions src/dask_awkward/lib/io/io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import functools
import logging
import math
from collections.abc import Callable, Iterable, Mapping
Expand All @@ -20,18 +19,20 @@
AwkwardBlockwiseLayer,
AwkwardInputLayer,
AwkwardMaterializedLayer,
BackendT,
ImplementsMocking,
ImplementsReport,
IOFunctionWithMocking,
io_func_implements_mock_empty,
io_func_implements_mocking,
io_func_implements_report,
)
from dask_awkward.lib.core import (
Array,
empty_typetracer,
map_partitions,
new_array_object,
typetracer_array,
)
from dask_awkward.utils import first, second

if TYPE_CHECKING:
from dask.array.core import Array as DaskArray
Expand All @@ -40,8 +41,6 @@
from dask.delayed import Delayed
from fsspec.spec import AbstractFileSystem

from dask_awkward.lib.core import Array


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,16 +99,19 @@ def from_awkward(
starts = locs[:-1]
stops = locs[1:]
meta = typetracer_array(source)
return from_map(
_FromAwkwardFn(source),
starts,
stops,
label=label or "from-awkward",
token=tokenize(source, npartitions),
divisions=locs,
meta=meta,
behavior=behavior,
attrs=attrs,
return cast(
Array,
from_map(
_FromAwkwardFn(source),
starts,
stops,
label=label or "from-awkward",
token=tokenize(source, npartitions),
divisions=locs,
meta=meta,
behavior=behavior,
attrs=attrs,
),
)


Expand Down Expand Up @@ -158,12 +160,15 @@ def from_lists(
"""
lists = list(source)
divs = (0, *np.cumsum(list(map(len, lists))))
return from_map(
_FromListsFn(behavior=behavior, attrs=attrs),
lists,
meta=typetracer_array(ak.Array(lists[0], attrs=attrs, behavior=behavior)),
divisions=divs,
label="from-lists",
return cast(
Array,
from_map(
_FromListsFn(behavior=behavior, attrs=attrs),
lists,
meta=typetracer_array(ak.Array(lists[0], attrs=attrs, behavior=behavior)),
divisions=divs,
label="from-lists",
),
)


Expand Down Expand Up @@ -496,31 +501,6 @@ def __call__(self, packed_arg):
)


def return_empty_on_raise(
fn: Callable,
allowed_exceptions: tuple[type[BaseException], ...],
backend: BackendT,
) -> Callable:
@functools.wraps(fn)
def wrapped(*args, **kwargs):
try:
return fn(*args, **kwargs)
except allowed_exceptions as err:
logmsg = (
"%s call failed with args %s and kwargs %s; empty array returned. %s"
% (
str(fn),
str(args),
str(kwargs),
str(err),
)
)
logger.info(logmsg)
return fn.mock_empty(backend)

return wrapped


def from_map(
func: Callable,
*iterables: Iterable,
Expand All @@ -529,10 +509,8 @@ def from_map(
token: str | None = None,
divisions: tuple[int, ...] | tuple[None, ...] | None = None,
meta: ak.Array | None = None,
empty_on_raise: tuple[type[BaseException], ...] | None = None,
empty_backend: BackendT | None = None,
**kwargs: Any,
) -> Array:
) -> Array | tuple[Array, Array]:
"""Create an Array collection from a custom mapping.
Parameters
Expand All @@ -557,12 +535,6 @@ def from_map(
meta : Array, optional
Collection metadata array, if known (the awkward-array type
tracer)
empty_on_raise : tuple[type[BaseException], ...], optional
Set of exceptions that can be caught to return an empty array
at compute time if file IO raises.
empty_backend : str,
The backend for the empty array resulting from a failed read
when `empty_on_raise` is defined.
**kwargs : Any
Keyword arguments passed to `func`.
Expand Down Expand Up @@ -644,18 +616,6 @@ def from_map(
io_func = func
array_meta = None

if (empty_on_raise and not empty_backend) or (empty_backend and not empty_on_raise):
raise ValueError("empty_on_raise and empty_backend must be used together.")

if empty_on_raise and empty_backend:
if not io_func_implements_mock_empty(io_func):
raise ValueError("io_func must implement mock_empty method.")
io_func = return_empty_on_raise(
io_func,
allowed_exceptions=empty_on_raise,
backend=empty_backend,
)

dsk = AwkwardInputLayer(name=name, inputs=inputs, io_func=io_func)

hlg = HighLevelGraph.from_collections(name, dsk)
Expand All @@ -664,6 +624,12 @@ def from_map(
else:
result = new_array_object(hlg, name, meta=array_meta, npartitions=len(inputs))

if io_func_implements_report(io_func):
if cast(ImplementsReport, io_func).return_report:
res = result.map_partitions(first, meta=array_meta, output_divisions=1)
rep = result.map_partitions(second, meta=empty_typetracer())
return res, rep

return result


Expand Down
Loading

0 comments on commit 910ca6b

Please sign in to comment.