Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis committed Dec 7, 2023
1 parent 2fc76d9 commit 9153f0b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/dask_awkward/lib/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def report_failure(exception, *args, **kwargs):
return ak.Array(
[
{
"duration": None,
"columns": [],
"args": [repr(a) for a in args],
"kwargs": [[k, repr(v)] for k, v in kwargs.items()],
"exception": type(exception).__name__,
Expand All @@ -47,11 +47,11 @@ def report_failure(exception, *args, **kwargs):
)


def report_success(duration, *args, **kwargs):
def report_success(columns, *args, **kwargs):
return ak.Array(
[
{
"duration": duration,
"columns": columns,
"args": [repr(a) for a in args],
"kwargs": [[k, repr(v)] for k, v in kwargs.items()],
"exception": None,
Expand Down Expand Up @@ -172,7 +172,7 @@ def __call__(self, *args, **kwargs):
if self.return_report:
try:
result = self.read_fn(source)
return result, report_success(source)
return result, report_success(self.columns, source)
except self.allowed_exceptions as err:
return self.mock_empty(), report_failure(err, source)

Expand Down
13 changes: 13 additions & 0 deletions tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
pytest.importorskip("aiohttp")

import awkward as ak
import dask
import fsspec
import pyarrow as pa
import pyarrow.dataset as pad
Expand Down Expand Up @@ -201,3 +202,15 @@ def test_to_parquet_with_prefix(
assert fname.startswith(f"{prefix}")
else:
assert fname.startswith("part")


def test_from_parquet_with_report(
daa: dak.Array,
tmp_path_factory: pytest.TempPathFactory,
) -> None:
p = str(tmp_path_factory.mktemp("from_pq_with_report"))
dak.to_parquet(daa, p)
ds, report = dak.from_parquet(p, report=True)
c_ds, c_report = dask.compute(dak.max(ds.points.x, axis=1), report)
assert len(c_ds)
assert c_report.columns.tolist()[0] == ["points.list.item.x"]

0 comments on commit 9153f0b

Please sign in to comment.