Skip to content

Commit

Permalink
refactor: separate mocking from projection more cleanly
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Oct 3, 2023
1 parent f62491c commit 7a10638
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 40 deletions.
8 changes: 4 additions & 4 deletions src/dask_awkward/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
AwkwardTreeReductionLayer,
ImplementsIOFunction,
ImplementsProjection,
IOFunctionWithMeta,
io_func_implements_project,
IOFunctionWithMocking,
io_func_implements_projection,
)

__all__ = (
Expand All @@ -16,6 +16,6 @@
"AwkwardTreeReductionLayer",
"ImplementsProjection",
"ImplementsIOFunction",
"IOFunctionWithMeta",
"io_func_implements_project",
"IOFunctionWithMocking",
"io_func_implements_projection",
)
56 changes: 38 additions & 18 deletions src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,47 +54,51 @@ def __call__(self, *args, **kwargs) -> AwkwardArray:
T = TypeVar("T")


class ImplementsProjection(Protocol):
@property
def meta(self) -> AwkwardArray:
class ImplementsMocking(Protocol):
def mock(self) -> AwkwardArray:
...


class ImplementsProjection(ImplementsMocking, Protocol):
def prepare_for_projection(self) -> tuple[AwkwardArray, T]:
...

def project(self, state: T) -> ImplementsIOFunction:
...


# IO functions may not end up performing buffer projection, so they
# should also support directly returning the result
# IO functions can implement full-blown projection
class ImplementsIOFunctionWithProjection(
ImplementsProjection, ImplementsIOFunction, Protocol
):
...


class IOFunctionWithMeta(ImplementsIOFunctionWithProjection):
# Or they can implement simple mocking
class ImplementsIOFunctionWithMocking(
ImplementsMocking, ImplementsIOFunction, Protocol
):
...


class IOFunctionWithMocking(ImplementsIOFunctionWithMocking):
def __init__(self, meta: AwkwardArray, io_func: ImplementsIOFunction):
self._meta = meta
self._io_func = io_func

def __call__(self, *args, **kwargs) -> AwkwardArray:
return self._io_func(*args, **kwargs)

@property
def meta(self) -> AwkwardArray:
def mock(self) -> AwkwardArray:
return self._meta

def prepare_for_projection(self) -> tuple[AwkwardArray, None]:
return self._meta, None

def project(self, state: None):
return self._io_func
def io_func_implements_projection(func: ImplementsIOFunction) -> bool:
return hasattr(func, "prepare_for_projection")


def io_func_implements_project(func: ImplementsIOFunction) -> bool:
return hasattr(func, "project")
def io_func_implements_mocking(func: ImplementsIOFunction) -> bool:
return hasattr(func, "mock")


class AwkwardInputLayer(AwkwardBlockwiseLayer):
Expand All @@ -108,7 +112,9 @@ def __init__(
*,
name: str,
inputs: Any,
io_func: ImplementsIOFunction | ImplementsIOFunctionWithProjection,
io_func: ImplementsIOFunction
| ImplementsIOFunctionWithMocking
| ImplementsIOFunctionWithProjection,
label: str | None = None,
produces_tasks: bool = False,
creation_info: dict | None = None,
Expand Down Expand Up @@ -142,11 +148,25 @@ def __repr__(self) -> str:
@property
def is_projectable(self) -> bool:
# isinstance(self.io_func, ImplementsProjection)
return io_func_implements_project(self.io_func)
return io_func_implements_projection(self.io_func)

@property
def is_mockable(self) -> bool:
# isinstance(self.io_func, ImplementsMocking)
return io_func_implements_mocking(self.io_func)

def mock(self) -> AwkwardInputLayer:
layer, _ = self.prepare_for_projection()
return layer
assert self.is_mockable

return AwkwardInputLayer(
name=self.name,
inputs=[None][: int(list(self.numblocks.values())[0][0])],
io_func=lambda *_, **__: self.io_func.mock(),
label=self.label,
produces_tasks=self.produces_tasks,
creation_info=self.creation_info,
annotations=self.annotations,
)

def prepare_for_projection(self) -> tuple[AwkwardInputLayer, T]:
"""Mock the input layer as starting with a data-less typetracer.
Expand Down
28 changes: 16 additions & 12 deletions src/dask_awkward/lib/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Mapping, Protocol
from typing import TYPE_CHECKING, Any, Callable, Mapping, Protocol, cast

import awkward as ak
import numpy as np
Expand All @@ -19,10 +19,13 @@
AwkwardInputLayer,
ImplementsIOFunction,
ImplementsProjection,
IOFunctionWithMeta,
io_func_implements_project,
)
from dask_awkward.layers.layers import AwkwardMaterializedLayer
from dask_awkward.layers.layers import (
AwkwardMaterializedLayer,
ImplementsMocking,
IOFunctionWithMocking,
io_func_implements_mocking,
)
from dask_awkward.lib.core import (
empty_typetracer,
map_partitions,
Expand Down Expand Up @@ -566,18 +569,19 @@ def from_map(
packed=packed,
)

# Special `io_func` implementations can
if io_func_implements_project(func):
# Special `io_func` implementations can implement mocking and optionally
# support buffer projection.
if io_func_implements_mocking(func):
io_func = func
array_meta = func.meta
array_meta = cast(ImplementsMocking, func).mock()
# If we know the meta, we can spoof mocking
elif meta is not None:
io_func = IOFunctionWithMocking(meta, func)
array_meta = meta
# Without `meta`, the meta will be computed by executing the graph
elif meta is None:
else:
io_func = func
array_meta = None
# If we know the meta, we can spoof projection
else:
io_func = IOFunctionWithMeta(meta, func)
array_meta = meta

dsk = AwkwardInputLayer(name=name, inputs=inputs, io_func=io_func)

Expand Down
3 changes: 1 addition & 2 deletions src/dask_awkward/lib/io/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ def __init__(
def __call__(self, source: Any) -> ak.Array:
...

@property
def meta(self) -> AwkwardArray:
def mock(self) -> AwkwardArray:
return ak.typetracer.typetracer_from_form(self.form, behavior=self.behavior)

def prepare_for_projection(self) -> tuple[AwkwardArray, dict]:
Expand Down
7 changes: 3 additions & 4 deletions src/dask_awkward/lib/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def __init__(
def __call__(self, source: Any) -> ak.Array:
...

def mock(self) -> AwkwardArray:
return ak.typetracer.typetracer_from_form(self.form, behavior=self.behavior)

def prepare_for_projection(self) -> tuple[AwkwardArray, dict]:
form = form_with_unique_keys(self.form, "<root>")

Expand All @@ -84,10 +87,6 @@ def prepare_for_projection(self) -> tuple[AwkwardArray, dict]:
"report": report,
}

@property
def meta(self) -> AwkwardArray:
return ak.typetracer.typetracer_from_form(self.form, behavior=self.behavior)

@abc.abstractmethod
def project(self, state: dict) -> _FromParquetFn:
...
Expand Down
2 changes: 2 additions & 0 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph:
projection_layers[name],
layer_to_projection_state[name],
) = lay.prepare_for_projection()
elif lay.is_mockable:
projection_layers[name] = lay.mock()
elif hasattr(lay, "mock"):
projection_layers[name] = lay.mock()

Expand Down

0 comments on commit 7a10638

Please sign in to comment.