From 9c654b1f7351fe9cbcf5ab61334fe69129f5b346 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Tue, 9 Apr 2024 16:15:31 -0400 Subject: [PATCH] first working (for parquet) --- src/dask_awkward/layers/layers.py | 14 ++++---------- src/dask_awkward/lib/io/io.py | 1 + src/dask_awkward/lib/io/parquet.py | 6 +++--- src/dask_awkward/lib/optimize.py | 26 ++++++++++++++++++++++---- 4 files changed, 30 insertions(+), 17 deletions(-) diff --git a/src/dask_awkward/layers/layers.py b/src/dask_awkward/layers/layers.py index ecc208bf..8350de34 100644 --- a/src/dask_awkward/layers/layers.py +++ b/src/dask_awkward/layers/layers.py @@ -57,7 +57,7 @@ def return_report(self) -> bool: ... class ImplementsProjection(Protocol[T]): - def project(self, report: TypeTracerReport, state: T) -> ImplementsIOFunction: ... + def project(self, columns: list[str]) -> ImplementsIOFunction: ... class ImplementsNecessaryColumns(ImplementsProjection[T], Protocol): @@ -81,7 +81,7 @@ def __call__(self, *args, **kwargs): def io_func_implements_projection(func: ImplementsIOFunction) -> bool: - return hasattr(func, "prepare_for_projection") + return hasattr(func, "project") def io_func_implements_columnar(func: ImplementsIOFunction) -> bool: @@ -158,15 +158,9 @@ def is_projectable(self) -> bool: def is_columnar(self) -> bool: return io_func_implements_columnar(self.io_func) - def project( - self, - report: TypeTracerReport, - state: T, - ) -> AwkwardInputLayer: + def project(self, columns: list[str]) -> AwkwardInputLayer: assert self.is_projectable - io_func = cast(ImplementsProjection, self.io_func).project( - report=report, state=state - ) + io_func = self.io_func.project(columns) return AwkwardInputLayer( name=self.name, inputs=self.inputs, diff --git a/src/dask_awkward/lib/io/io.py b/src/dask_awkward/lib/io/io.py index 0298f06a..1270f328 100644 --- a/src/dask_awkward/lib/io/io.py +++ b/src/dask_awkward/lib/io/io.py @@ -631,6 +631,7 @@ def from_map( behavior=io_func.behavior, buffer_key=render_buffer_key, ) + io_func._column_report = report report.commit(name) array_meta._report = { report diff --git a/src/dask_awkward/lib/io/parquet.py b/src/dask_awkward/lib/io/parquet.py index 2fc47c25..3625e1d7 100644 --- a/src/dask_awkward/lib/io/parquet.py +++ b/src/dask_awkward/lib/io/parquet.py @@ -94,7 +94,7 @@ def __init__( def __call__(self, *args, **kwargs): ... @abc.abstractmethod - def project_columns(self, columns): ... + def project(self, columns): ... @property def return_report(self) -> bool: @@ -176,7 +176,7 @@ def __call__(self, *args, **kwargs): return self.read_fn(source) - def project_columns(self, columns): + def project(self, columns): return FromParquetFileWiseFn( fs=self.fs, form=self.form.select_columns(columns), @@ -235,7 +235,7 @@ def __call__(self, pair: Any) -> ak.Array: attrs=self.attrs, ) - def project_columns(self, columns): + def project(self, columns): return FromParquetFragmentWiseFn( fs=self.fs, form=self.form.select_columns(columns), diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index cc792db6..22f87e29 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -61,8 +61,8 @@ def all_optimizations(dsk: Mapping, keys: Sequence[Key], **_: Any) -> Mapping: def optimize(dsk: HighLevelGraph, keys: Sequence[Key], **_: Any) -> Mapping: """Run optimizations specific to dask-awkward. - This is currently limited to determining the necessary columns for - input layers. + - determine the necessary columns for input layers + - fuse linear chains of blockwise operations in linear time """ if dask.config.get("awkward.optimization.enabled"): @@ -104,9 +104,27 @@ def optimize_columns(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelGraph New, optimized task graph with column-projected ``AwkwardInputLayer``. """ - # TBD + dsk2 = dsk.layers.copy() + for k, lay in dsk.layers.items(): + if not isinstance(lay, AwkwardInputLayer): + continue + rep = lay.io_func._column_report + cols = rep.data_touched_in(dsk.layers) + new_lay = lay.project([c.replace("@.", "") for c in cols]) + dsk2[k] = new_lay + + return HighLevelGraph(dsk2, dsk.dependencies) - return dsk # HighLevelGraph(layers, dsk.dependencies) + +def necessary_columns(dsk, keys): + out = {} + for k, lay in dsk.layers.items(): + if not isinstance(lay, AwkwardInputLayer): + continue + rep = lay.io_func._column_report + cols = rep.data_touched_in(dsk.layers) + out[k] = lay.project([c.replace("@.", "") for c in cols]) + return out def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelGraph: