Skip to content

Commit

Permalink
Fix types
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Jul 25, 2023
1 parent 8a403f1 commit 5c61565
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
9 changes: 8 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@ repos:
rev: v2.3.0
hooks:
- id: setup-cfg-fmt
args: [--include-version-classifiers, --max-py-version, "3.11"]
args:
[
--include-version-classifiers,
--min-py-version,
"3.8",
--max-py-version,
"3.11",
]

- repo: https://github.com/asottile/pyupgrade
rev: v3.7.0
Expand Down
10 changes: 6 additions & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
Any,
Callable,
Dict,
Optional,
Protocol,
Size,
Tensor,
TensorOrTensors,
Union,
)

from .conftest import FAST_MODE
Expand Down Expand Up @@ -112,7 +114,7 @@ class _GradcheckFunction(Protocol):
Type annotation for gradcheck function.
"""

def __call__(
def __call__( # type: ignore
self,
func: Callable[..., TensorOrTensors],
inputs: TensorOrTensors,
Expand All @@ -139,11 +141,11 @@ class _GradgradcheckFunction(Protocol):
Type annotation for gradgradcheck function.
"""

def __call__(
def __call__( # type: ignore
self,
func: Callable[..., TensorOrTensors],
inputs: TensorOrTensors,
grad_outputs: TensorOrTensors | None = None,
grad_outputs: Optional[TensorOrTensors] = None,
*,
eps: float = 1e-6,
atol: float = 1e-5,
Expand All @@ -162,7 +164,7 @@ def __call__(


def _wrap_gradcheck(
gradcheck_func: _GradcheckFunction | _GradgradcheckFunction,
gradcheck_func: Union[_GradcheckFunction, _GradgradcheckFunction],
func: Callable[..., TensorOrTensors],
diffvars: TensorOrTensors,
**kwargs: Any,
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ deps =
pytest-random-order
commands =
coverage erase
coverage run -m pytest -svv {posargs:--random-order-bucket=global tests}
coverage run -m pytest -vv {posargs:--random-order-bucket=global tests}
coverage report -m
coverage xml -o coverage.xml

0 comments on commit 5c61565

Please sign in to comment.