Skip to content

Commit

Permalink
Better type checks and better type hints for vector methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Feb 26, 2024
1 parent 5e2e772 commit 09b2de8
Showing 1 changed file with 69 additions and 62 deletions.
131 changes: 69 additions & 62 deletions src/vector/_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,26 +788,32 @@ def neg3D(self: SameVectorType) -> SameVectorType:
"""
raise AssertionError

def cross(self, other: VectorProtocol) -> VectorProtocolSpatial:
def cross(self, other: VectorProtocolSpatial) -> VectorProtocolSpatial:
"""
The 3D cross-product of ``self`` with ``other``.
Even if ``self`` or ``other`` is 4D, the resulting vector(s) is/are 3D.
"""
raise AssertionError

def deltaangle(self, other: VectorProtocol) -> ScalarCollection:
def deltaangle(
self, other: VectorProtocolSpatial | VectorProtocolLorentz
) -> ScalarCollection:
r"""
Angle in 3D space between ``self`` and ``other``, which is always
positive, between $0$ and $\pi$.
"""
raise AssertionError

def deltaeta(self, other: VectorProtocol) -> ScalarCollection:
def deltaeta(
self, other: VectorProtocolSpatial | VectorProtocolLorentz
) -> ScalarCollection:
r"""Signed difference in $\eta$ of ``self`` minus ``other``."""
raise AssertionError

def deltaR(self, other: VectorProtocol) -> ScalarCollection:
def deltaR(
self, other: VectorProtocolSpatial | VectorProtocolLorentz
) -> ScalarCollection:
r"""
Sum in quadrature of :meth:`vector._methods.VectorProtocolPlanar.deltaphi`
and :meth:`vector._methods.VectorProtocolSpatial.deltaeta`:
Expand All @@ -816,7 +822,9 @@ def deltaR(self, other: VectorProtocol) -> ScalarCollection:
"""
raise AssertionError

def deltaR2(self, other: VectorProtocol) -> ScalarCollection:
def deltaR2(
self, other: VectorProtocolSpatial | VectorProtocolLorentz
) -> ScalarCollection:
r"""
Square of the sum in quadrature of
:meth:`vector._methods.VectorProtocolPlanar.deltaphi` and
Expand Down Expand Up @@ -847,7 +855,7 @@ def rotateY(self: SameVectorType, angle: ScalarCollection) -> SameVectorType:
raise AssertionError

def rotate_axis(
self: SameVectorType, axis: VectorProtocol, angle: ScalarCollection
self: SameVectorType, axis: VectorProtocolSpatial, angle: ScalarCollection
) -> SameVectorType:
"""
Rotates the vector(s) by a given ``angle`` (in radians) around the
Expand Down Expand Up @@ -1054,7 +1062,7 @@ def rapidity(self) -> ScalarCollection:
"""
raise AssertionError

def deltaRapidityPhi(self, other: VectorProtocol) -> ScalarCollection:
def deltaRapidityPhi(self, other: VectorProtocolLorentz) -> ScalarCollection:
r"""
Sum in quadrature of :meth:`vector._methods.VectorProtocolPlanar.deltaphi`
and the difference in :attr:`vector._methods.VectorProtocolLorentz.rapidity`
Expand All @@ -1064,7 +1072,7 @@ def deltaRapidityPhi(self, other: VectorProtocol) -> ScalarCollection:
"""
raise AssertionError

def deltaRapidityPhi2(self, other: VectorProtocol) -> ScalarCollection:
def deltaRapidityPhi2(self, other: VectorProtocolLorentz) -> ScalarCollection:
r"""
Square of the sum in quadrature of
:meth:`vector._methods.VectorProtocolPlanar.deltaphi` and the difference in
Expand Down Expand Up @@ -3443,48 +3451,42 @@ def is_parallel(
) -> BoolCollection:
from vector._compute.planar import is_parallel

if not isinstance(other, Vector2D):
return self.to_Vector3D().is_parallel(other, tolerance=tolerance)
else:
return is_parallel.dispatch(tolerance, self, other)
_maybe_dimension_error(self, other, self.is_parallel.__name__)
return is_parallel.dispatch(tolerance, self, other)

Check warning on line 3455 in src/vector/_methods.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_methods.py#L3454-L3455

Added lines #L3454 - L3455 were not covered by tests

def is_antiparallel(
self, other: VectorProtocol, tolerance: ScalarCollection = 1e-5
) -> BoolCollection:
from vector._compute.planar import is_antiparallel

if not isinstance(other, Vector2D):
return self.to_Vector3D().is_antiparallel(other, tolerance=tolerance)
else:
return is_antiparallel.dispatch(tolerance, self, other)
_maybe_dimension_error(self, other, self.is_antiparallel.__name__)
return is_antiparallel.dispatch(tolerance, self, other)

Check warning on line 3463 in src/vector/_methods.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_methods.py#L3462-L3463

Added lines #L3462 - L3463 were not covered by tests

def is_perpendicular(
self, other: VectorProtocol, tolerance: ScalarCollection = 1e-5
) -> BoolCollection:
from vector._compute.planar import is_perpendicular

if not isinstance(other, Vector2D):
return self.to_Vector3D().is_perpendicular(other, tolerance=tolerance)
else:
return is_perpendicular.dispatch(tolerance, self, other)
_maybe_dimension_error(self, other, self.is_perpendicular.__name__)
return is_perpendicular.dispatch(tolerance, self, other)

Check warning on line 3471 in src/vector/_methods.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_methods.py#L3470-L3471

Added lines #L3470 - L3471 were not covered by tests

def unit(self: SameVectorType) -> SameVectorType:
from vector._compute.planar import unit

return unit.dispatch(self)

def dot(self, other: VectorProtocol) -> ScalarCollection:
_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.dot.__name__)
module = _compute_module_of(self, other)
return module.dot.dispatch(self, other)

def add(self, other: VectorProtocol) -> VectorProtocol:
_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.add.__name__)
module = _compute_module_of(self, other)
return module.add.dispatch(self, other)

def subtract(self, other: VectorProtocol) -> VectorProtocol:
_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.subtract.__name__)
module = _compute_module_of(self, other)
return module.subtract.dispatch(self, other)

Expand All @@ -3507,13 +3509,13 @@ def scale(self: SameVectorType, factor: ScalarCollection) -> SameVectorType:
def equal(self, other: VectorProtocol) -> BoolCollection:
from vector._compute.planar import equal

_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.equal.__name__)
return equal.dispatch(self, other)

def not_equal(self, other: VectorProtocol) -> BoolCollection:
from vector._compute.planar import not_equal

_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.not_equal.__name__)
return not_equal.dispatch(self, other)

def isclose(
Expand All @@ -3525,7 +3527,7 @@ def isclose(
) -> BoolCollection:
from vector._compute.planar import isclose

_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.isclose.__name__)
return isclose.dispatch(rtol, atol, equal_nan, self, other)


Expand Down Expand Up @@ -3572,27 +3574,35 @@ def mag2(self) -> ScalarCollection:

return mag2.dispatch(self)

def cross(self, other: VectorProtocol) -> VectorProtocolSpatial:
def cross(self, other: VectorProtocolSpatial) -> VectorProtocolSpatial:
from vector._compute.spatial import cross

return cross.dispatch(self, other)

def deltaangle(self, other: VectorProtocol) -> ScalarCollection:
def deltaangle(
self, other: VectorProtocolSpatial | VectorProtocolLorentz
) -> ScalarCollection:
from vector._compute.spatial import deltaangle

return deltaangle.dispatch(self, other)

def deltaeta(self, other: VectorProtocol) -> ScalarCollection:
def deltaeta(
self, other: VectorProtocolSpatial | VectorProtocolLorentz
) -> ScalarCollection:
from vector._compute.spatial import deltaeta

return deltaeta.dispatch(self, other)

def deltaR(self, other: VectorProtocol) -> ScalarCollection:
def deltaR(
self, other: VectorProtocolSpatial | VectorProtocolLorentz
) -> ScalarCollection:
from vector._compute.spatial import deltaR

return deltaR.dispatch(self, other)

def deltaR2(self, other: VectorProtocol) -> ScalarCollection:
def deltaR2(
self, other: VectorProtocolSpatial | VectorProtocolLorentz
) -> ScalarCollection:
from vector._compute.spatial import deltaR2

return deltaR2.dispatch(self, other)
Expand All @@ -3608,7 +3618,7 @@ def rotateY(self: SameVectorType, angle: ScalarCollection) -> SameVectorType:
return rotateY.dispatch(angle, self)

def rotate_axis(
self: SameVectorType, axis: VectorProtocol, angle: ScalarCollection
self: SameVectorType, axis: VectorProtocolSpatial, angle: ScalarCollection
) -> SameVectorType:
from vector._compute.spatial import rotate_axis

Expand Down Expand Up @@ -3658,48 +3668,42 @@ def is_parallel(
) -> BoolCollection:
from vector._compute.spatial import is_parallel

if isinstance(other, Vector2D):
return is_parallel.dispatch(tolerance, self, other.to_Vector3D())
else:
return is_parallel.dispatch(tolerance, self, other)
_maybe_dimension_error(self, other, self.is_parallel.__name__)
return is_parallel.dispatch(tolerance, self, other)

Check warning on line 3672 in src/vector/_methods.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_methods.py#L3671-L3672

Added lines #L3671 - L3672 were not covered by tests

def is_antiparallel(
self, other: VectorProtocol, tolerance: ScalarCollection = 1e-5
) -> BoolCollection:
from vector._compute.spatial import is_antiparallel

if isinstance(other, Vector2D):
return is_antiparallel.dispatch(tolerance, self, other.to_Vector3D())
else:
return is_antiparallel.dispatch(tolerance, self, other)
_maybe_dimension_error(self, other, self.is_antiparallel.__name__)
return is_antiparallel.dispatch(tolerance, self, other)

Check warning on line 3680 in src/vector/_methods.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_methods.py#L3679-L3680

Added lines #L3679 - L3680 were not covered by tests

def is_perpendicular(
self, other: VectorProtocol, tolerance: ScalarCollection = 1e-5
) -> BoolCollection:
from vector._compute.spatial import is_perpendicular

if isinstance(other, Vector2D):
return is_perpendicular.dispatch(tolerance, self, other.to_Vector3D())
else:
return is_perpendicular.dispatch(tolerance, self, other)
_maybe_dimension_error(self, other, self.is_perpendicular.__name__)
return is_perpendicular.dispatch(tolerance, self, other)

Check warning on line 3688 in src/vector/_methods.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_methods.py#L3687-L3688

Added lines #L3687 - L3688 were not covered by tests

def unit(self: SameVectorType) -> SameVectorType:
from vector._compute.spatial import unit

return unit.dispatch(self)

def dot(self, other: VectorProtocol) -> ScalarCollection:
_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.dot.__name__)
module = _compute_module_of(self, other)
return module.dot.dispatch(self, other)

def add(self, other: VectorProtocol) -> VectorProtocol:
_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.add.__name__)
module = _compute_module_of(self, other)
return module.add.dispatch(self, other)

def subtract(self, other: VectorProtocol) -> VectorProtocol:
_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.subtract.__name__)
module = _compute_module_of(self, other)
return module.subtract.dispatch(self, other)

Expand Down Expand Up @@ -3733,13 +3737,13 @@ def scale(self: SameVectorType, factor: ScalarCollection) -> SameVectorType:
def equal(self, other: VectorProtocol) -> BoolCollection:
from vector._compute.spatial import equal

_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.equal.__name__)
return equal.dispatch(self, other)

def not_equal(self, other: VectorProtocol) -> BoolCollection:
from vector._compute.spatial import not_equal

_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.not_equal.__name__)
return not_equal.dispatch(self, other)

def isclose(
Expand All @@ -3751,7 +3755,7 @@ def isclose(
) -> BoolCollection:
from vector._compute.spatial import isclose

_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.isclose.__name__)
return isclose.dispatch(rtol, atol, equal_nan, self, other)


Expand Down Expand Up @@ -3798,12 +3802,12 @@ def rapidity(self) -> ScalarCollection:

return rapidity.dispatch(self)

def deltaRapidityPhi(self, other: VectorProtocol) -> ScalarCollection:
def deltaRapidityPhi(self, other: VectorProtocolLorentz) -> ScalarCollection:
from vector._compute.lorentz import deltaRapidityPhi

return deltaRapidityPhi.dispatch(self, other)

def deltaRapidityPhi2(self, other: VectorProtocol) -> ScalarCollection:
def deltaRapidityPhi2(self, other: VectorProtocolLorentz) -> ScalarCollection:
from vector._compute.lorentz import deltaRapidityPhi2

return deltaRapidityPhi2.dispatch(self, other)
Expand Down Expand Up @@ -3933,17 +3937,17 @@ def unit(self: SameVectorType) -> SameVectorType:
return unit.dispatch(self)

def dot(self, other: VectorProtocol) -> ScalarCollection:
_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.dot.__name__)
module = _compute_module_of(self, other)
return module.dot.dispatch(self, other)

def add(self, other: VectorProtocol) -> VectorProtocol:
_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.add.__name__)
module = _compute_module_of(self, other)
return module.add.dispatch(self, other)

def subtract(self, other: VectorProtocol) -> VectorProtocol:
_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.subtract.__name__)
module = _compute_module_of(self, other)
return module.subtract.dispatch(self, other)

Expand Down Expand Up @@ -3988,13 +3992,13 @@ def scale(self: SameVectorType, factor: ScalarCollection) -> SameVectorType:
def equal(self, other: VectorProtocol) -> BoolCollection:
from vector._compute.lorentz import equal

_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.equal.__name__)
return equal.dispatch(self, other)

def not_equal(self, other: VectorProtocol) -> BoolCollection:
from vector._compute.lorentz import not_equal

_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.not_equal.__name__)
return not_equal.dispatch(self, other)

def isclose(
Expand All @@ -4006,7 +4010,7 @@ def isclose(
) -> BoolCollection:
from vector._compute.lorentz import isclose

_is_same_dimension(self, other)
_maybe_dimension_error(self, other, self.isclose.__name__)
return isclose.dispatch(rtol, atol, equal_nan, self, other)


Expand Down Expand Up @@ -4176,19 +4180,22 @@ def dim(v: VectorProtocol) -> int:
raise TypeError(f"{v!r} is not a vector.Vector")


def _is_same_dimension(v1: VectorProtocol, v2: VectorProtocol) -> None:
def _maybe_dimension_error(
v1: VectorProtocol, v2: VectorProtocol, operation: str
) -> None:
"""Raises an error if the vectors are not of the same dimension."""
if dim(v1) != dim(v2):
raise TypeError(
f"""{v1!r} and {v2!r} do not have the same dimension; use
a.like(b) + b
a.like(b).{operation}(b)
or
a + b.like(a)
a.{operation}(b.like(a))
to project or embed one of the vectors to match the other's dimensionality
or the binary operation equivalent to project or embed one of the vectors
to match the other's dimensionality
"""
)

Expand Down

0 comments on commit 09b2de8

Please sign in to comment.