Skip to content

Commit

Permalink
Refactor the demotion logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Jan 25, 2024
1 parent 2f8eb30 commit 411bb92
Showing 1 changed file with 42 additions and 43 deletions.
85 changes: 42 additions & 43 deletions src/vector/_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4107,11 +4107,46 @@ def _get_handler_index(obj: VectorProtocol) -> int:
def _check_instance(
any_or_all: typing.Callable[[typing.Iterable[bool]], bool],
objects: tuple[VectorProtocol, ...],
clas: type[Vector] | type[vector.VectorAwkward],
clas: type[VectorProtocol],
) -> bool:
return any_or_all(isinstance(v, clas) for v in objects)


def _demote_handler_vector(
handler: VectorProtocol,
objects: tuple[VectorProtocol, ...],
vector_class: type[VectorProtocol],
new_vector: VectorProtocol,
) -> VectorProtocol:
"""
Demotes the handler vector to the lowest possible dimension while respecting
the priority of backends.
"""
# if all the objects are not from the same backend
# choose the {X}D object of the backend with highest priority (if it exists)
# or demote the first encountered object of the backend with highest priority to {X}D
if len({_handler_priority.index(obj.__module__) for obj in objects}) != 1:
new_type = type(new_vector)
flag = 0
# if there is a {X}D object of the backend with highest priority
# make it the new handler
for obj in objects:
if type(obj) == new_type:
handler = obj
flag = 1
# else, demote the dimension of the object of the backend with highest priority
if flag == 0:
handler = new_vector
# if all objects are from the same backend
# use the {X}D one as the handler
else:
for obj in objects:
if isinstance(obj, vector_class):
handler = obj

return handler


def _handler_of(*objects: VectorProtocol) -> VectorProtocol:
"""
Determines which vector should wrap the output of a dispatched function.
Expand All @@ -4133,50 +4168,14 @@ def _handler_of(*objects: VectorProtocol) -> VectorProtocol:
if _check_instance(all, objects, Vector):
# if there is a 2D vector in objects
if _check_instance(any, objects, Vector2D):
# if all the objects are not from the same backend
# choose the 2D object of the backend with highest priority if it exists
# or demote the first encountered object of the backend with highest priority to 2D
if len({_handler_priority.index(obj.__module__) for obj in objects}) != 1:
new_type = type(handler.to_Vector2D())
flag = 0
# if there is a 2D object of the backend with highest priority
# make it the new handler
for obj in objects:
if type(obj) == new_type:
handler = obj
flag = 1
# else, demote the dimension of the object of the backend with highest priority
if flag == 0:
handler = handler.to_Vector2D()
# if all objects are from the same backend
# use the 2D one as the handler
else:
for obj in objects:
if isinstance(obj, Vector2D):
handler = obj
handler = _demote_handler_vector(
handler, objects, Vector2D, handler.to_Vector2D()
)
# if there is no 2D vector but a 3D vector in objects
elif _check_instance(any, objects, Vector3D):
# if all the objects are not from the same backend
# choose the 3D object of the backend with highest priority if it exists
# or demote the first encountered object of the backend with highest priority to 3D
if len({_handler_priority.index(obj.__module__) for obj in objects}) != 1:
new_type = type(handler.to_Vector3D())
flag = 0
# if there is a 3D object of the backend with highest priority
# make it the new handler
for obj in objects:
if type(obj) == new_type:
handler = obj
flag = 1
# else, demote the dimension of the object of the backend with highest priority
if flag == 0:
handler = handler.to_Vector3D()
# if all objects are from the same backend
# use the 3D one as the handler
else:
for obj in objects:
if isinstance(obj, Vector3D):
handler = obj
handler = _demote_handler_vector(
handler, objects, Vector3D, handler.to_Vector3D()
)

return handler

Expand Down

0 comments on commit 411bb92

Please sign in to comment.