diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6bef0c8..9d4d616 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,14 @@ repos: rev: v2.3.0 hooks: - id: setup-cfg-fmt - args: [--include-version-classifiers, --max-py-version, "3.11"] + args: + [ + --include-version-classifiers, + --min-py-version, + "3.8", + --max-py-version, + "3.11", + ] - repo: https://github.com/asottile/pyupgrade rev: v3.7.0 diff --git a/tests/utils.py b/tests/utils.py index 8a9906a..26bb3ec 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -23,10 +23,12 @@ Any, Callable, Dict, + Optional, Protocol, Size, Tensor, TensorOrTensors, + Union, ) from .conftest import FAST_MODE @@ -112,7 +114,7 @@ class _GradcheckFunction(Protocol): Type annotation for gradcheck function. """ - def __call__( + def __call__( # type: ignore self, func: Callable[..., TensorOrTensors], inputs: TensorOrTensors, @@ -139,11 +141,11 @@ class _GradgradcheckFunction(Protocol): Type annotation for gradgradcheck function. """ - def __call__( + def __call__( # type: ignore self, func: Callable[..., TensorOrTensors], inputs: TensorOrTensors, - grad_outputs: TensorOrTensors | None = None, + grad_outputs: Optional[TensorOrTensors] = None, *, eps: float = 1e-6, atol: float = 1e-5, @@ -162,7 +164,7 @@ def __call__( def _wrap_gradcheck( - gradcheck_func: _GradcheckFunction | _GradgradcheckFunction, + gradcheck_func: Union[_GradcheckFunction, _GradgradcheckFunction], func: Callable[..., TensorOrTensors], diffvars: TensorOrTensors, **kwargs: Any, diff --git a/tox.ini b/tox.ini index a62d990..f01a4cd 100644 --- a/tox.ini +++ b/tox.ini @@ -23,6 +23,6 @@ deps = pytest-random-order commands = coverage erase - coverage run -m pytest -svv {posargs:--random-order-bucket=global tests} + coverage run -m pytest -vv {posargs:--random-order-bucket=global tests} coverage report -m coverage xml -o coverage.xml