Skip to content

Commit

Permalink
Explicit error for more methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Mar 5, 2024
1 parent 0b11c7d commit dfaa766
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 118 deletions.
114 changes: 73 additions & 41 deletions src/vector/_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,7 @@ def neg4D(self: SameVectorType) -> SameVectorType:
"""Same as multiplying by -1."""
raise AssertionError

def boost_p4(self: SameVectorType, p4: VectorProtocolLorentz) -> SameVectorType:
def boost_p4(self: SameVectorType, p4: MomentumProtocolLorentz) -> SameVectorType:
"""
Boosts the vector or array of vectors in a direction and magnitude given
by the 4D vector or array of vectors ``p4``.
Expand All @@ -1118,7 +1118,7 @@ def boost_p4(self: SameVectorType, p4: VectorProtocolLorentz) -> SameVectorType:
raise AssertionError

def boost_beta3(
self: SameVectorType, beta3: VectorProtocolSpatial
self: SameVectorType, beta3: MomentumProtocolSpatial
) -> SameVectorType:
"""
Boosts the vector or array of vectors in a direction and magnitude given
Expand All @@ -1136,7 +1136,9 @@ def boost_beta3(
"""
raise AssertionError

def boost(self: SameVectorType, booster: VectorProtocol) -> SameVectorType:
def boost(
self: SameVectorType, booster: MomentumProtocolSpatial | MomentumProtocolLorentz
) -> SameVectorType:
"""
Boosts the vector or array of vectors using the 3D or 4D ``booster``.
Expand All @@ -1159,7 +1161,7 @@ def boost(self: SameVectorType, booster: VectorProtocol) -> SameVectorType:
raise AssertionError

def boostCM_of_p4(
self: SameVectorType, p4: VectorProtocolLorentz
self: SameVectorType, p4: MomentumProtocolLorentz
) -> SameVectorType:
"""
Boosts the vector or array of vectors to the center-of-mass (CM) frame of
Expand All @@ -1177,7 +1179,7 @@ def boostCM_of_p4(
raise AssertionError

def boostCM_of_beta3(
self: SameVectorType, beta3: VectorProtocolSpatial
self: SameVectorType, beta3: MomentumProtocolSpatial
) -> SameVectorType:
"""
Boosts the vector or array of vectors to the center-of-mass (CM) frame of
Expand All @@ -1188,7 +1190,9 @@ def boostCM_of_beta3(
"""
raise AssertionError

def boostCM_of(self: SameVectorType, booster: VectorProtocol) -> SameVectorType:
def boostCM_of(
self: SameVectorType, booster: MomentumProtocolSpatial | MomentumProtocolLorentz
) -> SameVectorType:
"""
Boosts the vector or array of vectors to the center-of-mass (CM) frame of
the 3D or 4D ``booster``.
Expand Down Expand Up @@ -3451,23 +3455,23 @@ def is_parallel(
) -> BoolCollection:
from vector._compute.planar import is_parallel

_maybe_dimension_error(self, other, self.is_parallel.__name__)
_maybe_same_dimension_error(self, other, self.is_parallel.__name__)
return is_parallel.dispatch(tolerance, self, other)

def is_antiparallel(
self, other: VectorProtocol, tolerance: ScalarCollection = 1e-5
) -> BoolCollection:
from vector._compute.planar import is_antiparallel

_maybe_dimension_error(self, other, self.is_antiparallel.__name__)
_maybe_same_dimension_error(self, other, self.is_antiparallel.__name__)
return is_antiparallel.dispatch(tolerance, self, other)

def is_perpendicular(
self, other: VectorProtocol, tolerance: ScalarCollection = 1e-5
) -> BoolCollection:
from vector._compute.planar import is_perpendicular

_maybe_dimension_error(self, other, self.is_perpendicular.__name__)
_maybe_same_dimension_error(self, other, self.is_perpendicular.__name__)
return is_perpendicular.dispatch(tolerance, self, other)

def unit(self: SameVectorType) -> SameVectorType:
Expand All @@ -3476,17 +3480,17 @@ def unit(self: SameVectorType) -> SameVectorType:
return unit.dispatch(self)

def dot(self, other: VectorProtocol) -> ScalarCollection:
_maybe_dimension_error(self, other, self.dot.__name__)
_maybe_same_dimension_error(self, other, self.dot.__name__)
module = _compute_module_of(self, other)
return module.dot.dispatch(self, other)

def add(self, other: VectorProtocol) -> VectorProtocol:
_maybe_dimension_error(self, other, self.add.__name__)
_maybe_same_dimension_error(self, other, self.add.__name__)
module = _compute_module_of(self, other)
return module.add.dispatch(self, other)

def subtract(self, other: VectorProtocol) -> VectorProtocol:
_maybe_dimension_error(self, other, self.subtract.__name__)
_maybe_same_dimension_error(self, other, self.subtract.__name__)
module = _compute_module_of(self, other)
return module.subtract.dispatch(self, other)

Expand All @@ -3509,13 +3513,13 @@ def scale(self: SameVectorType, factor: ScalarCollection) -> SameVectorType:
def equal(self, other: VectorProtocol) -> BoolCollection:
from vector._compute.planar import equal

_maybe_dimension_error(self, other, self.equal.__name__)
_maybe_same_dimension_error(self, other, self.equal.__name__)
return equal.dispatch(self, other)

def not_equal(self, other: VectorProtocol) -> BoolCollection:
from vector._compute.planar import not_equal

_maybe_dimension_error(self, other, self.not_equal.__name__)
_maybe_same_dimension_error(self, other, self.not_equal.__name__)
return not_equal.dispatch(self, other)

def isclose(
Expand All @@ -3527,7 +3531,7 @@ def isclose(
) -> BoolCollection:
from vector._compute.planar import isclose

_maybe_dimension_error(self, other, self.isclose.__name__)
_maybe_same_dimension_error(self, other, self.isclose.__name__)
return isclose.dispatch(rtol, atol, equal_nan, self, other)


Expand Down Expand Up @@ -3577,34 +3581,44 @@ def mag2(self) -> ScalarCollection:
def cross(self, other: VectorProtocolSpatial) -> VectorProtocolSpatial:
from vector._compute.spatial import cross

if dim(self) != 3 or dim(other) != 3:
raise TypeError("cross is only defined for 3D vectors")
return cross.dispatch(self, other)

def deltaangle(
self, other: VectorProtocolSpatial | VectorProtocolLorentz
) -> ScalarCollection:
from vector._compute.spatial import deltaangle

if dim(other) != 3 and dim(other) != 4:
raise TypeError(f"{other!r} is not a 3D or a 4D vector")
return deltaangle.dispatch(self, other)

def deltaeta(
self, other: VectorProtocolSpatial | VectorProtocolLorentz
) -> ScalarCollection:
from vector._compute.spatial import deltaeta

if dim(other) != 3 and dim(other) != 4:
raise TypeError(f"{other!r} is not a 3D or a 4D vector")
return deltaeta.dispatch(self, other)

def deltaR(
self, other: VectorProtocolSpatial | VectorProtocolLorentz
) -> ScalarCollection:
from vector._compute.spatial import deltaR

if dim(other) != 3 and dim(other) != 4:
raise TypeError(f"{other!r} is not a 3D or a 4D vector")
return deltaR.dispatch(self, other)

def deltaR2(
self, other: VectorProtocolSpatial | VectorProtocolLorentz
) -> ScalarCollection:
from vector._compute.spatial import deltaR2

if dim(other) != 3 and dim(other) != 4:
raise TypeError(f"{other!r} is not a 3D or a 4D vector")
return deltaR2.dispatch(self, other)

def rotateX(self: SameVectorType, angle: ScalarCollection) -> SameVectorType:
Expand All @@ -3622,6 +3636,8 @@ def rotate_axis(
) -> SameVectorType:
from vector._compute.spatial import rotate_axis

if dim(axis) != 3:
raise TypeError(f"{axis!r} is not a 3D vector")
return rotate_axis.dispatch(angle, axis, self)

def rotate_euler(
Expand Down Expand Up @@ -3668,23 +3684,23 @@ def is_parallel(
) -> BoolCollection:
from vector._compute.spatial import is_parallel

_maybe_dimension_error(self, other, self.is_parallel.__name__)
_maybe_same_dimension_error(self, other, self.is_parallel.__name__)
return is_parallel.dispatch(tolerance, self, other)

def is_antiparallel(
self, other: VectorProtocol, tolerance: ScalarCollection = 1e-5
) -> BoolCollection:
from vector._compute.spatial import is_antiparallel

_maybe_dimension_error(self, other, self.is_antiparallel.__name__)
_maybe_same_dimension_error(self, other, self.is_antiparallel.__name__)
return is_antiparallel.dispatch(tolerance, self, other)

def is_perpendicular(
self, other: VectorProtocol, tolerance: ScalarCollection = 1e-5
) -> BoolCollection:
from vector._compute.spatial import is_perpendicular

_maybe_dimension_error(self, other, self.is_perpendicular.__name__)
_maybe_same_dimension_error(self, other, self.is_perpendicular.__name__)
return is_perpendicular.dispatch(tolerance, self, other)

def unit(self: SameVectorType) -> SameVectorType:
Expand All @@ -3693,17 +3709,17 @@ def unit(self: SameVectorType) -> SameVectorType:
return unit.dispatch(self)

def dot(self, other: VectorProtocol) -> ScalarCollection:
_maybe_dimension_error(self, other, self.dot.__name__)
_maybe_same_dimension_error(self, other, self.dot.__name__)
module = _compute_module_of(self, other)
return module.dot.dispatch(self, other)

def add(self, other: VectorProtocol) -> VectorProtocol:
_maybe_dimension_error(self, other, self.add.__name__)
_maybe_same_dimension_error(self, other, self.add.__name__)
module = _compute_module_of(self, other)
return module.add.dispatch(self, other)

def subtract(self, other: VectorProtocol) -> VectorProtocol:
_maybe_dimension_error(self, other, self.subtract.__name__)
_maybe_same_dimension_error(self, other, self.subtract.__name__)
module = _compute_module_of(self, other)
return module.subtract.dispatch(self, other)

Expand Down Expand Up @@ -3737,13 +3753,13 @@ def scale(self: SameVectorType, factor: ScalarCollection) -> SameVectorType:
def equal(self, other: VectorProtocol) -> BoolCollection:
from vector._compute.spatial import equal

_maybe_dimension_error(self, other, self.equal.__name__)
_maybe_same_dimension_error(self, other, self.equal.__name__)
return equal.dispatch(self, other)

def not_equal(self, other: VectorProtocol) -> BoolCollection:
from vector._compute.spatial import not_equal

_maybe_dimension_error(self, other, self.not_equal.__name__)
_maybe_same_dimension_error(self, other, self.not_equal.__name__)
return not_equal.dispatch(self, other)

def isclose(
Expand All @@ -3755,7 +3771,7 @@ def isclose(
) -> BoolCollection:
from vector._compute.spatial import isclose

_maybe_dimension_error(self, other, self.isclose.__name__)
_maybe_same_dimension_error(self, other, self.isclose.__name__)
return isclose.dispatch(rtol, atol, equal_nan, self, other)


Expand Down Expand Up @@ -3805,31 +3821,41 @@ def rapidity(self) -> ScalarCollection:
def deltaRapidityPhi(self, other: VectorProtocolLorentz) -> ScalarCollection:
from vector._compute.lorentz import deltaRapidityPhi

if not dim(other) == 4:
raise TypeError(f"{other!r} is not a 4D vector")
return deltaRapidityPhi.dispatch(self, other)

def deltaRapidityPhi2(self, other: VectorProtocolLorentz) -> ScalarCollection:
from vector._compute.lorentz import deltaRapidityPhi2

if not dim(other) == 4:
raise TypeError(f"{other!r} is not a 4D vector")
return deltaRapidityPhi2.dispatch(self, other)

def boost_p4(self: SameVectorType, p4: VectorProtocolLorentz) -> SameVectorType:
def boost_p4(self: SameVectorType, p4: MomentumProtocolLorentz) -> SameVectorType:
from vector._compute.lorentz import boost_p4

if dim(p4) != 4 or not isinstance(p4, LorentzMomentum):
raise TypeError(f"{p4!r} is not a 4D momentum vector")
return boost_p4.dispatch(self, p4)

def boost_beta3(
self: SameVectorType, beta3: VectorProtocolSpatial
self: SameVectorType, beta3: MomentumProtocolSpatial
) -> SameVectorType:
from vector._compute.lorentz import boost_beta3

if dim(beta3) != 3 or not isinstance(beta3, SpatialMomentum):
raise TypeError(f"{beta3!r} is not a 3D momentum vector")
return boost_beta3.dispatch(self, beta3)

def boost(self: SameVectorType, booster: VectorProtocol) -> SameVectorType:
def boost(
self: SameVectorType, booster: MomentumProtocolSpatial | MomentumProtocolLorentz
) -> SameVectorType:
from vector._compute.lorentz import boost_beta3, boost_p4

if isinstance(booster, Vector3D):
if isinstance(booster, SpatialMomentum):
return boost_beta3.dispatch(self, booster)
elif isinstance(booster, Vector4D):
elif isinstance(booster, LorentzMomentum):
return boost_p4.dispatch(self, booster)
else:
raise TypeError(
Expand All @@ -3838,25 +3864,31 @@ def boost(self: SameVectorType, booster: VectorProtocol) -> SameVectorType:
)

def boostCM_of_p4(
self: SameVectorType, p4: VectorProtocolLorentz
self: SameVectorType, p4: MomentumProtocolLorentz
) -> SameVectorType:
from vector._compute.lorentz import boost_p4

if dim(p4) != 4 or not isinstance(p4, LorentzMomentum):
raise TypeError(f"{p4!r} is not a 4D momentum vector")
return boost_p4.dispatch(self, p4.neg3D)

def boostCM_of_beta3(
self: SameVectorType, beta3: VectorProtocolSpatial
self: SameVectorType, beta3: MomentumProtocolSpatial
) -> SameVectorType:
from vector._compute.lorentz import boost_beta3

if dim(beta3) != 3 or not isinstance(beta3, SpatialMomentum):
raise TypeError(f"{beta3!r} is not a 3D momentum vector")
return boost_beta3.dispatch(self, beta3.neg3D)

def boostCM_of(self: SameVectorType, booster: VectorProtocol) -> SameVectorType:
def boostCM_of(
self: SameVectorType, booster: MomentumProtocolSpatial | MomentumProtocolLorentz
) -> SameVectorType:
from vector._compute.lorentz import boost_beta3, boost_p4

if isinstance(booster, Vector3D):
if isinstance(booster, SpatialMomentum):
return boost_beta3.dispatch(self, booster.neg3D)
elif isinstance(booster, Vector4D):
elif isinstance(booster, LorentzMomentum):
return boost_p4.dispatch(self, booster.neg3D)
else:
raise TypeError(
Expand Down Expand Up @@ -3937,17 +3969,17 @@ def unit(self: SameVectorType) -> SameVectorType:
return unit.dispatch(self)

def dot(self, other: VectorProtocol) -> ScalarCollection:
_maybe_dimension_error(self, other, self.dot.__name__)
_maybe_same_dimension_error(self, other, self.dot.__name__)
module = _compute_module_of(self, other)
return module.dot.dispatch(self, other)

def add(self, other: VectorProtocol) -> VectorProtocol:
_maybe_dimension_error(self, other, self.add.__name__)
_maybe_same_dimension_error(self, other, self.add.__name__)
module = _compute_module_of(self, other)
return module.add.dispatch(self, other)

def subtract(self, other: VectorProtocol) -> VectorProtocol:
_maybe_dimension_error(self, other, self.subtract.__name__)
_maybe_same_dimension_error(self, other, self.subtract.__name__)
module = _compute_module_of(self, other)
return module.subtract.dispatch(self, other)

Expand Down Expand Up @@ -3992,13 +4024,13 @@ def scale(self: SameVectorType, factor: ScalarCollection) -> SameVectorType:
def equal(self, other: VectorProtocol) -> BoolCollection:
from vector._compute.lorentz import equal

_maybe_dimension_error(self, other, self.equal.__name__)
_maybe_same_dimension_error(self, other, self.equal.__name__)
return equal.dispatch(self, other)

def not_equal(self, other: VectorProtocol) -> BoolCollection:
from vector._compute.lorentz import not_equal

_maybe_dimension_error(self, other, self.not_equal.__name__)
_maybe_same_dimension_error(self, other, self.not_equal.__name__)
return not_equal.dispatch(self, other)

def isclose(
Expand All @@ -4010,7 +4042,7 @@ def isclose(
) -> BoolCollection:
from vector._compute.lorentz import isclose

_maybe_dimension_error(self, other, self.isclose.__name__)
_maybe_same_dimension_error(self, other, self.isclose.__name__)
return isclose.dispatch(rtol, atol, equal_nan, self, other)


Expand Down Expand Up @@ -4180,7 +4212,7 @@ def dim(v: VectorProtocol) -> int:
raise TypeError(f"{v!r} is not a vector.Vector")


def _maybe_dimension_error(
def _maybe_same_dimension_error(
v1: VectorProtocol, v2: VectorProtocol, operation: str
) -> None:
"""Raises an error if the vectors are not of the same dimension."""
Expand Down
Loading

0 comments on commit dfaa766

Please sign in to comment.