Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: simultaneous computes when one is encapsulated by the other doesn't over optimize #375

Merged
merged 2 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 120 additions & 91 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, Any

import dask.config
from awkward.typetracer import touch_data
from dask.blockwise import fuse_roots, optimize_blockwise
from dask.core import flatten
from dask.highlevelgraph import HighLevelGraph
Expand Down Expand Up @@ -107,12 +108,12 @@ def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph:
Parameters
----------
dsk : HighLevelGraph
Original high level dask graph
Task graph to optimize.

Returns
-------
HighLevelGraph
New dask graph with a modified ``AwkwardInputLayer``.
New, optimized task graph with column-projected ``AwkwardInputLayer``.

"""
layers = dsk.layers.copy() # type: ignore
Expand All @@ -128,109 +129,41 @@ def optimize_columns(dsk: HighLevelGraph) -> HighLevelGraph:
return HighLevelGraph(layers, deps)


def _projectable_input_layer_names(dsk: HighLevelGraph) -> list[str]:
"""Get list of column-projectable AwkwardInputLayer names.

Parameters
----------
dsk : HighLevelGraph
Task graph of interest

Returns
-------
list[str]
Names of the AwkwardInputLayers in the graph that are
column-projectable.

"""
return [
n
for n, v in dsk.layers.items()
if isinstance(v, AwkwardInputLayer) and hasattr(v.io_func, "project_columns")
# following condition means dep/pickled layers cannot be optimised
and hasattr(v, "_meta")
]


def _layers_with_annotation(dsk: HighLevelGraph, key: str) -> list[str]:
return [n for n, v in dsk.layers.items() if (v.annotations or {}).get(key)]


def _ak_output_layer_names(dsk: HighLevelGraph) -> list[str]:
"""Get a list output layer names.

Output layer names are annotated with 'ak_output'.
def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph:
"""Smush chains of blockwise layers into a single layer.

The logic here identifies chains by popping layers (in arbitrary
order) from a set of all layers in the task graph and walking
through the dependencies (parent layers) and dependents (child
layers). If a multi layer chain is discovered we compress it into
a single layer with the second loop below (for chain in chains;
that step rewrites the graph). In the chain building logic, if a
layer exists in the `keys` argument (the keys necessary for the
compute that we are optimizing for), we shortcircuit the logic to
ensure we do not chain layers that contain a necessary key inside
(these layers are called `required_layers` below).

Parameters
----------
dsk : HighLevelGraph
Graph of interest.
Task graph to optimize.
keys : Any
Keys that are requested by the compute that is being
optimized.

Returns
-------
list[str]
Names of the output layers.
HighLevelGraph
New, optimized task graph.

"""
return _layers_with_annotation(dsk, "ak_output")


def _opt_touch_all_layer_names(dsk: HighLevelGraph) -> list[str]:
return [n for n, v in dsk.layers.items() if hasattr(v, "_opt_touch_all")]
# return _layers_with_annotation(dsk, "ak_touch_all")


def _has_projectable_awkward_io_layer(dsk: HighLevelGraph) -> bool:
"""Check if a graph at least one AwkwardInputLayer that is project-able."""
for _, v in dsk.layers.items():
if isinstance(v, AwkwardInputLayer) and hasattr(v.io_func, "project_columns"):
return True
return False


def _touch_all_data(*args, **kwargs):
"""Mock writing an ak.Array to disk by touching data buffers."""
import awkward as ak

for arg in args + tuple(kwargs.values()):
ak.typetracer.touch_data(arg)


def _mock_output(layer):
"""Update a layer to run the _touch_all_data."""
assert len(layer.dsk) == 1

new_layer = copy.deepcopy(layer)
mp = new_layer.dsk.copy()
for k in iter(mp.keys()):
mp[k] = (_touch_all_data,) + mp[k][1:]
new_layer.dsk = mp
return new_layer


def _touch_and_call_fn(fn, *args, **kwargs):
_touch_all_data(*args, **kwargs)
return fn(*args, **kwargs)


def _touch_and_call(layer):
assert len(layer.dsk) == 1

new_layer = copy.deepcopy(layer)
mp = new_layer.dsk.copy()
for k in iter(mp.keys()):
mp[k] = (_touch_and_call_fn,) + mp[k]
new_layer.dsk = mp
return new_layer


def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph:
# dask.optimization.fuse_liner for blockwise layers
import copy

chains = []
deps = copy.copy(dsk.dependencies)

required_layers = {k[0] for k in keys}
layers = {}
# find chains; each chain list is at least two keys long
dependents = dsk.dependents
Expand All @@ -250,6 +183,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph:
and dsk.dependencies[list(children)[0]] == {lay}
and isinstance(dsk.layers[list(children)[0]], AwkwardBlockwiseLayer)
and len(dsk.layers[lay]) == len(dsk.layers[list(children)[0]])
and lay not in required_layers
):
# walk forwards
lay = list(children)[0]
Expand All @@ -263,6 +197,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph:
and dependents[list(parents)[0]] == {lay}
and isinstance(dsk.layers[list(parents)[0]], AwkwardBlockwiseLayer)
and len(dsk.layers[lay]) == len(dsk.layers[list(parents)[0]])
and list(parents)[0] not in required_layers
):
# walk backwards
lay = list(parents)[0]
Expand Down Expand Up @@ -316,6 +251,100 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph:
return HighLevelGraph(layers, deps)


def _projectable_input_layer_names(dsk: HighLevelGraph) -> list[str]:
"""Get list of column-projectable AwkwardInputLayer names.

Parameters
----------
dsk : HighLevelGraph
Task graph of interest

Returns
-------
list[str]
Names of the AwkwardInputLayers in the graph that are
column-projectable.

"""
return [
n
for n, v in dsk.layers.items()
if isinstance(v, AwkwardInputLayer) and hasattr(v.io_func, "project_columns")
# following condition means dep/pickled layers cannot be optimised
and hasattr(v, "_meta")
]


def _layers_with_annotation(dsk: HighLevelGraph, key: str) -> list[str]:
return [n for n, v in dsk.layers.items() if (v.annotations or {}).get(key)]


def _ak_output_layer_names(dsk: HighLevelGraph) -> list[str]:
"""Get a list output layer names.

Output layer names are annotated with 'ak_output'.

Parameters
----------
dsk : HighLevelGraph
Graph of interest.

Returns
-------
list[str]
Names of the output layers.

"""
return _layers_with_annotation(dsk, "ak_output")


def _opt_touch_all_layer_names(dsk: HighLevelGraph) -> list[str]:
return [n for n, v in dsk.layers.items() if hasattr(v, "_opt_touch_all")]
# return _layers_with_annotation(dsk, "ak_touch_all")


def _has_projectable_awkward_io_layer(dsk: HighLevelGraph) -> bool:
"""Check if a graph at least one AwkwardInputLayer that is project-able."""
for _, v in dsk.layers.items():
if isinstance(v, AwkwardInputLayer) and hasattr(v.io_func, "project_columns"):
return True
return False


def _touch_all_data(*args, **kwargs):
"""Mock writing an ak.Array to disk by touching data buffers."""
for arg in args + tuple(kwargs.values()):
touch_data(arg)


def _mock_output(layer):
"""Update a layer to run the _touch_all_data."""
assert len(layer.dsk) == 1

new_layer = copy.deepcopy(layer)
mp = new_layer.dsk.copy()
for k in iter(mp.keys()):
mp[k] = (_touch_all_data,) + mp[k][1:]
new_layer.dsk = mp
return new_layer


def _touch_and_call_fn(fn, *args, **kwargs):
_touch_all_data(*args, **kwargs)
return fn(*args, **kwargs)


def _touch_and_call(layer):
assert len(layer.dsk) == 1

new_layer = copy.deepcopy(layer)
mp = new_layer.dsk.copy()
for k in iter(mp.keys()):
mp[k] = (_touch_and_call_fn,) + mp[k]
new_layer.dsk = mp
return new_layer


def _recursive_replace(args, layer, parent, indices):
args2 = []
for arg in args:
Expand Down Expand Up @@ -393,7 +422,7 @@ def _get_column_reports(dsk: HighLevelGraph) -> dict[str, Any]:
results = get_sync(hlg, leaf_layers_keys)
for out in results:
if isinstance(out, (ak.Array, ak.Record)):
ak.typetracer.touch_data(out)
touch_data(out)
except Exception as err:
on_fail = dask.config.get("awkward.optimization.on-fail")
# this is the default, throw a warning but skip the optimization.
Expand Down
49 changes: 49 additions & 0 deletions tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dask

import dask_awkward as dak
from dask_awkward.lib.testutils import assert_eq


def test_multiple_computes(pq_points_dir: str) -> None:
Expand All @@ -27,3 +28,51 @@ def test_multiple_computes(pq_points_dir: str) -> None:

things = dask.compute(ds1.points, ds2.points.x, ds2.points.y, ds1.points.y, ds3)
assert things[-1].tolist() == ak.Array(lists[0] + lists[1]).tolist() # type: ignore


def identity(x):
return x


def test_multiple_compute_incapsulated():
array = ak.Array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])[[0, 2]]
darray = dak.from_awkward(array, 1)
darray_result = darray.map_partitions(identity)

first, second = dask.compute(darray, darray_result)

assert ak.almost_equal(first, second)
assert first.layout.form == second.layout.form


def test_multiple_computes_multiple_incapsulated(daa, caa):
dstep1 = daa.points.x
dstep2 = dstep1**2
dstep3 = dstep2 + 2
dstep4 = dstep3 - 1
dstep5 = dstep4 - dstep2

cstep1 = caa.points.x
cstep2 = cstep1**2
cstep3 = cstep2 + 2
cstep4 = cstep3 - 1
cstep5 = cstep4 - cstep2

# multiple computes all work and evaluate to the expected result
c5, c4, c2 = dask.compute(dstep5, dstep4, dstep2)
assert_eq(c5, cstep5)
assert_eq(c2, cstep2)
assert_eq(c4, cstep4)

# if optimized together we still have 2 layers
opt4, opt3 = dask.optimize(dstep4, dstep3)
assert len(opt4.dask.layers) == 2
assert len(opt3.dask.layers) == 2
assert_eq(opt4, cstep4)
assert_eq(opt3, cstep3)

# if optimized alone we get optimized to 1 entire chain smushed
# down to 1 layer
(opt4_alone,) = dask.optimize(dstep4)
assert len(opt4_alone.dask.layers) == 1
assert_eq(opt4_alone, opt4)
Loading