From f02e2f0186a478a5b97563d03fd6ed3d38270241 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sinclert=20P=C3=A9rez?= Date: Thu, 29 Jul 2021 17:05:05 +0200 Subject: [PATCH] feat: allow overriding of protocol methods in subclasses (#128) * Add _handler_of func. MRO classes logic * Add _handler_of func. test * Invert _handler_of func. logic * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Tune _get_handler_index func. with Jim comment * Replace `-1`/`break` logic with an early return. Co-authored-by: Henry Schreiner * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jim Pivarski Co-authored-by: Henry Schreiner --- src/vector/_methods.py | 24 +++++++++++++++++------- tests/test_methods.py | 17 +++++++++++++++++ 2 files changed, 34 insertions(+), 7 deletions(-) create mode 100644 tests/test_methods.py diff --git a/src/vector/_methods.py b/src/vector/_methods.py index 46007f85..7c3bf7ab 100644 --- a/src/vector/_methods.py +++ b/src/vector/_methods.py @@ -4,6 +4,7 @@ # or https://github.com/scikit-hep/vector for details. import typing +from contextlib import suppress from vector._typeutils import ( BoolCollection, @@ -2533,6 +2534,16 @@ def _from_signature( ] +def _get_handler_index(obj: VectorProtocol) -> int: + """Returns the index of the first valid handler checking the list of parent classes""" + for cls in type(obj).__mro__: + with suppress(ValueError): + return _handler_priority.index(cls.__module__) + raise AssertionError( + f"Could not find a valid handler for {obj}! This should not happen." + ) + + def _handler_of(*objects: VectorProtocol) -> VectorProtocol: """ Determines which vector should wrap the output of a dispatched function. @@ -2544,13 +2555,12 @@ def _handler_of(*objects: VectorProtocol) -> VectorProtocol: """ handler = None for obj in objects: - if isinstance(obj, Vector): - if handler is None: - handler = obj - elif _handler_priority.index( - type(obj).__module__ - ) > _handler_priority.index(type(handler).__module__): - handler = obj + if not isinstance(obj, Vector): + continue + if handler is None: + handler = obj + elif _get_handler_index(obj) > _get_handler_index(handler): + handler = obj assert handler is not None return handler diff --git a/tests/test_methods.py b/tests/test_methods.py new file mode 100644 index 00000000..17a9dd16 --- /dev/null +++ b/tests/test_methods.py @@ -0,0 +1,17 @@ +# Copyright (c) 2019-2021, Jonas Eschle, Jim Pivarski, Eduardo Rodrigues, and Henry Schreiner. +# +# Distributed under the 3-clause BSD license, see accompanying file LICENSE +# or https://github.com/scikit-hep/vector for details. + +import vector + + +class CustomVector(vector.VectorObject4D): + pass + + +def test_handler_of(): + object_a = CustomVector.from_xyzt(0.0, 0.0, 0.0, 0.0) + object_b = CustomVector.from_xyzt(1.0, 1.0, 1.0, 1.0) + protocol = vector._methods._handler_of(object_a, object_b) + assert protocol == object_a