Skip to content

Commit

Permalink
feat: vectors should be demoted to the lowest dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Jan 24, 2024
1 parent a0dd194 commit 4aeda08
Show file tree
Hide file tree
Showing 5 changed files with 492 additions and 29 deletions.
96 changes: 72 additions & 24 deletions src/vector/_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -3927,9 +3927,7 @@ def dim(v: VectorProtocol) -> int:
raise TypeError(f"{v!r} is not a vector.Vector")


def _compute_module_of(
one: VectorProtocol, two: VectorProtocol, nontemporal: bool = False
) -> Module:
def _compute_module_of(one: VectorProtocol, two: VectorProtocol) -> Module:
"""
Determines which compute module to use for functions of two vectors
(the one with minimum dimension).
Expand All @@ -3941,34 +3939,20 @@ def _compute_module_of(
if not isinstance(two, Vector):
raise TypeError(f"{two!r} is not a Vector")

if isinstance(one, Vector2D):
if isinstance(one, Vector2D) or isinstance(two, Vector2D):
import vector._compute.planar

return vector._compute.planar

elif isinstance(one, Vector3D):
if isinstance(two, Vector2D):
import vector._compute.planar
elif isinstance(one, Vector3D) or isinstance(two, Vector3D):
import vector._compute.spatial

return vector._compute.planar
else:
import vector._compute.spatial

return vector._compute.spatial

elif isinstance(one, Vector4D):
if isinstance(two, Vector2D):
import vector._compute.planar
return vector._compute.spatial

return vector._compute.planar
elif isinstance(two, Vector3D) or nontemporal:
import vector._compute.spatial
elif isinstance(one, Vector4D) or isinstance(two, Vector4D):
import vector._compute.lorentz

return vector._compute.spatial
else:
import vector._compute.lorentz

return vector._compute.lorentz
return vector._compute.lorentz

raise AssertionError(repr(one))

Expand Down Expand Up @@ -4120,6 +4104,14 @@ def _get_handler_index(obj: VectorProtocol) -> int:
)


def _check_instance(
any_or_all: typing.Callable[[typing.Iterable[bool]], bool],
objects: tuple[VectorProtocol, ...],
clas: type[Vector] | type[vector.VectorAwkward],
) -> bool:
return any_or_all(isinstance(v, clas) for v in objects)


def _handler_of(*objects: VectorProtocol) -> VectorProtocol:
"""
Determines which vector should wrap the output of a dispatched function.
Expand All @@ -4137,6 +4129,62 @@ def _handler_of(*objects: VectorProtocol) -> VectorProtocol:
handler = obj

assert handler is not None

# if there is a 2D vector in objects
if _check_instance(any, objects, Vector2D):
# if all the objects are not from the same backend
# choose the 2D object of the backend with highest priority if it exists
# or demote the first encountered object of the backend with highest priority to 2D
if (
not _check_instance(all, objects, vector.VectorObject)
and not _check_instance(all, objects, vector.VectorNumpy)
and not _check_instance(all, objects, vector.VectorAwkward)
):
new_type = type(handler.to_Vector2D())
flag = 0
# if there is a 2D object of the backend with highest priority
# make it the new handler
for obj in objects:
if type(obj) == new_type:
handler = obj
flag = 1
# else, demote the dimension of the object of the backend with highest priority
if flag == 0:
handler = handler.to_Vector2D()
# if all objects are from the same backend
# use the 2D one as the handler
else:
for obj in objects:
if isinstance(obj, Vector2D):
handler = obj
# if there is no 2D vector but a 3D vector in objects
elif _check_instance(any, objects, Vector3D):
# if all the objects are not from the same backend
# choose the 3D object of the backend with highest priority if it exists
# or demote the first encountered object of the backend with highest priority to 3D
if (
not _check_instance(all, objects, vector.VectorObject)
and not _check_instance(all, objects, vector.VectorNumpy)
and not _check_instance(all, objects, vector.VectorAwkward)
):
new_type = type(handler.to_Vector3D())
flag = 0
# if there is a 3D object of the backend with highest priority
# make it the new handler
for obj in objects:
if type(obj) == new_type:
handler = obj
flag = 1
# else, demote the dimension of the object of the backend with highest priority
if flag == 0:
handler = handler.to_Vector3D()
# if all objects are from the same backend
# use the 3D one as the handler
else:
for obj in objects:
if isinstance(obj, Vector3D):
handler = obj

return handler


Expand Down
98 changes: 98 additions & 0 deletions tests/backends/test_awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,101 @@ def test_count_4d():
None,
3,
]


def test_demotion():
v1 = vector.zip(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
},
)
v2 = vector.zip(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
"z": [5.0, 1.0, 1.0],
},
)
v3 = vector.zip(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
"z": [5.0, 1.0, 1.0],
"t": [16.0, 31.0, 46.0],
},
)

v1_v2 = vector.zip(
{
"x": [20.0, 40.0, 60.0],
"y": [-20.0, 40.0, 60.0],
},
)
v2_v3 = vector.zip(
{
"x": [20.0, 40.0, 60.0],
"y": [-20.0, 40.0, 60.0],
"z": [10.0, 2.0, 2.0],
},
)
v1_v3 = vector.zip(
{
"x": [20.0, 40.0, 60.0],
"y": [-20.0, 40.0, 60.0],
},
)

# order should not matter
assert all(v1 + v2 == v1_v2)
assert all(v2 + v1 == v1_v2)
assert all(v1 + v3 == v1_v3)
assert all(v3 + v1 == v1_v3)
assert all(v2 + v3 == v2_v3)
assert all(v3 + v2 == v2_v3)

v1 = vector.zip(
{
"px": [10.0, 20.0, 30.0],
"py": [-10.0, 20.0, 30.0],
},
)
v2 = vector.zip(
{
"px": [10.0, 20.0, 30.0],
"py": [-10.0, 20.0, 30.0],
"pz": [5.0, 1.0, 1.0],
},
)
v3 = vector.zip(
{
"px": [10.0, 20.0, 30.0],
"py": [-10.0, 20.0, 30.0],
"pz": [5.0, 1.0, 1.0],
"t": [16.0, 31.0, 46.0],
},
)

# order should not matter
assert all(v1 + v2 == v1_v2)
assert all(v2 + v1 == v1_v2)
assert all(v1 + v3 == v1_v3)
assert all(v3 + v1 == v1_v3)
assert all(v2 + v3 == v2_v3)
assert all(v3 + v2 == v2_v3)

v2 = vector.zip(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
"z": [5.0, 1.0, 1.0],
},
)

# momentum + generic = generic
assert all(v1 + v2 == v1_v2)
assert all(v2 + v1 == v1_v2)
assert all(v1 + v3 == v1_v3)
assert all(v3 + v1 == v1_v3)
assert all(v2 + v3 == v2_v3)
assert all(v3 + v2 == v2_v3)
118 changes: 118 additions & 0 deletions tests/backends/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,121 @@ def test_count_nonzero_4d():
assert numpy.count_nonzero(v2, axis=1, keepdims=True).tolist() == [[3], [2]]
assert numpy.count_nonzero(v2, axis=0).tolist() == [2, 2, 1]
assert numpy.count_nonzero(v2, axis=0, keepdims=True).tolist() == [[2, 2, 1]]


def test_demotion():
v1 = vector.array(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
},
)
v2 = vector.array(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
"z": [5.0, 1.0, 1.0],
},
)
v3 = vector.array(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
"z": [5.0, 1.0, 1.0],
"t": [16.0, 31.0, 46.0],
},
)

v1_v2 = vector.array(
{
"x": [20.0, 40.0, 60.0],
"y": [-20.0, 40.0, 60.0],
},
)
v2_v3 = vector.array(
{
"x": [20.0, 40.0, 60.0],
"y": [-20.0, 40.0, 60.0],
"z": [10.0, 2.0, 2.0],
},
)
v1_v3 = vector.array(
{
"x": [20.0, 40.0, 60.0],
"y": [-20.0, 40.0, 60.0],
},
)

# order should not matter
assert all(v1 + v2 == v1_v2)
assert all(v2 + v1 == v1_v2)
assert all(v1 + v3 == v1_v3)
assert all(v3 + v1 == v1_v3)
assert all(v2 + v3 == v2_v3)
assert all(v3 + v2 == v2_v3)

v1 = vector.array(
{
"px": [10.0, 20.0, 30.0],
"py": [-10.0, 20.0, 30.0],
},
)
v2 = vector.array(
{
"px": [10.0, 20.0, 30.0],
"py": [-10.0, 20.0, 30.0],
"pz": [5.0, 1.0, 1.0],
},
)
v3 = vector.array(
{
"px": [10.0, 20.0, 30.0],
"py": [-10.0, 20.0, 30.0],
"pz": [5.0, 1.0, 1.0],
"t": [16.0, 31.0, 46.0],
},
)

p_v1_v2 = vector.array(
{
"px": [20.0, 40.0, 60.0],
"py": [-20.0, 40.0, 60.0],
},
)
p_v2_v3 = vector.array(
{
"px": [20.0, 40.0, 60.0],
"py": [-20.0, 40.0, 60.0],
"pz": [10.0, 2.0, 2.0],
},
)
p_v1_v3 = vector.array(
{
"px": [20.0, 40.0, 60.0],
"py": [-20.0, 40.0, 60.0],
},
)

# order should not matter
assert all(v1 + v2 == p_v1_v2)
assert all(v2 + v1 == p_v1_v2)
assert all(v1 + v3 == p_v1_v3)
assert all(v3 + v1 == p_v1_v3)
assert all(v2 + v3 == p_v2_v3)
assert all(v3 + v2 == p_v2_v3)

v2 = vector.zip(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
"z": [5.0, 1.0, 1.0],
},
)

# momentum + generic = generic
assert all(v1 + v2 == v1_v2)
assert all(v2 + v1 == v1_v2)
assert all(v1 + v3 == v1_v3)
assert all(v3 + v1 == v1_v3)
assert all(v2 + v3 == v2_v3)
assert all(v3 + v2 == v2_v3)
38 changes: 38 additions & 0 deletions tests/backends/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,41 @@ def test_array_casting():

with pytest.raises(TypeError):
vector.obj(x=1, y=False)


def test_demotion():
v1 = vector.obj(x=0.1, y=0.2)
v2 = vector.obj(x=1, y=2, z=3)
v3 = vector.obj(x=10, y=20, z=30, t=40)

# order should not matter
assert v1 + v2 == vector.obj(x=1.1, y=2.2)
assert v2 + v1 == vector.obj(x=1.1, y=2.2)
assert v1 + v3 == vector.obj(x=10.1, y=20.2)
assert v3 + v1 == vector.obj(x=10.1, y=20.2)
assert v2 + v3 == vector.obj(x=11, y=22, z=33)
assert v3 + v2 == vector.obj(x=11, y=22, z=33)

v1 = vector.obj(px=0.1, py=0.2)
v2 = vector.obj(px=1, py=2, pz=3)
v3 = vector.obj(px=10, py=20, pz=30, t=40)

# order should not matter
assert v1 + v2 == vector.obj(px=1.1, py=2.2)
assert v2 + v1 == vector.obj(px=1.1, py=2.2)
assert v1 + v3 == vector.obj(px=10.1, py=20.2)
assert v3 + v1 == vector.obj(px=10.1, py=20.2)
assert v2 + v3 == vector.obj(px=11, py=22, pz=33)
assert v3 + v2 == vector.obj(px=11, py=22, pz=33)

v1 = vector.obj(px=0.1, py=0.2)
v2 = vector.obj(x=1, y=2, z=3)
v3 = vector.obj(px=10, py=20, pz=30, t=40)

# momentum + generic = generic
assert v1 + v2 == vector.obj(x=1.1, y=2.2)
assert v2 + v1 == vector.obj(x=1.1, y=2.2)
assert v1 + v3 == vector.obj(px=10.1, py=20.2)
assert v3 + v1 == vector.obj(px=10.1, py=20.2)
assert v2 + v3 == vector.obj(x=11, y=22, z=33)
assert v3 + v2 == vector.obj(x=11, y=22, z=33)
Loading

0 comments on commit 4aeda08

Please sign in to comment.