Skip to content

Commit

Permalink
feat: add 'dak.backend' and verify that it overloads 'ak.backend'
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Apr 11, 2024
1 parent b49b967 commit b96e9f0
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/dask_awkward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
map_partitions,
partition_compatibility,
)
from dask_awkward.lib.describe import fields
from dask_awkward.lib.describe import backend, fields
from dask_awkward.lib.inspect import (
report_necessary_buffers,
report_necessary_columns,
Expand Down
2 changes: 1 addition & 1 deletion src/dask_awkward/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
map_partitions,
partition_compatibility,
)
from dask_awkward.lib.describe import fields
from dask_awkward.lib.describe import backend, fields
from dask_awkward.lib.inspect import (
report_necessary_buffers,
report_necessary_columns,
Expand Down
23 changes: 22 additions & 1 deletion src/dask_awkward/lib/describe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from dask_awkward.lib.core import Array, Record
import awkward as ak

from dask_awkward.lib.core import Array, Record, Scalar


def fields(collection: Array | Record) -> list[str] | None:
Expand All @@ -19,3 +21,22 @@ def fields(collection: Array | Record) -> list[str] | None:
"""
return collection.fields


def backend(*arrays: Array | Record) -> str:
"""Get the name of the backend used by `arrays`.
Parameters
----------
arrays : dask_awkward.Array or dask_awkward.Record
Array or Record collection
Returns
-------
str
The backend name, which is always `"typetracer"` for
dask-awkward arrays.
"""
return ak.backend(
*[x._meta if isinstance(x, (Array, Record, Scalar)) else x for x in arrays]
)
18 changes: 14 additions & 4 deletions tests/test_describe.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from typing import Any

import awkward as ak
import pytest

import dask_awkward as dak


def test_fields(ndjson_points_file: str) -> None:
@pytest.mark.parametrize("quak", [ak, dak])
def test_fields(ndjson_points_file: str, quak: Any) -> None:
daa = dak.from_json([ndjson_points_file] * 2)
# records fields same as array of records fields
assert dak.fields(daa[0].points) == dak.fields(daa.points)
assert quak.fields(daa[0].points) == quak.fields(daa.points)
# computed is same as collection
assert dak.fields(daa) == ak.fields(daa.compute())
assert quak.fields(daa) == ak.fields(daa.compute())
daa.reset_meta()
# removed meta gives None fields
assert dak.fields(daa) == []
assert quak.fields(daa) == []


@pytest.mark.parametrize("quak", [ak, dak])
def test_backend(ndjson_points_file: str, quak: Any) -> None:
daa = dak.from_json([ndjson_points_file] * 2)
assert quak.backend(daa) == "typetracer"
1 change: 1 addition & 0 deletions tests/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def test_visualize_works(daa):
dask.compute(query, optimize_graph=True)


@pytest.mark.xfail()
def test_basic_root_works():
pytest.importorskip("hist")
pytest.importorskip("uproot")
Expand Down

0 comments on commit b96e9f0

Please sign in to comment.