Skip to content

Commit

Permalink
fix: vectors should be demoted to the lowest dimension (#413)
Browse files Browse the repository at this point in the history
* feat: vectors should be demoted to the lowest dimension

* don't use awkward stuff in _methods.py

* Refactor the demotion logic

* Revert changes to _compute_module_of

* Generalise lowering to classes inheriting vector mixins

* Add a note
  • Loading branch information
Saransh-cpp authored Jan 30, 2024
1 parent c4c90b9 commit 6aef998
Show file tree
Hide file tree
Showing 8 changed files with 502 additions and 12 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ ak.Array(

All of the keyword arguments and rules that apply to `vector.obj` construction apply to `vector.awk` field names.

Finally, the `VectorAwkward` mixins can be subclassed to create custom vector classes. The awkward behavior classes and projections must be named as `*Array`. For example, `coffea` uses the following names - `TwoVectorArray`, `ThreeVectorArray`, `PolarTwoVectorArray`, `SphericalThreeVectorArray`, ...

## Vector properties

Any geometrical coordinate can be computed from vectors in any coordinate system; they'll be provided or computed as needed.
Expand Down
4 changes: 3 additions & 1 deletion docs/usage/intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,9 @@
"id": "beginning-expert",
"metadata": {},
"source": [
"All of the keyword arguments and rules that apply to `vector.obj` construction apply to `vector.Array` field names."
"All of the keyword arguments and rules that apply to `vector.obj` construction apply to `vector.Array` field names.\n",
"\n",
"Finally, the `VectorAwkward` mixins can be subclassed to create custom vector classes. The awkward behavior classes and projections must be named as `*Array`. For example, `coffea` uses the following names - `TwoVectorArray`, `ThreeVectorArray`, `PolarTwoVectorArray`, `SphericalThreeVectorArray`, ..."
]
},
{
Expand Down
65 changes: 65 additions & 0 deletions src/vector/_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4120,6 +4120,58 @@ def _get_handler_index(obj: VectorProtocol) -> int:
)


def _check_instance(
any_or_all: typing.Callable[[typing.Iterable[bool]], bool],
objects: tuple[VectorProtocol, ...],
clas: type[VectorProtocol],
) -> bool:
return any_or_all(isinstance(v, clas) for v in objects)


def _demote_handler_vector(
handler: VectorProtocol,
objects: tuple[VectorProtocol, ...],
vector_class: type[VectorProtocol],
new_vector: VectorProtocol,
) -> VectorProtocol:
"""
Demotes the handler vector to the lowest possible dimension while respecting
the priority of backends.
"""
# if all the objects are not from the same backend
# choose the {X}D object of the backend with highest priority (if it exists)
# or demote the first encountered object of the backend with highest priority to {X}D
backends = [
next(
x.__module__
for x in type(obj).__mro__
if "vector.backends." in x.__module__
)
for obj in objects
]
if len({_handler_priority.index(backend) for backend in backends}) != 1:
new_type = type(new_vector)
flag = 0
# if there is a {X}D object of the backend with highest priority
# make it the new handler
for obj in objects:
if type(obj) == new_type:
handler = obj
flag = 1
break
# else, demote the dimension of the object of the backend with highest priority
if flag == 0:
handler = new_vector
# if all objects are from the same backend
# use the {X}D one as the handler
else:
for obj in objects:
if isinstance(obj, vector_class):
handler = obj

return handler


def _handler_of(*objects: VectorProtocol) -> VectorProtocol:
"""
Determines which vector should wrap the output of a dispatched function.
Expand All @@ -4137,6 +4189,19 @@ def _handler_of(*objects: VectorProtocol) -> VectorProtocol:
handler = obj

assert handler is not None

if _check_instance(all, objects, Vector):
# if there is a 2D vector in objects
if _check_instance(any, objects, Vector2D):
handler = _demote_handler_vector(
handler, objects, Vector2D, handler.to_Vector2D()
)
# if there is no 2D vector but a 3D vector in objects
elif _check_instance(any, objects, Vector3D):
handler = _demote_handler_vector(
handler, objects, Vector3D, handler.to_Vector3D()
)

return handler


Expand Down
14 changes: 8 additions & 6 deletions src/vector/backends/awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,19 +575,21 @@ def elements(self) -> tuple[ArrayOrRecord]:


def _class_to_name(cls: type[VectorProtocol]) -> str:
# respect the type of classes inheriting VectorAwkward classes
is_vector = "vector.backends" in cls.__module__
if issubclass(cls, Momentum):
if issubclass(cls, Vector2D):
return "Momentum2D"
return "Momentum2D" if is_vector else cls.__name__[:-5]
if issubclass(cls, Vector3D):
return "Momentum3D"
return "Momentum3D" if is_vector else cls.__name__[:-5]
if issubclass(cls, Vector4D):
return "Momentum4D"
return "Momentum4D" if is_vector else cls.__name__[:-5]
if issubclass(cls, Vector2D):
return "Vector2D"
return "Vector2D" if is_vector else cls.__name__[:-5]
if issubclass(cls, Vector3D):
return "Vector3D"
return "Vector3D" if is_vector else cls.__name__[:-5]
if issubclass(cls, Vector4D):
return "Vector4D"
return "Vector4D" if is_vector else cls.__name__[:-5]

raise AssertionError(repr(cls))

Expand Down
98 changes: 98 additions & 0 deletions tests/backends/test_awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,101 @@ def test_count_4d():
None,
3,
]


def test_demotion():
v1 = vector.zip(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
},
)
v2 = vector.zip(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
"z": [5.0, 1.0, 1.0],
},
)
v3 = vector.zip(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
"z": [5.0, 1.0, 1.0],
"t": [16.0, 31.0, 46.0],
},
)

v1_v2 = vector.zip(
{
"x": [20.0, 40.0, 60.0],
"y": [-20.0, 40.0, 60.0],
},
)
v2_v3 = vector.zip(
{
"x": [20.0, 40.0, 60.0],
"y": [-20.0, 40.0, 60.0],
"z": [10.0, 2.0, 2.0],
},
)
v1_v3 = vector.zip(
{
"x": [20.0, 40.0, 60.0],
"y": [-20.0, 40.0, 60.0],
},
)

# order should not matter
assert all(v1 + v2 == v1_v2)
assert all(v2 + v1 == v1_v2)
assert all(v1 + v3 == v1_v3)
assert all(v3 + v1 == v1_v3)
assert all(v2 + v3 == v2_v3)
assert all(v3 + v2 == v2_v3)

v1 = vector.zip(
{
"px": [10.0, 20.0, 30.0],
"py": [-10.0, 20.0, 30.0],
},
)
v2 = vector.zip(
{
"px": [10.0, 20.0, 30.0],
"py": [-10.0, 20.0, 30.0],
"pz": [5.0, 1.0, 1.0],
},
)
v3 = vector.zip(
{
"px": [10.0, 20.0, 30.0],
"py": [-10.0, 20.0, 30.0],
"pz": [5.0, 1.0, 1.0],
"t": [16.0, 31.0, 46.0],
},
)

# order should not matter
assert all(v1 + v2 == v1_v2)
assert all(v2 + v1 == v1_v2)
assert all(v1 + v3 == v1_v3)
assert all(v3 + v1 == v1_v3)
assert all(v2 + v3 == v2_v3)
assert all(v3 + v2 == v2_v3)

v2 = vector.zip(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
"z": [5.0, 1.0, 1.0],
},
)

# momentum + generic = generic
assert all(v1 + v2 == v1_v2)
assert all(v2 + v1 == v1_v2)
assert all(v1 + v3 == v1_v3)
assert all(v3 + v1 == v1_v3)
assert all(v2 + v3 == v2_v3)
assert all(v3 + v2 == v2_v3)
118 changes: 118 additions & 0 deletions tests/backends/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,121 @@ def test_count_nonzero_4d():
assert numpy.count_nonzero(v2, axis=1, keepdims=True).tolist() == [[3], [2]]
assert numpy.count_nonzero(v2, axis=0).tolist() == [2, 2, 1]
assert numpy.count_nonzero(v2, axis=0, keepdims=True).tolist() == [[2, 2, 1]]


def test_demotion():
v1 = vector.array(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
},
)
v2 = vector.array(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
"z": [5.0, 1.0, 1.0],
},
)
v3 = vector.array(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
"z": [5.0, 1.0, 1.0],
"t": [16.0, 31.0, 46.0],
},
)

v1_v2 = vector.array(
{
"x": [20.0, 40.0, 60.0],
"y": [-20.0, 40.0, 60.0],
},
)
v2_v3 = vector.array(
{
"x": [20.0, 40.0, 60.0],
"y": [-20.0, 40.0, 60.0],
"z": [10.0, 2.0, 2.0],
},
)
v1_v3 = vector.array(
{
"x": [20.0, 40.0, 60.0],
"y": [-20.0, 40.0, 60.0],
},
)

# order should not matter
assert all(v1 + v2 == v1_v2)
assert all(v2 + v1 == v1_v2)
assert all(v1 + v3 == v1_v3)
assert all(v3 + v1 == v1_v3)
assert all(v2 + v3 == v2_v3)
assert all(v3 + v2 == v2_v3)

v1 = vector.array(
{
"px": [10.0, 20.0, 30.0],
"py": [-10.0, 20.0, 30.0],
},
)
v2 = vector.array(
{
"px": [10.0, 20.0, 30.0],
"py": [-10.0, 20.0, 30.0],
"pz": [5.0, 1.0, 1.0],
},
)
v3 = vector.array(
{
"px": [10.0, 20.0, 30.0],
"py": [-10.0, 20.0, 30.0],
"pz": [5.0, 1.0, 1.0],
"t": [16.0, 31.0, 46.0],
},
)

p_v1_v2 = vector.array(
{
"px": [20.0, 40.0, 60.0],
"py": [-20.0, 40.0, 60.0],
},
)
p_v2_v3 = vector.array(
{
"px": [20.0, 40.0, 60.0],
"py": [-20.0, 40.0, 60.0],
"pz": [10.0, 2.0, 2.0],
},
)
p_v1_v3 = vector.array(
{
"px": [20.0, 40.0, 60.0],
"py": [-20.0, 40.0, 60.0],
},
)

# order should not matter
assert all(v1 + v2 == p_v1_v2)
assert all(v2 + v1 == p_v1_v2)
assert all(v1 + v3 == p_v1_v3)
assert all(v3 + v1 == p_v1_v3)
assert all(v2 + v3 == p_v2_v3)
assert all(v3 + v2 == p_v2_v3)

v2 = vector.array(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
"z": [5.0, 1.0, 1.0],
},
)

# momentum + generic = generic
assert all(v1 + v2 == v1_v2)
assert all(v2 + v1 == v1_v2)
assert all(v1 + v3 == v1_v3)
assert all(v3 + v1 == v1_v3)
assert all(v2 + v3 == v2_v3)
assert all(v3 + v2 == v2_v3)
38 changes: 38 additions & 0 deletions tests/backends/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,41 @@ def test_array_casting():

with pytest.raises(TypeError):
vector.obj(x=1, y=False)


def test_demotion():
v1 = vector.obj(x=0.1, y=0.2)
v2 = vector.obj(x=1, y=2, z=3)
v3 = vector.obj(x=10, y=20, z=30, t=40)

# order should not matter
assert v1 + v2 == vector.obj(x=1.1, y=2.2)
assert v2 + v1 == vector.obj(x=1.1, y=2.2)
assert v1 + v3 == vector.obj(x=10.1, y=20.2)
assert v3 + v1 == vector.obj(x=10.1, y=20.2)
assert v2 + v3 == vector.obj(x=11, y=22, z=33)
assert v3 + v2 == vector.obj(x=11, y=22, z=33)

v1 = vector.obj(px=0.1, py=0.2)
v2 = vector.obj(px=1, py=2, pz=3)
v3 = vector.obj(px=10, py=20, pz=30, t=40)

# order should not matter
assert v1 + v2 == vector.obj(px=1.1, py=2.2)
assert v2 + v1 == vector.obj(px=1.1, py=2.2)
assert v1 + v3 == vector.obj(px=10.1, py=20.2)
assert v3 + v1 == vector.obj(px=10.1, py=20.2)
assert v2 + v3 == vector.obj(px=11, py=22, pz=33)
assert v3 + v2 == vector.obj(px=11, py=22, pz=33)

v1 = vector.obj(px=0.1, py=0.2)
v2 = vector.obj(x=1, y=2, z=3)
v3 = vector.obj(px=10, py=20, pz=30, t=40)

# momentum + generic = generic
assert v1 + v2 == vector.obj(x=1.1, y=2.2)
assert v2 + v1 == vector.obj(x=1.1, y=2.2)
assert v1 + v3 == vector.obj(px=10.1, py=20.2)
assert v3 + v1 == vector.obj(px=10.1, py=20.2)
assert v2 + v3 == vector.obj(x=11, y=22, z=33)
assert v3 + v2 == vector.obj(x=11, y=22, z=33)
Loading

0 comments on commit 6aef998

Please sign in to comment.