Skip to content

Commit

Permalink
refactor: use public typetracer API (#894)
Browse files Browse the repository at this point in the history
* refactor: use public typetracer API

* refactor: use newer API

* chore: bump dask-awkward version

* refactor: use `ak.almost_equal` instead of dask-awkward's testutils

---------

Co-authored-by: Jim Pivarski <jpivarski@users.noreply.github.com>
  • Loading branch information
agoose77 and jpivarski authored Sep 14, 2023
1 parent 39dff81 commit ac23179
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 27 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ dynamic = [
[project.optional-dependencies]
dev = [
"boost_histogram>=0.13",
"dask-awkward>=2022.12a3;python_version >= \"3.8\"",
"dask-awkward>=2023.9.0;python_version >= \"3.8\"",
"dask[array];python_version >= \"3.8\"",
"hist>=1.2",
"pandas",
Expand Down
16 changes: 6 additions & 10 deletions src/uproot/_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,11 +806,9 @@ def project_columns(self, common_keys=None, original_form=None):
if self.form_mapping is not None:
awkward = uproot.extras.awkward()

(
new_meta_labelled,
report,
) = awkward._nplikes.typetracer.typetracer_with_report(self.rendered_form)
tt = awkward.Array(new_meta_labelled)
tt, report = awkward.typetracer.typetracer_with_report(
self.rendered_form, highlevel=True
)

if common_keys is not None:
for key in common_keys:
Expand Down Expand Up @@ -950,11 +948,9 @@ def project_columns(self, columns=None, original_form=None):
if self.form_mapping is not None:
awkward = uproot.extras.awkward()

(
new_meta_labelled,
report,
) = awkward._nplikes.typetracer.typetracer_with_report(self.rendered_form)
tt = awkward.Array(new_meta_labelled)
tt, report = awkward.typetracer.typetracer_with_report(
self.rendered_form, highlevel=True
)

if columns is not None:
for key in columns:
Expand Down
8 changes: 7 additions & 1 deletion src/uproot/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,14 @@ def dask_awkward():
or
conda install -c conda-forge dask dask-awkward"""
) from err
else:
if parse_version("2023.9.0") <= parse_version(dask_awkward.__version__):
return dask_awkward
else:
raise ModuleNotFoundError(
"Uproot 5.x can only be used with dask-awkward 2023.9.0 or newer; you have dask-awkward {}".format(
dask_awkward.__version__
)
)


def awkward_pandas():
Expand Down
9 changes: 3 additions & 6 deletions tests/test_0700-dask-empty-arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
import numpy
import pytest
import skhep_testdata

import awkward as ak
import uproot

dask = pytest.importorskip("dask")
dask_awkward = pytest.importorskip("dask_awkward")

pytest.importorskip("pyarrow") # dask_awkward.lib.testutils needs pyarrow
from dask_awkward.lib.testutils import assert_eq


def test_dask_numpy_empty_arrays():
test_path = skhep_testdata.data_path("uproot-issue-697.root") + ":tree"
Expand Down Expand Up @@ -52,7 +49,7 @@ def test_dask_awkward_empty_arrays():
ak_array = ttree.arrays()
dak_array = uproot.dask(test_path, library="ak")

assert_eq(dak_array, ak_array)
assert ak.almost_equal(dak_array.compute(scheduler="synchronous"), ak_array)


def test_dask_delayed_open_awkward():
Expand All @@ -62,7 +59,7 @@ def test_dask_delayed_open_awkward():
ak_array = ttree.arrays()
dak_array = uproot.dask(test_path, library="ak", open_files=False)

assert_eq(dak_array, ak_array)
ak.almost_equal(dak_array.compute(scheduler="synchronous"), ak_array)


def test_no_common_tree_branches():
Expand Down
9 changes: 4 additions & 5 deletions tests/test_0755-dask-awkward-column-projection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/uproot5/blob/main/LICENSE

import numpy
import awkward as ak
import pytest
import skhep_testdata

Expand All @@ -9,9 +10,6 @@
dask = pytest.importorskip("dask")
dask_awkward = pytest.importorskip("dask_awkward")

pytest.importorskip("pyarrow") # dask_awkward.lib.testutils needs pyarrow
from dask_awkward.lib.testutils import assert_eq


def test_column_projection_sanity_check():
test_path = skhep_testdata.data_path("uproot-Zmumu.root") + ":events"
Expand All @@ -20,6 +18,7 @@ def test_column_projection_sanity_check():
ak_array = ttree.arrays()
dak_array = uproot.dask(test_path, library="ak")

assert_eq(
dak_array[["px1", "px2", "py1", "py2"]], ak_array[["px1", "px2", "py1", "py2"]]
assert ak.almost_equal(
dak_array[["px1", "px2", "py1", "py2"]].compute(scheduler="synchronous"),
ak_array[["px1", "px2", "py1", "py2"]],
)
9 changes: 5 additions & 4 deletions tests/test_0876-uproot-dask-blind-steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import numpy
import pytest
import skhep_testdata
import awkward as ak

import uproot

dask_awkward = pytest.importorskip("dask_awkward")
pytest.importorskip("pyarrow") # dask_awkward.lib.testutils needs pyarrow
from dask_awkward.lib.testutils import assert_eq


@pytest.mark.parametrize("library", ["np", "ak"])
Expand Down Expand Up @@ -60,7 +59,9 @@ def test_uproot_dask_steps(library, step_size, steps_per_file, open_files):
assert all(comp), f"Incorrect array at key {key}"

else:
assert_eq(
dask_arrays[["px1", "px2", "py1", "py2"]],
assert ak.almost_equal(
dask_arrays[["px1", "px2", "py1", "py2"]].compute(
scheduler="synchronous"
),
arrays[["px1", "px2", "py1", "py2"]],
)

0 comments on commit ac23179

Please sign in to comment.