Skip to content

Commit

Permalink
fix awkward tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Mar 15, 2024
1 parent ec24efd commit 619442a
Showing 1 changed file with 15 additions and 20 deletions.
35 changes: 15 additions & 20 deletions tests/backends/test_awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@

import vector
from vector import VectorObject2D
from vector.backends.awkward import (
MomentumAwkward2D,
MomentumAwkward3D,
MomentumAwkward4D,
)

ak = pytest.importorskip("awkward")

Expand Down Expand Up @@ -880,36 +875,36 @@ def test_momentum_preservation():

# momentum + generic = momentum
# 2D + 3D.like(2D) = 2D
assert isinstance(v1 + v2.like(v1), MomentumAwkward2D)
assert isinstance(v2.like(v1) + v1, MomentumAwkward2D)
assert isinstance(v1 + v2.like(v1), vector.backends.awkward.MomentumAwkward2D)
assert isinstance(v2.like(v1) + v1, vector.backends.awkward.MomentumAwkward2D)
# 2D + 4D.like(2D) = 2D
assert isinstance(v1 + v3.like(v1), MomentumAwkward2D)
assert isinstance(v3.like(v1) + v1, MomentumAwkward2D)
assert isinstance(v1 + v3.like(v1), vector.backends.awkward.MomentumAwkward2D)
assert isinstance(v3.like(v1) + v1, vector.backends.awkward.MomentumAwkward2D)
# 3D + 2D.like(3D) = 3D
assert isinstance(v2 + v1.like(v2), MomentumAwkward3D)
assert isinstance(v1.like(v2) + v2, MomentumAwkward3D)
assert isinstance(v2 + v1.like(v2), vector.backends.awkward.MomentumAwkward3D)
assert isinstance(v1.like(v2) + v2, vector.backends.awkward.MomentumAwkward3D)
# 3D + 4D.like(3D) = 3D
assert isinstance(v2 + v3.like(v2), MomentumAwkward3D)
assert isinstance(v3.like(v2) + v2, MomentumAwkward3D)
assert isinstance(v2 + v3.like(v2), vector.backends.awkward.MomentumAwkward3D)
assert isinstance(v3.like(v2) + v2, vector.backends.awkward.MomentumAwkward3D)
# 4D + 2D.like(4D) = 4D
assert isinstance(v3 + v1.like(v3), MomentumAwkward4D)
assert isinstance(v1.like(v3) + v3, MomentumAwkward4D)
assert isinstance(v3 + v1.like(v3), vector.backends.awkward.MomentumAwkward4D)
assert isinstance(v1.like(v3) + v3, vector.backends.awkward.MomentumAwkward4D)
# 4D + 3D.like(4D) = 4D
assert isinstance(v3 + v2.like(v3), MomentumAwkward4D)
assert isinstance(v2.like(v3) + v3, MomentumAwkward4D)
assert isinstance(v3 + v2.like(v3), vector.backends.awkward.MomentumAwkward4D)
assert isinstance(v2.like(v3) + v3, vector.backends.awkward.MomentumAwkward4D)


def test_subclass_fields():
@ak.mixin_class(vector.backends.awkward.behavior)
class TwoVector(MomentumAwkward2D):
class TwoVector(vector.backends.awkward.MomentumAwkward2D):
pass

@ak.mixin_class(vector.backends.awkward.behavior)
class ThreeVector(MomentumAwkward3D):
class ThreeVector(vector.backends.awkward.MomentumAwkward3D):
pass

@ak.mixin_class(vector.backends.awkward.behavior)
class LorentzVector(MomentumAwkward4D):
class LorentzVector(vector.backends.awkward.MomentumAwkward4D):
@ak.mixin_class_method(np.divide, {numbers.Number})
def divide(self, factor):
return self.scale(1 / factor)
Expand Down

0 comments on commit 619442a

Please sign in to comment.