Skip to content

Commit

Permalink
feat: allow overriding of protocol methods in subclasses (#128)
Browse files Browse the repository at this point in the history
* Add _handler_of func. MRO classes logic

* Add _handler_of func. test

* Invert _handler_of func. logic

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Tune _get_handler_index func. with Jim comment

* Replace `-1`/`break` logic with an early return.

Co-authored-by: Henry Schreiner <HenrySchreinerIII@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jim Pivarski <jpivarski@users.noreply.github.com>
Co-authored-by: Henry Schreiner <HenrySchreinerIII@gmail.com>
  • Loading branch information
4 people authored Jul 29, 2021
1 parent 5ba1115 commit f02e2f0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
24 changes: 17 additions & 7 deletions src/vector/_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# or https://github.com/scikit-hep/vector for details.

import typing
from contextlib import suppress

from vector._typeutils import (
BoolCollection,
Expand Down Expand Up @@ -2533,6 +2534,16 @@ def _from_signature(
]


def _get_handler_index(obj: VectorProtocol) -> int:
"""Returns the index of the first valid handler checking the list of parent classes"""
for cls in type(obj).__mro__:
with suppress(ValueError):
return _handler_priority.index(cls.__module__)
raise AssertionError(
f"Could not find a valid handler for {obj}! This should not happen."
)


def _handler_of(*objects: VectorProtocol) -> VectorProtocol:
"""
Determines which vector should wrap the output of a dispatched function.
Expand All @@ -2544,13 +2555,12 @@ def _handler_of(*objects: VectorProtocol) -> VectorProtocol:
"""
handler = None
for obj in objects:
if isinstance(obj, Vector):
if handler is None:
handler = obj
elif _handler_priority.index(
type(obj).__module__
) > _handler_priority.index(type(handler).__module__):
handler = obj
if not isinstance(obj, Vector):
continue
if handler is None:
handler = obj
elif _get_handler_index(obj) > _get_handler_index(handler):
handler = obj

assert handler is not None
return handler
Expand Down
17 changes: 17 additions & 0 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2019-2021, Jonas Eschle, Jim Pivarski, Eduardo Rodrigues, and Henry Schreiner.
#
# Distributed under the 3-clause BSD license, see accompanying file LICENSE
# or https://github.com/scikit-hep/vector for details.

import vector


class CustomVector(vector.VectorObject4D):
pass


def test_handler_of():
object_a = CustomVector.from_xyzt(0.0, 0.0, 0.0, 0.0)
object_b = CustomVector.from_xyzt(1.0, 1.0, 1.0, 1.0)
protocol = vector._methods._handler_of(object_a, object_b)
assert protocol == object_a

0 comments on commit f02e2f0

Please sign in to comment.