Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis committed Sep 25, 2023
1 parent ffa20fb commit 2ac41fc
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph:
chains = []
deps = copy.copy(dsk.dependencies)

# TODO: add some comments to the chaining algorithm w.r.t. when we
# use it and when we don't.
required_layers = {k[0] for k in keys}
layers = {}
# find chains; each chain list is at least two keys long
dependents = dsk.dependents
Expand All @@ -250,6 +253,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph:
and dsk.dependencies[list(children)[0]] == {lay}
and isinstance(dsk.layers[list(children)[0]], AwkwardBlockwiseLayer)
and len(dsk.layers[lay]) == len(dsk.layers[list(children)[0]])
and lay not in required_layers
):
# walk forwards
lay = list(children)[0]
Expand All @@ -263,6 +267,7 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Any) -> HighLevelGraph:
and dependents[list(parents)[0]] == {lay}
and isinstance(dsk.layers[list(parents)[0]], AwkwardBlockwiseLayer)
and len(dsk.layers[lay]) == len(dsk.layers[list(parents)[0]])
and list(parents)[0] not in required_layers
):
# walk backwards
lay = list(parents)[0]
Expand Down
49 changes: 49 additions & 0 deletions tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dask

import dask_awkward as dak
from dask_awkward.lib.testutils import assert_eq


def test_multiple_computes(pq_points_dir: str) -> None:
Expand All @@ -27,3 +28,51 @@ def test_multiple_computes(pq_points_dir: str) -> None:

things = dask.compute(ds1.points, ds2.points.x, ds2.points.y, ds1.points.y, ds3)
assert things[-1].tolist() == ak.Array(lists[0] + lists[1]).tolist() # type: ignore


def identity(x):
return x


def test_multiple_compute_incapsulated():
array = ak.Array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])[[0, 2]]
darray = dak.from_awkward(array, 1)
darray_result = darray.map_partitions(identity)

first, second = dask.compute(darray, darray_result)

assert ak.almost_equal(first, second)
assert first.layout.form == second.layout.form


def test_multiple_computes_multiple_incapsulated(daa, caa):
dstep1 = daa.points.x
dstep2 = dstep1**2
dstep3 = dstep2 + 2
dstep4 = dstep3 - 1
dstep5 = dstep4 - dstep2

cstep1 = caa.points.x
cstep2 = cstep1**2
cstep3 = cstep2 + 2
cstep4 = cstep3 - 1
cstep5 = cstep4 - cstep2

# multiple computes all work and evaluate to the expected result
c5, c4, c2 = dask.compute(dstep5, dstep4, dstep2)
assert_eq(c5, cstep5)
assert_eq(c2, cstep2)
assert_eq(c4, cstep4)

# if optimized together we still have 2 layers
opt4, opt3 = dask.optimize(dstep4, dstep3)
assert len(opt4.dask.layers) == 2
assert len(opt3.dask.layers) == 2
assert_eq(opt4, cstep4)
assert_eq(opt3, cstep3)

# if optimized alone we get optimized to 1 entire chain smushed
# down to 1 layer
(opt4_alone,) = dask.optimize(dstep4)
assert len(opt4_alone.dask.layers) == 1
assert_eq(opt4_alone, opt4)

0 comments on commit 2ac41fc

Please sign in to comment.