From 5c615651f93e4a7397acdfe3f732447b04224fdd Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Tue, 25 Jul 2023 12:49:42 +0200 Subject: [PATCH] Fix types --- .pre-commit-config.yaml | 9 ++++++++- tests/utils.py | 10 ++++++---- tox.ini | 2 +- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6bef0c8..9d4d616 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,14 @@ repos: rev: v2.3.0 hooks: - id: setup-cfg-fmt - args: [--include-version-classifiers, --max-py-version, "3.11"] + args: + [ + --include-version-classifiers, + --min-py-version, + "3.8", + --max-py-version, + "3.11", + ] - repo: https://github.com/asottile/pyupgrade rev: v3.7.0 diff --git a/tests/utils.py b/tests/utils.py index 8a9906a..26bb3ec 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -23,10 +23,12 @@ Any, Callable, Dict, + Optional, Protocol, Size, Tensor, TensorOrTensors, + Union, ) from .conftest import FAST_MODE @@ -112,7 +114,7 @@ class _GradcheckFunction(Protocol): Type annotation for gradcheck function. """ - def __call__( + def __call__( # type: ignore self, func: Callable[..., TensorOrTensors], inputs: TensorOrTensors, @@ -139,11 +141,11 @@ class _GradgradcheckFunction(Protocol): Type annotation for gradgradcheck function. """ - def __call__( + def __call__( # type: ignore self, func: Callable[..., TensorOrTensors], inputs: TensorOrTensors, - grad_outputs: TensorOrTensors | None = None, + grad_outputs: Optional[TensorOrTensors] = None, *, eps: float = 1e-6, atol: float = 1e-5, @@ -162,7 +164,7 @@ def __call__( def _wrap_gradcheck( - gradcheck_func: _GradcheckFunction | _GradgradcheckFunction, + gradcheck_func: Union[_GradcheckFunction, _GradgradcheckFunction], func: Callable[..., TensorOrTensors], diffvars: TensorOrTensors, **kwargs: Any, diff --git a/tox.ini b/tox.ini index a62d990..f01a4cd 100644 --- a/tox.ini +++ b/tox.ini @@ -23,6 +23,6 @@ deps = pytest-random-order commands = coverage erase - coverage run -m pytest -svv {posargs:--random-order-bucket=global tests} + coverage run -m pytest -vv {posargs:--random-order-bucket=global tests} coverage report -m coverage xml -o coverage.xml