Skip to content

Commit

Permalink
chore: support awkward v1 and v2 together (#226)
Browse files Browse the repository at this point in the history
* Upper cap awkward to v2

* Update `environment.yml`

* Support both `awkward` `v1` and `v2`

* refactor: run tests with awkward._v2

Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com>

* Test `awkward._v2` backend in workflows

* Fix few tests

* patch sys.modules['awkward']._v2 + consistent import names

* Revert some changes

* fix: move __getitem__ to Protocol

Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com>

* chore: pre-commit autoupdate

Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com>

* ci: require Awkward RC for v2

* Update pickle file for `v2`

* Use quotes

* Mark failing tests with `xfail`

* Tidy up

* Sync with `awkward` `v1.9.0`

Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com>
Co-authored-by: Henry Schreiner <henryschreineriii@gmail.com>
  • Loading branch information
Saransh-cpp and henryiii authored Sep 4, 2022
1 parent 709964c commit 1f09bb2
Show file tree
Hide file tree
Showing 11 changed files with 125 additions and 32 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,15 @@ jobs:
- name: Install develop extras
run: python -m pip install -e .[dev]

- name: Test package
- name: Test package with awkward v1.x
run: python -m pytest -ra --cov=vector tests/

- name: Use awkward v1.9.0 for v2 support
run: python -m pip install -U awkward==1.9.0

- name: Test package with awkward._v2
run: VECTOR_USE_AWKWARDV2=1 python -m pytest -ra --cov=vector tests/

- name: Run doctests
run: xdoctest ./src/vector/

Expand Down
9 changes: 5 additions & 4 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ dependencies:
- jupyterlab >=1.2
- nb_conda_kernels
- pip >=18
- pytest
- root >=6.18.04
- pytest >=6
- numba >=0.50
- numpy >=1.13.3
- root >=6.18.04
- pip:
- "awkward1>=0.2.29"
- "uproot==3.*"
- "awkward>=1.2.0"
- "uproot==4.*"
- "scikit-hep-testdata>=0.2.0"
- -e .
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ dev = [
"xdoctest>=1",
]
docs = [
"awkward",
"awkward>=1.2",
"ipykernel",
"myst-parser>0.13",
"nbsphinx",
Expand Down
19 changes: 17 additions & 2 deletions src/vector/_typeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,24 @@
import numpy

if sys.version_info < (3, 8):
from typing_extensions import TypedDict
from typing_extensions import Protocol, TypedDict
else:
from typing import TypedDict
from typing import Protocol, TypedDict


__all__ = [
"Protocol",
"ScalarCollection",
"BoolCollection",
"TransformProtocol2D",
"TransformProtocol3D",
"TransformProtocol4D",
"FloatArray",
]


def __dir__() -> typing.List[str]:
return __all__


# Represents a number, a NumPy array, an Awkward Array, etc., of non-vectors.
Expand Down
44 changes: 29 additions & 15 deletions src/vector/backends/awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
Vector4D,
VectorProtocol,
)
from vector._typeutils import BoolCollection, ScalarCollection
from vector._typeutils import BoolCollection, Protocol, ScalarCollection
from vector.backends.numpy import VectorNumpy2D, VectorNumpy3D, VectorNumpy4D
from vector.backends.object import (
AzimuthalObjectRhoPhi,
Expand Down Expand Up @@ -194,7 +194,7 @@ def from_fields(cls, array: ak.Array) -> "LongitudinalAwkward":
Examples:
>>> import vector
>>> import awkward as ak
>>> import awkward as ak
>>> a = ak.Array([{"theta": [1, 0]}])
>>> l = vector.backends.awkward.LongitudinalAwkward.from_fields(a)
>>> l
Expand Down Expand Up @@ -225,7 +225,7 @@ def from_momentum_fields(cls, array: ak.Array) -> "LongitudinalAwkward":
Examples:
>>> import vector
>>> import awkward as ak
>>> import awkward as ak
>>> a = ak.Array([{"theta": [1, 0]}])
>>> l = vector.backends.awkward.LongitudinalAwkward.from_momentum_fields(a)
>>> l
Expand Down Expand Up @@ -272,7 +272,7 @@ def from_fields(cls, array: ak.Array) -> "TemporalAwkward":
Examples:
>>> import vector
>>> import awkward as ak
>>> import awkward as ak
>>> a = ak.Array([{"tau": [1, 0]}])
>>> t = vector.backends.awkward.TemporalAwkward.from_fields(a)
>>> t
Expand Down Expand Up @@ -300,7 +300,7 @@ def from_momentum_fields(cls, array: ak.Array) -> "TemporalAwkward":
Examples:
>>> import vector
>>> import awkward as ak
>>> import awkward as ak
>>> a = ak.Array([{"mass": [1, 0]}])
>>> t = vector.backends.awkward.TemporalAwkward.from_momentum_fields(a)
>>> t
Expand Down Expand Up @@ -594,27 +594,31 @@ def _class_to_name(cls: typing.Type[VectorProtocol]) -> str:
# the vector class ############################################################


def _yes_record(x: ak.Array) -> typing.Optional[typing.Union[float, ak.Record]]:
def _yes_record(
x: ak.Array,
) -> typing.Optional[typing.Union[float, ak.Record]]:
return x[0]


def _no_record(x: ak.Array) -> typing.Optional[ak.Array]:
return x


# Type for mixing in Awkward later
class AwkwardProtocol(Protocol):
def __getitem__(
self, where: typing.Any
) -> typing.Optional[typing.Union[float, ak.Array, ak.Record]]:
...


class VectorAwkward:
"""One dimensional vector class for the Awkward backend."""

lib: types.ModuleType = numpy

def __getitem__(
self, where: typing.Any
) -> typing.Optional[typing.Union[float, ak.Array, ak.Record]]:
# "__getitem__" undefined in superclass
return super().__getitem__(where) # type: ignore[misc]

def _wrap_result(
self,
self: AwkwardProtocol,
cls: typing.Any,
result: typing.Any,
returns: typing.Any,
Expand Down Expand Up @@ -1598,9 +1602,19 @@ class MomentumRecord4D(MomentumAwkward4D, ak.Record): # type: ignore[misc]
def _arraytype_of(awkwardtype: typing.Any, component: str) -> typing.Any:
import numba

if isinstance(awkwardtype, ak._connect._numba.layout.NumpyArrayType):
if isinstance(
awkwardtype,
ak._connect.numba.layout.NumpyArrayType
if hasattr(ak._connect, "numba") # Awkward v2
else ak._connect._numba.layout.NumpyArrayType,
):
return awkwardtype.arraytype
elif isinstance(awkwardtype, ak._connect._numba.layout.IndexedArrayType):
elif isinstance(
awkwardtype,
ak._connect.numba.layout.IndexedArrayType
if hasattr(ak._connect, "numba") # Awkward v2
else ak._connect._numba.layout.IndexedArrayType,
):
return _arraytype_of(awkwardtype.contenttype, component)
else:
raise numba.TypingError(
Expand Down
30 changes: 25 additions & 5 deletions src/vector/backends/awkward_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,17 +214,37 @@ def _is_type_safe(array_type: typing.Any) -> bool:
awkward.types.OptionType,
),
):
array_type = array_type.type
# .content is Awkward v2
array_type = (
array_type.content if hasattr(array_type, "content") else array_type.type
)

if not isinstance(array_type, awkward.types.RecordType):
return False

for field_type in array_type.fields():
# .contents is Awkward v2
contents = (
array_type.contents if hasattr(array_type, "contents") else array_type.fields()
)
for field_type in contents:
if isinstance(field_type, awkward.types.OptionType):
field_type = field_type.type
if not isinstance(field_type, awkward.types.PrimitiveType):
field_type = (
field_type.content
if hasattr(array_type, "content")
else field_type.type
)
if not isinstance(
field_type,
awkward.types.NumpyType
if hasattr(awkward.types, "NumpyType")
else awkward.types.PrimitiveType,
):
return False
dt = field_type.dtype
dt = (
field_type.primitive
if hasattr(field_type, "primitive")
else field_type.dtype
)
if (
not dt.startswith("int")
and not dt.startswith("uint")
Expand Down
10 changes: 9 additions & 1 deletion tests/backends/test_awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# Distributed under the 3-clause BSD license, see accompanying file LICENSE
# or https://github.com/scikit-hep/vector for details.

import os

import pytest

import vector

ak = pytest.importorskip("awkward")
pytest.importorskip("awkward")

pytestmark = pytest.mark.awkward

Expand Down Expand Up @@ -75,6 +77,12 @@ def test_rotateZ():
assert out.wow.tolist() == [[99], [], [123]]


# awkward._v2 has not yet registered NumPy dispatch mechanisms
# see https://github.com/scikit-hep/awkward/issues/1638
# TODO: ensure this passes once awkward v2 is out
@pytest.mark.xfail(
strict=True if os.environ.get("VECTOR_USE_AWKWARDV2") is not None else False
)
def test_projection():
array = vector.Array(
[
Expand Down
8 changes: 8 additions & 0 deletions tests/backends/test_awkward_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Distributed under the 3-clause BSD license, see accompanying file LICENSE
# or https://github.com/scikit-hep/vector for details.

import os

import pytest

import vector
Expand All @@ -15,6 +17,12 @@
pytestmark = [pytest.mark.numba, pytest.mark.awkward]


# awkward._v2 has not yet registered Numba dispatch mechanisms
# see https://github.com/scikit-hep/awkward/discussions/1639
# TODO: ensure this passes once awkward v2 is out
@pytest.mark.xfail(
strict=True if os.environ.get("VECTOR_USE_AWKWARDV2") is not None else False
)
def test():
@numba.njit
def extract(x):
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import os
import sys

if os.environ.get("VECTOR_USE_AWKWARDV2", None):
import awkward._v2

sys.modules["awkward"] = awkward._v2
sys.modules["awkward"]._v2 = awkward._v2
Binary file added tests/samples/issue-161-v2.pkl
Binary file not shown.
19 changes: 16 additions & 3 deletions tests/test_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def test_issue_99():
}


# awkward._v2 has not yet registered Numba dispatch mechanisms
# see https://github.com/scikit-hep/awkward/discussions/1639
# TODO: ensure this passes once awkward v2 is out
@pytest.mark.xfail(
strict=True if os.environ.get("VECTOR_USE_AWKWARDV2") is not None else False
)
def test_issue_161():
ak = pytest.importorskip("awkward")
nb = pytest.importorskip("numba")
Expand All @@ -36,6 +42,13 @@ def repro(generator_like_jet_constituents):
for generator_like_constituent in sublist:
s += generator_like_constituent.pt

with open(os.path.join("tests", "samples", "issue-161.pkl"), "rb") as f:
a = ak.from_buffers(*pickle.load(f))
repro(generator_like_jet_constituents=a.constituents)
file_path = (
os.path.join("tests", "samples", "issue-161.pkl")
if os.getenv("VECTOR_USE_AWKWARDV2") is None
else os.path.join("tests", "samples", "issue-161-v2.pkl")
)

f = open(file_path, "rb")
a = ak.from_buffers(*pickle.load(f))
f.close()
repro(generator_like_jet_constituents=a.constituents)

0 comments on commit 1f09bb2

Please sign in to comment.