Skip to content

Commit

Permalink
first working (for parquet)
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Apr 9, 2024
1 parent 0c367df commit 9c654b1
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 17 deletions.
14 changes: 4 additions & 10 deletions src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def return_report(self) -> bool: ...


class ImplementsProjection(Protocol[T]):
def project(self, report: TypeTracerReport, state: T) -> ImplementsIOFunction: ...
def project(self, columns: list[str]) -> ImplementsIOFunction: ...


class ImplementsNecessaryColumns(ImplementsProjection[T], Protocol):
Expand All @@ -81,7 +81,7 @@ def __call__(self, *args, **kwargs):


def io_func_implements_projection(func: ImplementsIOFunction) -> bool:
return hasattr(func, "prepare_for_projection")
return hasattr(func, "project")


def io_func_implements_columnar(func: ImplementsIOFunction) -> bool:
Expand Down Expand Up @@ -158,15 +158,9 @@ def is_projectable(self) -> bool:
def is_columnar(self) -> bool:
return io_func_implements_columnar(self.io_func)

def project(
self,
report: TypeTracerReport,
state: T,
) -> AwkwardInputLayer:
def project(self, columns: list[str]) -> AwkwardInputLayer:
assert self.is_projectable
io_func = cast(ImplementsProjection, self.io_func).project(
report=report, state=state
)
io_func = self.io_func.project(columns)
return AwkwardInputLayer(
name=self.name,
inputs=self.inputs,
Expand Down
1 change: 1 addition & 0 deletions src/dask_awkward/lib/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ def from_map(
behavior=io_func.behavior,
buffer_key=render_buffer_key,
)
io_func._column_report = report
report.commit(name)
array_meta._report = {
report
Expand Down
6 changes: 3 additions & 3 deletions src/dask_awkward/lib/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
def __call__(self, *args, **kwargs): ...

@abc.abstractmethod
def project_columns(self, columns): ...
def project(self, columns): ...

@property
def return_report(self) -> bool:
Expand Down Expand Up @@ -176,7 +176,7 @@ def __call__(self, *args, **kwargs):

return self.read_fn(source)

def project_columns(self, columns):
def project(self, columns):
return FromParquetFileWiseFn(
fs=self.fs,
form=self.form.select_columns(columns),
Expand Down Expand Up @@ -235,7 +235,7 @@ def __call__(self, pair: Any) -> ak.Array:
attrs=self.attrs,
)

def project_columns(self, columns):
def project(self, columns):
return FromParquetFragmentWiseFn(
fs=self.fs,
form=self.form.select_columns(columns),
Expand Down
26 changes: 22 additions & 4 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def all_optimizations(dsk: Mapping, keys: Sequence[Key], **_: Any) -> Mapping:
def optimize(dsk: HighLevelGraph, keys: Sequence[Key], **_: Any) -> Mapping:
"""Run optimizations specific to dask-awkward.
This is currently limited to determining the necessary columns for
input layers.
- determine the necessary columns for input layers
- fuse linear chains of blockwise operations in linear time
"""
if dask.config.get("awkward.optimization.enabled"):
Expand Down Expand Up @@ -104,9 +104,27 @@ def optimize_columns(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelGraph
New, optimized task graph with column-projected ``AwkwardInputLayer``.
"""
# TBD
dsk2 = dsk.layers.copy()
for k, lay in dsk.layers.items():
if not isinstance(lay, AwkwardInputLayer):
continue
rep = lay.io_func._column_report
cols = rep.data_touched_in(dsk.layers)
new_lay = lay.project([c.replace("@.", "") for c in cols])
dsk2[k] = new_lay

return HighLevelGraph(dsk2, dsk.dependencies)

return dsk # HighLevelGraph(layers, dsk.dependencies)

def necessary_columns(dsk, keys):
out = {}
for k, lay in dsk.layers.items():
if not isinstance(lay, AwkwardInputLayer):
continue
rep = lay.io_func._column_report
cols = rep.data_touched_in(dsk.layers)
out[k] = lay.project([c.replace("@.", "") for c in cols])
return out


def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelGraph:
Expand Down

0 comments on commit 9c654b1

Please sign in to comment.