diff --git a/src/vector/_methods.py b/src/vector/_methods.py index df41fb06..5d25abf4 100644 --- a/src/vector/_methods.py +++ b/src/vector/_methods.py @@ -3927,9 +3927,7 @@ def dim(v: VectorProtocol) -> int: raise TypeError(f"{v!r} is not a vector.Vector") -def _compute_module_of( - one: VectorProtocol, two: VectorProtocol, nontemporal: bool = False -) -> Module: +def _compute_module_of(one: VectorProtocol, two: VectorProtocol) -> Module: """ Determines which compute module to use for functions of two vectors (the one with minimum dimension). @@ -3941,34 +3939,20 @@ def _compute_module_of( if not isinstance(two, Vector): raise TypeError(f"{two!r} is not a Vector") - if isinstance(one, Vector2D): + if isinstance(one, Vector2D) or isinstance(two, Vector2D): import vector._compute.planar return vector._compute.planar - elif isinstance(one, Vector3D): - if isinstance(two, Vector2D): - import vector._compute.planar + elif isinstance(one, Vector3D) or isinstance(two, Vector3D): + import vector._compute.spatial - return vector._compute.planar - else: - import vector._compute.spatial - - return vector._compute.spatial - - elif isinstance(one, Vector4D): - if isinstance(two, Vector2D): - import vector._compute.planar + return vector._compute.spatial - return vector._compute.planar - elif isinstance(two, Vector3D) or nontemporal: - import vector._compute.spatial + elif isinstance(one, Vector4D) or isinstance(two, Vector4D): + import vector._compute.lorentz - return vector._compute.spatial - else: - import vector._compute.lorentz - - return vector._compute.lorentz + return vector._compute.lorentz raise AssertionError(repr(one)) @@ -4120,6 +4104,14 @@ def _get_handler_index(obj: VectorProtocol) -> int: ) +def _check_instance( + any_or_all: typing.Callable[[typing.Iterable[bool]], bool], + objects: tuple[VectorProtocol, ...], + clas: type[Vector] | type[vector.VectorAwkward], +) -> bool: + return any_or_all(isinstance(v, clas) for v in objects) + + def _handler_of(*objects: VectorProtocol) -> VectorProtocol: """ Determines which vector should wrap the output of a dispatched function. @@ -4137,6 +4129,62 @@ def _handler_of(*objects: VectorProtocol) -> VectorProtocol: handler = obj assert handler is not None + + # if there is a 2D vector in objects + if _check_instance(any, objects, Vector2D): + # if all the objects are not from the same backend + # choose the 2D object of the backend with highest priority if it exists + # or demote the first encountered object of the backend with highest priority to 2D + if ( + not _check_instance(all, objects, vector.VectorObject) + and not _check_instance(all, objects, vector.VectorNumpy) + and not _check_instance(all, objects, vector.VectorAwkward) + ): + new_type = type(handler.to_Vector2D()) + flag = 0 + # if there is a 2D object of the backend with highest priority + # make it the new handler + for obj in objects: + if type(obj) == new_type: + handler = obj + flag = 1 + # else, demote the dimension of the object of the backend with highest priority + if flag == 0: + handler = handler.to_Vector2D() + # if all objects are from the same backend + # use the 2D one as the handler + else: + for obj in objects: + if isinstance(obj, Vector2D): + handler = obj + # if there is no 2D vector but a 3D vector in objects + elif _check_instance(any, objects, Vector3D): + # if all the objects are not from the same backend + # choose the 3D object of the backend with highest priority if it exists + # or demote the first encountered object of the backend with highest priority to 3D + if ( + not _check_instance(all, objects, vector.VectorObject) + and not _check_instance(all, objects, vector.VectorNumpy) + and not _check_instance(all, objects, vector.VectorAwkward) + ): + new_type = type(handler.to_Vector3D()) + flag = 0 + # if there is a 3D object of the backend with highest priority + # make it the new handler + for obj in objects: + if type(obj) == new_type: + handler = obj + flag = 1 + # else, demote the dimension of the object of the backend with highest priority + if flag == 0: + handler = handler.to_Vector3D() + # if all objects are from the same backend + # use the 3D one as the handler + else: + for obj in objects: + if isinstance(obj, Vector3D): + handler = obj + return handler diff --git a/tests/backends/test_awkward.py b/tests/backends/test_awkward.py index 4786fbb6..9797be1b 100644 --- a/tests/backends/test_awkward.py +++ b/tests/backends/test_awkward.py @@ -602,3 +602,101 @@ def test_count_4d(): None, 3, ] + + +def test_demotion(): + v1 = vector.zip( + { + "x": [10.0, 20.0, 30.0], + "y": [-10.0, 20.0, 30.0], + }, + ) + v2 = vector.zip( + { + "x": [10.0, 20.0, 30.0], + "y": [-10.0, 20.0, 30.0], + "z": [5.0, 1.0, 1.0], + }, + ) + v3 = vector.zip( + { + "x": [10.0, 20.0, 30.0], + "y": [-10.0, 20.0, 30.0], + "z": [5.0, 1.0, 1.0], + "t": [16.0, 31.0, 46.0], + }, + ) + + v1_v2 = vector.zip( + { + "x": [20.0, 40.0, 60.0], + "y": [-20.0, 40.0, 60.0], + }, + ) + v2_v3 = vector.zip( + { + "x": [20.0, 40.0, 60.0], + "y": [-20.0, 40.0, 60.0], + "z": [10.0, 2.0, 2.0], + }, + ) + v1_v3 = vector.zip( + { + "x": [20.0, 40.0, 60.0], + "y": [-20.0, 40.0, 60.0], + }, + ) + + # order should not matter + assert all(v1 + v2 == v1_v2) + assert all(v2 + v1 == v1_v2) + assert all(v1 + v3 == v1_v3) + assert all(v3 + v1 == v1_v3) + assert all(v2 + v3 == v2_v3) + assert all(v3 + v2 == v2_v3) + + v1 = vector.zip( + { + "px": [10.0, 20.0, 30.0], + "py": [-10.0, 20.0, 30.0], + }, + ) + v2 = vector.zip( + { + "px": [10.0, 20.0, 30.0], + "py": [-10.0, 20.0, 30.0], + "pz": [5.0, 1.0, 1.0], + }, + ) + v3 = vector.zip( + { + "px": [10.0, 20.0, 30.0], + "py": [-10.0, 20.0, 30.0], + "pz": [5.0, 1.0, 1.0], + "t": [16.0, 31.0, 46.0], + }, + ) + + # order should not matter + assert all(v1 + v2 == v1_v2) + assert all(v2 + v1 == v1_v2) + assert all(v1 + v3 == v1_v3) + assert all(v3 + v1 == v1_v3) + assert all(v2 + v3 == v2_v3) + assert all(v3 + v2 == v2_v3) + + v2 = vector.zip( + { + "x": [10.0, 20.0, 30.0], + "y": [-10.0, 20.0, 30.0], + "z": [5.0, 1.0, 1.0], + }, + ) + + # momentum + generic = generic + assert all(v1 + v2 == v1_v2) + assert all(v2 + v1 == v1_v2) + assert all(v1 + v3 == v1_v3) + assert all(v3 + v1 == v1_v3) + assert all(v2 + v3 == v2_v3) + assert all(v3 + v2 == v2_v3) diff --git a/tests/backends/test_numpy.py b/tests/backends/test_numpy.py index 0eea3517..aa3523b0 100644 --- a/tests/backends/test_numpy.py +++ b/tests/backends/test_numpy.py @@ -486,3 +486,121 @@ def test_count_nonzero_4d(): assert numpy.count_nonzero(v2, axis=1, keepdims=True).tolist() == [[3], [2]] assert numpy.count_nonzero(v2, axis=0).tolist() == [2, 2, 1] assert numpy.count_nonzero(v2, axis=0, keepdims=True).tolist() == [[2, 2, 1]] + + +def test_demotion(): + v1 = vector.array( + { + "x": [10.0, 20.0, 30.0], + "y": [-10.0, 20.0, 30.0], + }, + ) + v2 = vector.array( + { + "x": [10.0, 20.0, 30.0], + "y": [-10.0, 20.0, 30.0], + "z": [5.0, 1.0, 1.0], + }, + ) + v3 = vector.array( + { + "x": [10.0, 20.0, 30.0], + "y": [-10.0, 20.0, 30.0], + "z": [5.0, 1.0, 1.0], + "t": [16.0, 31.0, 46.0], + }, + ) + + v1_v2 = vector.array( + { + "x": [20.0, 40.0, 60.0], + "y": [-20.0, 40.0, 60.0], + }, + ) + v2_v3 = vector.array( + { + "x": [20.0, 40.0, 60.0], + "y": [-20.0, 40.0, 60.0], + "z": [10.0, 2.0, 2.0], + }, + ) + v1_v3 = vector.array( + { + "x": [20.0, 40.0, 60.0], + "y": [-20.0, 40.0, 60.0], + }, + ) + + # order should not matter + assert all(v1 + v2 == v1_v2) + assert all(v2 + v1 == v1_v2) + assert all(v1 + v3 == v1_v3) + assert all(v3 + v1 == v1_v3) + assert all(v2 + v3 == v2_v3) + assert all(v3 + v2 == v2_v3) + + v1 = vector.array( + { + "px": [10.0, 20.0, 30.0], + "py": [-10.0, 20.0, 30.0], + }, + ) + v2 = vector.array( + { + "px": [10.0, 20.0, 30.0], + "py": [-10.0, 20.0, 30.0], + "pz": [5.0, 1.0, 1.0], + }, + ) + v3 = vector.array( + { + "px": [10.0, 20.0, 30.0], + "py": [-10.0, 20.0, 30.0], + "pz": [5.0, 1.0, 1.0], + "t": [16.0, 31.0, 46.0], + }, + ) + + p_v1_v2 = vector.array( + { + "px": [20.0, 40.0, 60.0], + "py": [-20.0, 40.0, 60.0], + }, + ) + p_v2_v3 = vector.array( + { + "px": [20.0, 40.0, 60.0], + "py": [-20.0, 40.0, 60.0], + "pz": [10.0, 2.0, 2.0], + }, + ) + p_v1_v3 = vector.array( + { + "px": [20.0, 40.0, 60.0], + "py": [-20.0, 40.0, 60.0], + }, + ) + + # order should not matter + assert all(v1 + v2 == p_v1_v2) + assert all(v2 + v1 == p_v1_v2) + assert all(v1 + v3 == p_v1_v3) + assert all(v3 + v1 == p_v1_v3) + assert all(v2 + v3 == p_v2_v3) + assert all(v3 + v2 == p_v2_v3) + + v2 = vector.zip( + { + "x": [10.0, 20.0, 30.0], + "y": [-10.0, 20.0, 30.0], + "z": [5.0, 1.0, 1.0], + }, + ) + + # momentum + generic = generic + assert all(v1 + v2 == v1_v2) + assert all(v2 + v1 == v1_v2) + assert all(v1 + v3 == v1_v3) + assert all(v3 + v1 == v1_v3) + assert all(v2 + v3 == v2_v3) + assert all(v3 + v2 == v2_v3) diff --git a/tests/backends/test_object.py b/tests/backends/test_object.py index a976611f..c50598c2 100644 --- a/tests/backends/test_object.py +++ b/tests/backends/test_object.py @@ -257,3 +257,41 @@ def test_array_casting(): with pytest.raises(TypeError): vector.obj(x=1, y=False) + + +def test_demotion(): + v1 = vector.obj(x=0.1, y=0.2) + v2 = vector.obj(x=1, y=2, z=3) + v3 = vector.obj(x=10, y=20, z=30, t=40) + + # order should not matter + assert v1 + v2 == vector.obj(x=1.1, y=2.2) + assert v2 + v1 == vector.obj(x=1.1, y=2.2) + assert v1 + v3 == vector.obj(x=10.1, y=20.2) + assert v3 + v1 == vector.obj(x=10.1, y=20.2) + assert v2 + v3 == vector.obj(x=11, y=22, z=33) + assert v3 + v2 == vector.obj(x=11, y=22, z=33) + + v1 = vector.obj(px=0.1, py=0.2) + v2 = vector.obj(px=1, py=2, pz=3) + v3 = vector.obj(px=10, py=20, pz=30, t=40) + + # order should not matter + assert v1 + v2 == vector.obj(px=1.1, py=2.2) + assert v2 + v1 == vector.obj(px=1.1, py=2.2) + assert v1 + v3 == vector.obj(px=10.1, py=20.2) + assert v3 + v1 == vector.obj(px=10.1, py=20.2) + assert v2 + v3 == vector.obj(px=11, py=22, pz=33) + assert v3 + v2 == vector.obj(px=11, py=22, pz=33) + + v1 = vector.obj(px=0.1, py=0.2) + v2 = vector.obj(x=1, y=2, z=3) + v3 = vector.obj(px=10, py=20, pz=30, t=40) + + # momentum + generic = generic + assert v1 + v2 == vector.obj(x=1.1, y=2.2) + assert v2 + v1 == vector.obj(x=1.1, y=2.2) + assert v1 + v3 == vector.obj(px=10.1, py=20.2) + assert v3 + v1 == vector.obj(px=10.1, py=20.2) + assert v2 + v3 == vector.obj(x=11, y=22, z=33) + assert v3 + v2 == vector.obj(x=11, y=22, z=33) diff --git a/tests/test_methods.py b/tests/test_methods.py index a69d6815..f8a3c673 100644 --- a/tests/test_methods.py +++ b/tests/test_methods.py @@ -6,14 +6,175 @@ from __future__ import annotations import vector +from vector import VectorObject2D, VectorObject3D, VectorObject4D -class CustomVector(vector.VectorObject4D): - pass +def test_handler_of(): + object_a = VectorObject4D.from_xyzt(0.0, 0.0, 0.0, 0.0) + object_b = VectorObject4D.from_xyzt(1.0, 1.0, 1.0, 1.0) + protocol = vector._methods._handler_of(object_a, object_b) + assert protocol == object_a + object_a = VectorObject3D.from_xyz(0.0, 0.0, 0.0) + object_b = VectorObject4D.from_xyzt(1.0, 1.0, 1.0, 1.0) + protocol = vector._methods._handler_of(object_a, object_b) + assert protocol == object_a -def test_handler_of(): - object_a = CustomVector.from_xyzt(0.0, 0.0, 0.0, 0.0) - object_b = CustomVector.from_xyzt(1.0, 1.0, 1.0, 1.0) + object_a = VectorObject4D.from_xyzt(0.0, 0.0, 0.0, 0.0) + object_b = VectorObject3D.from_xyz(1.0, 1.0, 1.0) + protocol = vector._methods._handler_of(object_a, object_b) + assert protocol == object_b + + object_a = VectorObject2D.from_xy(0.0, 0.0) + object_b = VectorObject4D.from_xyzt(1.0, 1.0, 1.0, 1.0) + protocol = vector._methods._handler_of(object_a, object_b) + assert protocol == object_a + + object_a = VectorObject4D.from_xyzt(0.0, 0.0, 0.0, 0.0) + object_b = VectorObject2D.from_xy(1.0, 1.0) + protocol = vector._methods._handler_of(object_a, object_b) + assert protocol == object_b + + object_a = VectorObject2D.from_xy(0.0, 0.0) + object_b = VectorObject3D.from_xyz(1.0, 1.0, 1.0) protocol = vector._methods._handler_of(object_a, object_b) assert protocol == object_a + + object_a = VectorObject3D.from_xyz(0.0, 0.0, 0.0) + object_b = VectorObject2D.from_xy(1.0, 1.0) + protocol = vector._methods._handler_of(object_a, object_b) + assert protocol == object_b + + awkward_a = vector.zip( + { + "x": [10.0, 20.0, 30.0], + "y": [-10.0, 20.0, 30.0], + "z": [5.0, 10.0, 15.0], + "t": [16.0, 31.0, 46.0], + }, + ) + object_b = VectorObject2D.from_xy(1.0, 1.0) + protocol = vector._methods._handler_of(awkward_a, object_b) + # chooses awkward backend and converts the vector to 2D + assert all(protocol == awkward_a.to_Vector2D()) + + awkward_a = vector.zip( + { + "x": [10.0, 20.0, 30.0], + "y": [-10.0, 20.0, 30.0], + }, + ) + object_b = VectorObject4D.from_xyzt(1.0, 1.0, 1.0, 1.0) + protocol = vector._methods._handler_of(object_b, awkward_a) + # chooses awkward backend and the vector is already of the + # lower dimension + assert all(protocol == awkward_a) + + awkward_a = vector.zip( + { + "x": [10.0, 20.0, 30.0], + "y": [-10.0, 20.0, 30.0], + }, + ) + awkward_b = vector.zip( + { + "x": [1.0, 2.0, 3.0], + "y": [-1.0, 2.0, 3.0], + "z": [5.0, 10.0, 15.0], + "t": [16.0, 31.0, 46.0], + }, + ) + object_b = VectorObject4D.from_xyzt(1.0, 1.0, 1.0, 1.0) + protocol = vector._methods._handler_of(object_b, awkward_a, awkward_b) + # chooses awkward backend and the 2D awkward vector + # (first encountered awkward vector) + assert all(protocol == awkward_a) + + awkward_a = vector.zip( + { + "x": [10.0, 20.0, 30.0], + "y": [-10.0, 20.0, 30.0], + "z": [-10.0, 20.0, 30.0], + }, + ) + awkward_b = vector.zip( + { + "x": [1.0, 2.0, 3.0], + "y": [-1.0, 2.0, 3.0], + "z": [5.0, 10.0, 15.0], + "t": [16.0, 31.0, 46.0], + }, + ) + object_b = VectorObject2D.from_xy(1.0, 1.0) + protocol = vector._methods._handler_of(awkward_b, object_b, awkward_a) + # chooses awkward backend and converts awkward_b to 2D + # (first encountered awkward vector) + assert all(protocol == awkward_b.to_Vector2D()) + + awkward_a = vector.zip( + { + "x": [10.0, 20.0, 30.0], + "y": [-10.0, 20.0, 30.0], + "z": [5.0, 1.0, 1.0], + }, + ) + awkward_b = vector.zip( + { + "x": [1.0, 2.0, 3.0], + "y": [-1.0, 2.0, 3.0], + "z": [5.0, 10.0, 15.0], + "t": [16.0, 31.0, 46.0], + }, + ) + object_b = VectorObject2D.from_xy(1.0, 1.0) + protocol = vector._methods._handler_of(object_b, awkward_a, awkward_b) + # chooses awkward backend and converts the vector to 2D + # (the first awkward vector encountered is used as the base) + assert all(protocol == awkward_a.to_Vector2D()) + + numpy_a = vector.array( + { + "x": [1.1, 1.2, 1.3, 1.4, 1.5], + "y": [2.1, 2.2, 2.3, 2.4, 2.5], + "z": [3.1, 3.2, 3.3, 3.4, 3.5], + } + ) + awkward_b = vector.zip( + { + "x": [1.0, 2.0, 3.0], + "y": [-1.0, 2.0, 3.0], + "z": [5.0, 10.0, 15.0], + "t": [16.0, 31.0, 46.0], + }, + ) + object_b = VectorObject2D.from_xy(1.0, 1.0) + protocol = vector._methods._handler_of(object_b, numpy_a, awkward_b) + # chooses awkward backend and converts the vector to 2D + assert all(protocol == awkward_b.to_Vector2D()) + + awkward_a = vector.zip( + { + "x": [10.0, 20.0, 30.0], + "y": [-10.0, 20.0, 30.0], + "z": [5.0, 1.0, 1.0], + }, + ) + numpy_a = vector.array( + { + "x": [1.1, 1.2, 1.3, 1.4, 1.5], + "y": [2.1, 2.2, 2.3, 2.4, 2.5], + } + ) + awkward_b = vector.zip( + { + "x": [1.0, 2.0, 3.0], + "y": [-1.0, 2.0, 3.0], + "z": [5.0, 10.0, 15.0], + "t": [16.0, 31.0, 46.0], + }, + ) + object_b = VectorObject3D.from_xyz(1.0, 1.0, 1.0) + protocol = vector._methods._handler_of(object_b, awkward_a, awkward_b, numpy_a) + # chooses awkward backend and converts the vector to 2D + # (the first awkward vector encountered is used as the base) + assert all(protocol == awkward_a.to_Vector2D())