From 3c55c39dbd9693febeaf14128b862b011a272dbf Mon Sep 17 00:00:00 2001 From: Saransh Chopra Date: Thu, 16 May 2024 21:08:07 +0200 Subject: [PATCH] Add a default dtype for coord classes --- src/vector/backends/numpy.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/vector/backends/numpy.py b/src/vector/backends/numpy.py index f15fb53e..1abfbb9c 100644 --- a/src/vector/backends/numpy.py +++ b/src/vector/backends/numpy.py @@ -446,6 +446,8 @@ class AzimuthalNumpyXY(AzimuthalNumpy, AzimuthalXY, GetItem, FloatArray): # typ dtype: numpy.dtype[typing.Any] = numpy.dtype([("x", float), ("y", float)]) def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> AzimuthalNumpyXY: + if "dtype" in kwargs: + AzimuthalNumpyXY.dtype = numpy.dtype(kwargs["dtype"]) return numpy.array(*args, **kwargs).view(cls) def __array_finalize__(self, obj: typing.Any) -> None: @@ -458,13 +460,11 @@ def __array_finalize__(self, obj: typing.Any) -> None: def __eq__(self, other: typing.Any) -> bool: if self.dtype != other.dtype or not isinstance(other, AzimuthalNumpyXY): return False - return all(coord1 == coord2 for coord1, coord2 in zip(self, other)) def __ne__(self, other: typing.Any) -> bool: if self.dtype != other.dtype or not isinstance(other, AzimuthalNumpyXY): return True - return any(coord1 != coord2 for coord1, coord2 in zip(self, other)) @property @@ -510,6 +510,8 @@ class AzimuthalNumpyRhoPhi(AzimuthalNumpy, AzimuthalRhoPhi, GetItem, FloatArray) dtype: numpy.dtype[typing.Any] = numpy.dtype([("rho", float), ("phi", float)]) def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> AzimuthalNumpyRhoPhi: + if "dtype" in kwargs: + AzimuthalNumpyRhoPhi.dtype = numpy.dtype(kwargs["dtype"]) return numpy.array(*args, **kwargs).view(cls) def __array_finalize__(self, obj: typing.Any) -> None: @@ -522,13 +524,11 @@ def __array_finalize__(self, obj: typing.Any) -> None: def __eq__(self, other: typing.Any) -> bool: if self.dtype != other.dtype or not isinstance(other, AzimuthalNumpyRhoPhi): return False - return all(coord1 == coord2 for coord1, coord2 in zip(self, other)) def __ne__(self, other: typing.Any) -> bool: if self.dtype != other.dtype or not isinstance(other, AzimuthalNumpyRhoPhi): return True - return any(coord1 != coord2 for coord1, coord2 in zip(self, other)) @property @@ -573,6 +573,8 @@ class LongitudinalNumpyZ(LongitudinalNumpy, LongitudinalZ, GetItem, FloatArray): dtype: numpy.dtype[typing.Any] = numpy.dtype([("z", float)]) def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> LongitudinalNumpyZ: + if "dtype" in kwargs: + LongitudinalNumpyZ.dtype = numpy.dtype(kwargs["dtype"]) return numpy.array(*args, **kwargs).view(cls) def __array_finalize__(self, obj: typing.Any) -> None: @@ -585,13 +587,11 @@ def __array_finalize__(self, obj: typing.Any) -> None: def __eq__(self, other: typing.Any) -> bool: if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyZ): return False - return all(coord1 == coord2 for coord1, coord2 in zip(self, other)) def __ne__(self, other: typing.Any) -> bool: if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyZ): return True - return any(coord1 != coord2 for coord1, coord2 in zip(self, other)) @property @@ -631,6 +631,8 @@ class LongitudinalNumpyTheta(LongitudinalNumpy, LongitudinalTheta, GetItem, Floa dtype: numpy.dtype[typing.Any] = numpy.dtype([("theta", float)]) def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> LongitudinalNumpyTheta: + if "dtype" in kwargs: + LongitudinalNumpyTheta.dtype = numpy.dtype(kwargs["dtype"]) return numpy.array(*args, **kwargs).view(cls) def __array_finalize__(self, obj: typing.Any) -> None: @@ -643,13 +645,11 @@ def __array_finalize__(self, obj: typing.Any) -> None: def __eq__(self, other: typing.Any) -> bool: if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyTheta): return False - return all(coord1 == coord2 for coord1, coord2 in zip(self, other)) def __ne__(self, other: typing.Any) -> bool: if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyTheta): return True - return any(coord1 != coord2 for coord1, coord2 in zip(self, other)) @property @@ -689,6 +689,8 @@ class LongitudinalNumpyEta(LongitudinalNumpy, LongitudinalEta, GetItem, FloatArr dtype: numpy.dtype[typing.Any] = numpy.dtype([("eta", float)]) def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> LongitudinalNumpyEta: + if "dtype" in kwargs: + LongitudinalNumpyEta.dtype = numpy.dtype(kwargs["dtype"]) return numpy.array(*args, **kwargs).view(cls) def __array_finalize__(self, obj: typing.Any) -> None: @@ -701,13 +703,11 @@ def __array_finalize__(self, obj: typing.Any) -> None: def __eq__(self, other: typing.Any) -> bool: if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyEta): return False - return all(coord1 == coord2 for coord1, coord2 in zip(self, other)) def __ne__(self, other: typing.Any) -> bool: if self.dtype != other.dtype or not isinstance(other, LongitudinalNumpyEta): return True - return any(coord1 != coord2 for coord1, coord2 in zip(self, other)) @property @@ -747,6 +747,8 @@ class TemporalNumpyT(TemporalNumpy, TemporalT, GetItem, FloatArray): # type: ig dtype: numpy.dtype[typing.Any] = numpy.dtype([("t", float)]) def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> TemporalNumpyT: + if "dtype" in kwargs: + TemporalNumpyT.dtype = numpy.dtype(kwargs["dtype"]) return numpy.array(*args, **kwargs).view(cls) def __array_finalize__(self, obj: typing.Any) -> None: @@ -759,13 +761,11 @@ def __array_finalize__(self, obj: typing.Any) -> None: def __eq__(self, other: typing.Any) -> bool: if self.dtype != other.dtype or not isinstance(other, TemporalNumpyT): return False - return all(coord1 == coord2 for coord1, coord2 in zip(self, other)) def __ne__(self, other: typing.Any) -> bool: if self.dtype != other.dtype or not isinstance(other, TemporalNumpyT): return True - return any(coord1 != coord2 for coord1, coord2 in zip(self, other)) @property @@ -797,6 +797,8 @@ class TemporalNumpyTau(TemporalNumpy, TemporalTau, GetItem, FloatArray): # type dtype: numpy.dtype[typing.Any] = numpy.dtype([("tau", float)]) def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> TemporalNumpyTau: + if "dtype" in kwargs: + TemporalNumpyTau.dtype = numpy.dtype(kwargs["dtype"]) return numpy.array(*args, **kwargs).view(cls) def __array_finalize__(self, obj: typing.Any) -> None: @@ -809,13 +811,11 @@ def __array_finalize__(self, obj: typing.Any) -> None: def __eq__(self, other: typing.Any) -> bool: if self.dtype != other.dtype or not isinstance(other, TemporalNumpyTau): return False - return all(coord1 == coord2 for coord1, coord2 in zip(self, other)) def __ne__(self, other: typing.Any) -> bool: if self.dtype != other.dtype or not isinstance(other, TemporalNumpyTau): return True - return any(coord1 != coord2 for coord1, coord2 in zip(self, other)) @property