Skip to content

Commit

Permalink
addressing some comments Angus
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis committed Nov 30, 2023
1 parent 25558bb commit 54b4d74
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 39 deletions.
15 changes: 8 additions & 7 deletions src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, Union, cast

import awkward as ak
from dask.blockwise import Blockwise, BlockwiseDepDict, blockwise_token
from dask.highlevelgraph import MaterializedLayer
from dask.layers import DataFrameTreeReduction
Expand Down Expand Up @@ -107,13 +108,13 @@ def mock_empty(self, backend: BackendT = "cpu") -> AwkwardArray:
)


def io_func_rer_wrapped(func: ImplementsIOFunction) -> bool:
return hasattr(func, "__rer_wrapped__")
def io_func_empty_on_error_wrapped(func: ImplementsIOFunction) -> bool:
return hasattr(func, "_empty_on_error_wrapped")


def maybe_unwrap(func: Callable) -> Callable:
if io_func_rer_wrapped(func):
return func.__rer_wrapped__ # type: ignore
if io_func_empty_on_error_wrapped(func):
return func._empty_on_error_wrapped # type: ignore
return func


Expand Down Expand Up @@ -242,8 +243,8 @@ def prepare_for_projection(self) -> tuple[AwkwardInputLayer, TypeTracerReport, T
ImplementsProjection, fn
).prepare_for_projection()

if io_func_rer_wrapped(self.io_func):
new_return = (new_meta_array, type(new_meta_array)([]))
if io_func_empty_on_error_wrapped(self.io_func):
new_return = (new_meta_array, ak.from_iter([]))
else:
new_return = new_meta_array

Expand All @@ -267,7 +268,7 @@ def project(
fn = maybe_unwrap(self.io_func)
io_func = cast(ImplementsProjection, fn).project(report=report, state=state)

if io_func_rer_wrapped(self.io_func):
if io_func_empty_on_error_wrapped(self.io_func):
io_func = self.io_func.recreate(io_func) # type: ignore

return AwkwardInputLayer(
Expand Down
106 changes: 76 additions & 30 deletions src/dask_awkward/lib/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

import awkward as ak
import numpy as np
from awkward.forms.listoffsetform import ListOffsetForm
from awkward.forms.numpyform import NumpyForm
from awkward.forms.recordform import RecordForm
from awkward.types.numpytype import primitive_to_dtype
from awkward.typetracer import length_zero_if_typetracer
from dask.base import flatten, tokenize
Expand Down Expand Up @@ -496,20 +499,47 @@ def __call__(self, packed_arg):
)


def default_report_success(*args: Any, **kwargs: Any) -> ak.Array:
return ak.Array(
[
{
"args": [],
"kwargs": [],
"exception": "",
"message": "",
},
],
)
_default_failure_array_form = RecordForm(
[
ListOffsetForm(
"i64",
ListOffsetForm(
"i64",
NumpyForm("uint8", parameters={"__array__": "char"}),
parameters={"__array__": "string"},
),
),
ListOffsetForm(
"i64",
ListOffsetForm(
"i64",
ListOffsetForm(
"i64",
NumpyForm("uint8", parameters={"__array__": "char"}),
parameters={"__array__": "string"},
),
),
),
ListOffsetForm(
"i64",
NumpyForm("uint8", parameters={"__array__": "char"}),
parameters={"__array__": "string"},
),
ListOffsetForm(
"i64",
NumpyForm("uint8", parameters={"__array__": "char"}),
parameters={"__array__": "string"},
),
],
["args", "kwargs", "exception", "message"],
)


def on_success_default(*args: Any, **kwargs: Any) -> ak.Array:
return ak.Array(_default_failure_array_form.length_one_array(highlevel=False))

def default_report_failure(

def on_failure_default(
exception: type[BaseException],
*args: Any,
**kwargs: Any,
Expand All @@ -526,38 +556,54 @@ def default_report_failure(
)


class return_empty_on_raise:
class ReturnEmptyOnRaise:
def __init__(
self,
fn: Callable[..., ak.Array],
allowed_exceptions: tuple[type[BaseException], ...],
backend: BackendT,
success_callback: Callable[..., ak.Array],
failure_callback: Callable[..., ak.Array],
on_success: Callable[..., ak.Array],
on_failure: Callable[..., ak.Array],
):
self.__rer_wrapped__ = fn
self._empty_on_error_wrapped = fn
self.fn = fn
self.allowed_exceptions = allowed_exceptions
self.backend = backend
self.success_callback = success_callback
self.failure_callback = failure_callback
self.on_success = on_success
self.on_failure = on_failure

def recreate(self, fn):
return return_empty_on_raise(
fn,
self.allowed_exceptions,
self.backend,
self.success_callback,
self.failure_callback,
self.on_success,
self.on_failure,
)

def __call__(self, *args, **kwargs):
try:
result = self.fn(*args, **kwargs)
return result, self.success_callback(*args, **kwargs)
return result, self.on_success(*args, **kwargs)
except self.allowed_exceptions as err:
result = self.fn.mock_empty(self.backend)
return result, self.failure_callback(err, *args, **kwargs)
return result, self.on_failure(err, *args, **kwargs)


def return_empty_on_raise(
fn: Callable[..., ak.Array],
allowed_exceptions: tuple[type[BaseException], ...],
backend: BackendT,
on_success: Callable[..., ak.Array],
on_failure: Callable[..., ak.Array],
) -> ReturnEmptyOnRaise:
return ReturnEmptyOnRaise(
fn,
allowed_exceptions,
backend,
on_success,
on_failure,
)


@overload
Expand All @@ -571,8 +617,8 @@ def from_map(
meta: ak.Array | None = None,
empty_on_raise: None = None,
empty_backend: None = None,
empty_success_callback: Callable[..., ak.Array] = default_report_success,
empty_failure_callback: Callable[..., ak.Array] = default_report_failure,
on_success: Callable[..., ak.Array] = on_success_default,
on_failure: Callable[..., ak.Array] = on_failure_default,
**kwargs: Any,
) -> Array:
...
Expand All @@ -589,8 +635,8 @@ def from_map(
token: str | None = None,
divisions: tuple[int, ...] | tuple[None, ...] | None = None,
meta: ak.Array | None = None,
empty_success_callback: Callable[..., ak.Array] = default_report_success,
empty_failure_callback: Callable[..., ak.Array] = default_report_failure,
on_success: Callable[..., ak.Array] = on_success_default,
on_failure: Callable[..., ak.Array] = on_failure_default,
**kwargs: Any,
) -> tuple[Array, Array]:
...
Expand All @@ -606,8 +652,8 @@ def from_map(
meta: ak.Array | None = None,
empty_on_raise: tuple[type[BaseException], ...] | None = None,
empty_backend: BackendT | None = None,
empty_success_callback: Callable[..., ak.Array] = default_report_success,
empty_failure_callback: Callable[..., ak.Array] = default_report_failure,
on_success: Callable[..., ak.Array] = on_success_default,
on_failure: Callable[..., ak.Array] = on_failure_default,
**kwargs: Any,
) -> Array | tuple[Array, Array]:
"""Create an Array collection from a custom mapping.
Expand Down Expand Up @@ -731,8 +777,8 @@ def from_map(
io_func,
allowed_exceptions=empty_on_raise,
backend=empty_backend,
success_callback=empty_success_callback,
failure_callback=empty_failure_callback,
on_success=on_success,
on_failure=on_failure,
)

dsk = AwkwardInputLayer(name=name, inputs=inputs, io_func=io_func)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,8 @@ def fail(excep, *args, **kwargs):
label="from-lists",
empty_on_raise=(OSError,),
empty_backend="cpu",
empty_failure_callback=fail,
empty_success_callback=succ,
on_failure=fail,
on_success=succ,
)

_, rep = dask.compute(array, report)
Expand Down

0 comments on commit 54b4d74

Please sign in to comment.