From 09b2de8eeac9310920d19f80f5124b1621c2381c Mon Sep 17 00:00:00 2001 From: Saransh Chopra Date: Mon, 26 Feb 2024 13:33:22 +0100 Subject: [PATCH] Better type checks and better type hints for vector methods --- src/vector/_methods.py | 131 ++++++++++++++++++++++------------------- 1 file changed, 69 insertions(+), 62 deletions(-) diff --git a/src/vector/_methods.py b/src/vector/_methods.py index c5150f8f..cb292382 100644 --- a/src/vector/_methods.py +++ b/src/vector/_methods.py @@ -788,7 +788,7 @@ def neg3D(self: SameVectorType) -> SameVectorType: """ raise AssertionError - def cross(self, other: VectorProtocol) -> VectorProtocolSpatial: + def cross(self, other: VectorProtocolSpatial) -> VectorProtocolSpatial: """ The 3D cross-product of ``self`` with ``other``. @@ -796,18 +796,24 @@ def cross(self, other: VectorProtocol) -> VectorProtocolSpatial: """ raise AssertionError - def deltaangle(self, other: VectorProtocol) -> ScalarCollection: + def deltaangle( + self, other: VectorProtocolSpatial | VectorProtocolLorentz + ) -> ScalarCollection: r""" Angle in 3D space between ``self`` and ``other``, which is always positive, between $0$ and $\pi$. """ raise AssertionError - def deltaeta(self, other: VectorProtocol) -> ScalarCollection: + def deltaeta( + self, other: VectorProtocolSpatial | VectorProtocolLorentz + ) -> ScalarCollection: r"""Signed difference in $\eta$ of ``self`` minus ``other``.""" raise AssertionError - def deltaR(self, other: VectorProtocol) -> ScalarCollection: + def deltaR( + self, other: VectorProtocolSpatial | VectorProtocolLorentz + ) -> ScalarCollection: r""" Sum in quadrature of :meth:`vector._methods.VectorProtocolPlanar.deltaphi` and :meth:`vector._methods.VectorProtocolSpatial.deltaeta`: @@ -816,7 +822,9 @@ def deltaR(self, other: VectorProtocol) -> ScalarCollection: """ raise AssertionError - def deltaR2(self, other: VectorProtocol) -> ScalarCollection: + def deltaR2( + self, other: VectorProtocolSpatial | VectorProtocolLorentz + ) -> ScalarCollection: r""" Square of the sum in quadrature of :meth:`vector._methods.VectorProtocolPlanar.deltaphi` and @@ -847,7 +855,7 @@ def rotateY(self: SameVectorType, angle: ScalarCollection) -> SameVectorType: raise AssertionError def rotate_axis( - self: SameVectorType, axis: VectorProtocol, angle: ScalarCollection + self: SameVectorType, axis: VectorProtocolSpatial, angle: ScalarCollection ) -> SameVectorType: """ Rotates the vector(s) by a given ``angle`` (in radians) around the @@ -1054,7 +1062,7 @@ def rapidity(self) -> ScalarCollection: """ raise AssertionError - def deltaRapidityPhi(self, other: VectorProtocol) -> ScalarCollection: + def deltaRapidityPhi(self, other: VectorProtocolLorentz) -> ScalarCollection: r""" Sum in quadrature of :meth:`vector._methods.VectorProtocolPlanar.deltaphi` and the difference in :attr:`vector._methods.VectorProtocolLorentz.rapidity` @@ -1064,7 +1072,7 @@ def deltaRapidityPhi(self, other: VectorProtocol) -> ScalarCollection: """ raise AssertionError - def deltaRapidityPhi2(self, other: VectorProtocol) -> ScalarCollection: + def deltaRapidityPhi2(self, other: VectorProtocolLorentz) -> ScalarCollection: r""" Square of the sum in quadrature of :meth:`vector._methods.VectorProtocolPlanar.deltaphi` and the difference in @@ -3443,30 +3451,24 @@ def is_parallel( ) -> BoolCollection: from vector._compute.planar import is_parallel - if not isinstance(other, Vector2D): - return self.to_Vector3D().is_parallel(other, tolerance=tolerance) - else: - return is_parallel.dispatch(tolerance, self, other) + _maybe_dimension_error(self, other, self.is_parallel.__name__) + return is_parallel.dispatch(tolerance, self, other) def is_antiparallel( self, other: VectorProtocol, tolerance: ScalarCollection = 1e-5 ) -> BoolCollection: from vector._compute.planar import is_antiparallel - if not isinstance(other, Vector2D): - return self.to_Vector3D().is_antiparallel(other, tolerance=tolerance) - else: - return is_antiparallel.dispatch(tolerance, self, other) + _maybe_dimension_error(self, other, self.is_antiparallel.__name__) + return is_antiparallel.dispatch(tolerance, self, other) def is_perpendicular( self, other: VectorProtocol, tolerance: ScalarCollection = 1e-5 ) -> BoolCollection: from vector._compute.planar import is_perpendicular - if not isinstance(other, Vector2D): - return self.to_Vector3D().is_perpendicular(other, tolerance=tolerance) - else: - return is_perpendicular.dispatch(tolerance, self, other) + _maybe_dimension_error(self, other, self.is_perpendicular.__name__) + return is_perpendicular.dispatch(tolerance, self, other) def unit(self: SameVectorType) -> SameVectorType: from vector._compute.planar import unit @@ -3474,17 +3476,17 @@ def unit(self: SameVectorType) -> SameVectorType: return unit.dispatch(self) def dot(self, other: VectorProtocol) -> ScalarCollection: - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.dot.__name__) module = _compute_module_of(self, other) return module.dot.dispatch(self, other) def add(self, other: VectorProtocol) -> VectorProtocol: - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.add.__name__) module = _compute_module_of(self, other) return module.add.dispatch(self, other) def subtract(self, other: VectorProtocol) -> VectorProtocol: - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.subtract.__name__) module = _compute_module_of(self, other) return module.subtract.dispatch(self, other) @@ -3507,13 +3509,13 @@ def scale(self: SameVectorType, factor: ScalarCollection) -> SameVectorType: def equal(self, other: VectorProtocol) -> BoolCollection: from vector._compute.planar import equal - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.equal.__name__) return equal.dispatch(self, other) def not_equal(self, other: VectorProtocol) -> BoolCollection: from vector._compute.planar import not_equal - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.not_equal.__name__) return not_equal.dispatch(self, other) def isclose( @@ -3525,7 +3527,7 @@ def isclose( ) -> BoolCollection: from vector._compute.planar import isclose - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.isclose.__name__) return isclose.dispatch(rtol, atol, equal_nan, self, other) @@ -3572,27 +3574,35 @@ def mag2(self) -> ScalarCollection: return mag2.dispatch(self) - def cross(self, other: VectorProtocol) -> VectorProtocolSpatial: + def cross(self, other: VectorProtocolSpatial) -> VectorProtocolSpatial: from vector._compute.spatial import cross return cross.dispatch(self, other) - def deltaangle(self, other: VectorProtocol) -> ScalarCollection: + def deltaangle( + self, other: VectorProtocolSpatial | VectorProtocolLorentz + ) -> ScalarCollection: from vector._compute.spatial import deltaangle return deltaangle.dispatch(self, other) - def deltaeta(self, other: VectorProtocol) -> ScalarCollection: + def deltaeta( + self, other: VectorProtocolSpatial | VectorProtocolLorentz + ) -> ScalarCollection: from vector._compute.spatial import deltaeta return deltaeta.dispatch(self, other) - def deltaR(self, other: VectorProtocol) -> ScalarCollection: + def deltaR( + self, other: VectorProtocolSpatial | VectorProtocolLorentz + ) -> ScalarCollection: from vector._compute.spatial import deltaR return deltaR.dispatch(self, other) - def deltaR2(self, other: VectorProtocol) -> ScalarCollection: + def deltaR2( + self, other: VectorProtocolSpatial | VectorProtocolLorentz + ) -> ScalarCollection: from vector._compute.spatial import deltaR2 return deltaR2.dispatch(self, other) @@ -3608,7 +3618,7 @@ def rotateY(self: SameVectorType, angle: ScalarCollection) -> SameVectorType: return rotateY.dispatch(angle, self) def rotate_axis( - self: SameVectorType, axis: VectorProtocol, angle: ScalarCollection + self: SameVectorType, axis: VectorProtocolSpatial, angle: ScalarCollection ) -> SameVectorType: from vector._compute.spatial import rotate_axis @@ -3658,30 +3668,24 @@ def is_parallel( ) -> BoolCollection: from vector._compute.spatial import is_parallel - if isinstance(other, Vector2D): - return is_parallel.dispatch(tolerance, self, other.to_Vector3D()) - else: - return is_parallel.dispatch(tolerance, self, other) + _maybe_dimension_error(self, other, self.is_parallel.__name__) + return is_parallel.dispatch(tolerance, self, other) def is_antiparallel( self, other: VectorProtocol, tolerance: ScalarCollection = 1e-5 ) -> BoolCollection: from vector._compute.spatial import is_antiparallel - if isinstance(other, Vector2D): - return is_antiparallel.dispatch(tolerance, self, other.to_Vector3D()) - else: - return is_antiparallel.dispatch(tolerance, self, other) + _maybe_dimension_error(self, other, self.is_antiparallel.__name__) + return is_antiparallel.dispatch(tolerance, self, other) def is_perpendicular( self, other: VectorProtocol, tolerance: ScalarCollection = 1e-5 ) -> BoolCollection: from vector._compute.spatial import is_perpendicular - if isinstance(other, Vector2D): - return is_perpendicular.dispatch(tolerance, self, other.to_Vector3D()) - else: - return is_perpendicular.dispatch(tolerance, self, other) + _maybe_dimension_error(self, other, self.is_perpendicular.__name__) + return is_perpendicular.dispatch(tolerance, self, other) def unit(self: SameVectorType) -> SameVectorType: from vector._compute.spatial import unit @@ -3689,17 +3693,17 @@ def unit(self: SameVectorType) -> SameVectorType: return unit.dispatch(self) def dot(self, other: VectorProtocol) -> ScalarCollection: - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.dot.__name__) module = _compute_module_of(self, other) return module.dot.dispatch(self, other) def add(self, other: VectorProtocol) -> VectorProtocol: - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.add.__name__) module = _compute_module_of(self, other) return module.add.dispatch(self, other) def subtract(self, other: VectorProtocol) -> VectorProtocol: - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.subtract.__name__) module = _compute_module_of(self, other) return module.subtract.dispatch(self, other) @@ -3733,13 +3737,13 @@ def scale(self: SameVectorType, factor: ScalarCollection) -> SameVectorType: def equal(self, other: VectorProtocol) -> BoolCollection: from vector._compute.spatial import equal - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.equal.__name__) return equal.dispatch(self, other) def not_equal(self, other: VectorProtocol) -> BoolCollection: from vector._compute.spatial import not_equal - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.not_equal.__name__) return not_equal.dispatch(self, other) def isclose( @@ -3751,7 +3755,7 @@ def isclose( ) -> BoolCollection: from vector._compute.spatial import isclose - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.isclose.__name__) return isclose.dispatch(rtol, atol, equal_nan, self, other) @@ -3798,12 +3802,12 @@ def rapidity(self) -> ScalarCollection: return rapidity.dispatch(self) - def deltaRapidityPhi(self, other: VectorProtocol) -> ScalarCollection: + def deltaRapidityPhi(self, other: VectorProtocolLorentz) -> ScalarCollection: from vector._compute.lorentz import deltaRapidityPhi return deltaRapidityPhi.dispatch(self, other) - def deltaRapidityPhi2(self, other: VectorProtocol) -> ScalarCollection: + def deltaRapidityPhi2(self, other: VectorProtocolLorentz) -> ScalarCollection: from vector._compute.lorentz import deltaRapidityPhi2 return deltaRapidityPhi2.dispatch(self, other) @@ -3933,17 +3937,17 @@ def unit(self: SameVectorType) -> SameVectorType: return unit.dispatch(self) def dot(self, other: VectorProtocol) -> ScalarCollection: - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.dot.__name__) module = _compute_module_of(self, other) return module.dot.dispatch(self, other) def add(self, other: VectorProtocol) -> VectorProtocol: - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.add.__name__) module = _compute_module_of(self, other) return module.add.dispatch(self, other) def subtract(self, other: VectorProtocol) -> VectorProtocol: - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.subtract.__name__) module = _compute_module_of(self, other) return module.subtract.dispatch(self, other) @@ -3988,13 +3992,13 @@ def scale(self: SameVectorType, factor: ScalarCollection) -> SameVectorType: def equal(self, other: VectorProtocol) -> BoolCollection: from vector._compute.lorentz import equal - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.equal.__name__) return equal.dispatch(self, other) def not_equal(self, other: VectorProtocol) -> BoolCollection: from vector._compute.lorentz import not_equal - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.not_equal.__name__) return not_equal.dispatch(self, other) def isclose( @@ -4006,7 +4010,7 @@ def isclose( ) -> BoolCollection: from vector._compute.lorentz import isclose - _is_same_dimension(self, other) + _maybe_dimension_error(self, other, self.isclose.__name__) return isclose.dispatch(rtol, atol, equal_nan, self, other) @@ -4176,19 +4180,22 @@ def dim(v: VectorProtocol) -> int: raise TypeError(f"{v!r} is not a vector.Vector") -def _is_same_dimension(v1: VectorProtocol, v2: VectorProtocol) -> None: +def _maybe_dimension_error( + v1: VectorProtocol, v2: VectorProtocol, operation: str +) -> None: """Raises an error if the vectors are not of the same dimension.""" if dim(v1) != dim(v2): raise TypeError( f"""{v1!r} and {v2!r} do not have the same dimension; use - a.like(b) + b + a.like(b).{operation}(b) or - a + b.like(a) + a.{operation}(b.like(a)) - to project or embed one of the vectors to match the other's dimensionality + or the binary operation equivalent to project or embed one of the vectors + to match the other's dimensionality """ )