Skip to content

Commit

Permalink
Add a default dtype for coord classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed May 16, 2024
1 parent 5026638 commit 3c55c39
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions src/vector/backends/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,8 @@ class AzimuthalNumpyXY(AzimuthalNumpy, AzimuthalXY, GetItem, FloatArray): # typ
dtype: numpy.dtype[typing.Any] = numpy.dtype([("x", float), ("y", float)])

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> AzimuthalNumpyXY:
if "dtype" in kwargs:
AzimuthalNumpyXY.dtype = numpy.dtype(kwargs["dtype"])
return numpy.array(*args, **kwargs).view(cls)

def __array_finalize__(self, obj: typing.Any) -> None:
Expand All @@ -458,13 +460,11 @@ def __array_finalize__(self, obj: typing.Any) -> None:
def __eq__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, AzimuthalNumpyXY):
return False

Check warning on line 462 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L462

Added line #L462 was not covered by tests

return all(coord1 == coord2 for coord1, coord2 in zip(self, other))

def __ne__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, AzimuthalNumpyXY):
return True

return any(coord1 != coord2 for coord1, coord2 in zip(self, other))

Check warning on line 468 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L466-L468

Added lines #L466 - L468 were not covered by tests

@property
Expand Down Expand Up @@ -510,6 +510,8 @@ class AzimuthalNumpyRhoPhi(AzimuthalNumpy, AzimuthalRhoPhi, GetItem, FloatArray)
dtype: numpy.dtype[typing.Any] = numpy.dtype([("rho", float), ("phi", float)])

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> AzimuthalNumpyRhoPhi:
if "dtype" in kwargs:
AzimuthalNumpyRhoPhi.dtype = numpy.dtype(kwargs["dtype"])

Check warning on line 514 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L513-L514

Added lines #L513 - L514 were not covered by tests
return numpy.array(*args, **kwargs).view(cls)

def __array_finalize__(self, obj: typing.Any) -> None:
Expand All @@ -522,13 +524,11 @@ def __array_finalize__(self, obj: typing.Any) -> None:
def __eq__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, AzimuthalNumpyRhoPhi):
return False

return all(coord1 == coord2 for coord1, coord2 in zip(self, other))

Check warning on line 527 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L525-L527

Added lines #L525 - L527 were not covered by tests

def __ne__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, AzimuthalNumpyRhoPhi):
return True

return any(coord1 != coord2 for coord1, coord2 in zip(self, other))

Check warning on line 532 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L530-L532

Added lines #L530 - L532 were not covered by tests

@property
Expand Down Expand Up @@ -573,6 +573,8 @@ class LongitudinalNumpyZ(LongitudinalNumpy, LongitudinalZ, GetItem, FloatArray):
dtype: numpy.dtype[typing.Any] = numpy.dtype([("z", float)])

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> LongitudinalNumpyZ:
if "dtype" in kwargs:
LongitudinalNumpyZ.dtype = numpy.dtype(kwargs["dtype"])
return numpy.array(*args, **kwargs).view(cls)

def __array_finalize__(self, obj: typing.Any) -> None:
Expand All @@ -585,13 +587,11 @@ def __array_finalize__(self, obj: typing.Any) -> None:
def __eq__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyZ):
return False

Check warning on line 589 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L589

Added line #L589 was not covered by tests

return all(coord1 == coord2 for coord1, coord2 in zip(self, other))

def __ne__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyZ):
return True

return any(coord1 != coord2 for coord1, coord2 in zip(self, other))

Check warning on line 595 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L593-L595

Added lines #L593 - L595 were not covered by tests

@property
Expand Down Expand Up @@ -631,6 +631,8 @@ class LongitudinalNumpyTheta(LongitudinalNumpy, LongitudinalTheta, GetItem, Floa
dtype: numpy.dtype[typing.Any] = numpy.dtype([("theta", float)])

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> LongitudinalNumpyTheta:
if "dtype" in kwargs:
LongitudinalNumpyTheta.dtype = numpy.dtype(kwargs["dtype"])

Check warning on line 635 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L634-L635

Added lines #L634 - L635 were not covered by tests
return numpy.array(*args, **kwargs).view(cls)

def __array_finalize__(self, obj: typing.Any) -> None:
Expand All @@ -643,13 +645,11 @@ def __array_finalize__(self, obj: typing.Any) -> None:
def __eq__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyTheta):
return False

return all(coord1 == coord2 for coord1, coord2 in zip(self, other))

Check warning on line 648 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L646-L648

Added lines #L646 - L648 were not covered by tests

def __ne__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyTheta):
return True

return any(coord1 != coord2 for coord1, coord2 in zip(self, other))

Check warning on line 653 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L651-L653

Added lines #L651 - L653 were not covered by tests

@property
Expand Down Expand Up @@ -689,6 +689,8 @@ class LongitudinalNumpyEta(LongitudinalNumpy, LongitudinalEta, GetItem, FloatArr
dtype: numpy.dtype[typing.Any] = numpy.dtype([("eta", float)])

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> LongitudinalNumpyEta:
if "dtype" in kwargs:
LongitudinalNumpyEta.dtype = numpy.dtype(kwargs["dtype"])

Check warning on line 693 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L692-L693

Added lines #L692 - L693 were not covered by tests
return numpy.array(*args, **kwargs).view(cls)

def __array_finalize__(self, obj: typing.Any) -> None:
Expand All @@ -701,13 +703,11 @@ def __array_finalize__(self, obj: typing.Any) -> None:
def __eq__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyEta):
return False

return all(coord1 == coord2 for coord1, coord2 in zip(self, other))

Check warning on line 706 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L704-L706

Added lines #L704 - L706 were not covered by tests

def __ne__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyEta):
return True

return any(coord1 != coord2 for coord1, coord2 in zip(self, other))

Check warning on line 711 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L709-L711

Added lines #L709 - L711 were not covered by tests

@property
Expand Down Expand Up @@ -747,6 +747,8 @@ class TemporalNumpyT(TemporalNumpy, TemporalT, GetItem, FloatArray): # type: ig
dtype: numpy.dtype[typing.Any] = numpy.dtype([("t", float)])

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> TemporalNumpyT:
if "dtype" in kwargs:
TemporalNumpyT.dtype = numpy.dtype(kwargs["dtype"])
return numpy.array(*args, **kwargs).view(cls)

def __array_finalize__(self, obj: typing.Any) -> None:
Expand All @@ -759,13 +761,11 @@ def __array_finalize__(self, obj: typing.Any) -> None:
def __eq__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, TemporalNumpyT):
return False

Check warning on line 763 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L763

Added line #L763 was not covered by tests

return all(coord1 == coord2 for coord1, coord2 in zip(self, other))

def __ne__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, TemporalNumpyT):
return True

return any(coord1 != coord2 for coord1, coord2 in zip(self, other))

Check warning on line 769 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L767-L769

Added lines #L767 - L769 were not covered by tests

@property
Expand Down Expand Up @@ -797,6 +797,8 @@ class TemporalNumpyTau(TemporalNumpy, TemporalTau, GetItem, FloatArray): # type
dtype: numpy.dtype[typing.Any] = numpy.dtype([("tau", float)])

def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> TemporalNumpyTau:
if "dtype" in kwargs:
TemporalNumpyTau.dtype = numpy.dtype(kwargs["dtype"])

Check warning on line 801 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L800-L801

Added lines #L800 - L801 were not covered by tests
return numpy.array(*args, **kwargs).view(cls)

def __array_finalize__(self, obj: typing.Any) -> None:
Expand All @@ -809,13 +811,11 @@ def __array_finalize__(self, obj: typing.Any) -> None:
def __eq__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, TemporalNumpyTau):
return False

return all(coord1 == coord2 for coord1, coord2 in zip(self, other))

Check warning on line 814 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L812-L814

Added lines #L812 - L814 were not covered by tests

def __ne__(self, other: typing.Any) -> bool:
if self.dtype != other.dtype or not isinstance(other, TemporalNumpyTau):
return True

return any(coord1 != coord2 for coord1, coord2 in zip(self, other))

Check warning on line 819 in src/vector/backends/numpy.py

View check run for this annotation

Codecov / codecov/patch

src/vector/backends/numpy.py#L817-L819

Added lines #L817 - L819 were not covered by tests

@property
Expand Down

0 comments on commit 3c55c39

Please sign in to comment.