Skip to content

Commit

Permalink
fix: momentum coords should not be repeated with generic coords in su…
Browse files Browse the repository at this point in the history
…bclasses (#438)

* fix: momentum coords should not be repeated with generic coords for subclasses

* Added extra coords by mistake

* add tests

* Update noxfile.py
  • Loading branch information
Saransh-cpp authored Mar 15, 2024
1 parent b907615 commit 8de0dfe
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 2 deletions.
44 changes: 42 additions & 2 deletions src/vector/backends/awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,13 @@ def _wrap_result(
fields = ak.fields(self)
if num_vecargs == 1:
for name in fields:
if name not in ("x", "y", "rho", "phi"):
if name not in (
"x",
"y",
"rho",
"pt",
"phi",
):
names.append(name)
arrays.append(self[name])

Expand Down Expand Up @@ -720,12 +726,20 @@ def _wrap_result(
"x",
"y",
"rho",
"pt",
"phi",
"z",
"pz",
"theta",
"eta",
"t",
"tau",
"m",
"M",
"mass",
"e",
"E",
"energy",
):
names.append(name)
arrays.append(self[name])
Expand Down Expand Up @@ -774,7 +788,17 @@ def _wrap_result(
fields = ak.fields(self)
if num_vecargs == 1:
for name in fields:
if name not in ("x", "y", "rho", "phi", "z", "theta", "eta"):
if name not in (
"x",
"y",
"rho",
"pt",
"phi",
"z",
"pz",
"theta",
"eta",
):
names.append(name)
arrays.append(self[name])

Expand Down Expand Up @@ -831,12 +855,20 @@ def _wrap_result(
"x",
"y",
"rho",
"pt",
"phi",
"z",
"pz",
"theta",
"eta",
"t",
"tau",
"m",
"M",
"mass",
"e",
"E",
"energy",
):
names.append(name)
arrays.append(self[name])
Expand Down Expand Up @@ -897,12 +929,20 @@ def _wrap_result(
"x",
"y",
"rho",
"pt",
"phi",
"z",
"pz",
"theta",
"eta",
"t",
"tau",
"m",
"M",
"mass",
"e",
"E",
"energy",
):
names.append(name)
arrays.append(self[name])
Expand Down
38 changes: 38 additions & 0 deletions tests/backends/test_awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from __future__ import annotations

import importlib.metadata
import numbers

import numpy as np
import packaging.version
import pytest

Expand Down Expand Up @@ -895,3 +897,39 @@ def test_momentum_preservation():
# 4D + 3D.like(4D) = 4D
assert isinstance(v3 + v2.like(v3), MomentumAwkward4D)
assert isinstance(v2.like(v3) + v3, MomentumAwkward4D)


def test_subclass_fields():
@ak.mixin_class(vector.backends.awkward.behavior)
class TwoVector(MomentumAwkward2D):
pass

@ak.mixin_class(vector.backends.awkward.behavior)
class ThreeVector(MomentumAwkward3D):
pass

@ak.mixin_class(vector.backends.awkward.behavior)
class LorentzVector(MomentumAwkward4D):
@ak.mixin_class_method(np.divide, {numbers.Number})
def divide(self, factor):
return self.scale(1 / factor)

LorentzVectorArray.ProjectionClass2D = TwoVectorArray # noqa: F821
LorentzVectorArray.ProjectionClass3D = ThreeVectorArray # noqa: F821
LorentzVectorArray.ProjectionClass4D = LorentzVectorArray # noqa: F821
LorentzVectorArray.MomentumClass = LorentzVectorArray # noqa: F821

vec = ak.zip(
{
"pt": [[1, 2], [], [3], [4]],
"eta": [[1.2, 1.4], [], [1.6], [3.4]],
"phi": [[0.3, 0.4], [], [0.5], [0.6]],
"energy": [[50, 51], [], [52], [60]],
},
with_name="LorentzVector",
behavior=vector.backends.awkward.behavior,
)

assert vec.like(vector.obj(x=1, y=2)).fields == ["rho", "phi"]
assert vec.like(vector.obj(x=1, y=2, z=3)).fields == ["rho", "phi", "eta"]
assert (vec / 2).fields == ["rho", "phi", "eta", "t"]

0 comments on commit 8de0dfe

Please sign in to comment.