Skip to content

Commit

Permalink
[pre-commit.ci] pre-commit autoupdate (#46)
Browse files Browse the repository at this point in the history
<!--pre-commit.ci start-->
updates:
- [github.com/psf/black: 23.12.1 →
24.2.0](psf/black@23.12.1...24.2.0)
<!--pre-commit.ci end-->

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
pre-commit-ci[bot] authored Feb 14, 2024
1 parent a8e44b7 commit 6e35b6e
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ repos:
args: ["--profile", "black", "--filter-files"]

- repo: https://github.com/psf/black
rev: 23.12.1
rev: 24.2.0
hooks:
- id: black
stages: [commit]
Expand Down
8 changes: 2 additions & 6 deletions test/test_grad/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@
tol = 1e-8


def gradchecker(
dtype: torch.dtype, name: str
) -> tuple[
def gradchecker(dtype: torch.dtype, name: str) -> tuple[
Callable[[Tensor, Tensor, Tensor, Tensor], Tensor], # autograd function
tuple[Tensor, Tensor, Tensor, Tensor], # differentiable variables
]:
Expand Down Expand Up @@ -85,9 +83,7 @@ def test_gradgradcheck(dtype: torch.dtype, name: str) -> None:
assert dgradgradcheck(func, diffvars, atol=tol, fast_mode=FAST_MODE)


def gradchecker_batch(
dtype: torch.dtype, name1: str, name2: str
) -> tuple[
def gradchecker_batch(dtype: torch.dtype, name1: str, name2: str) -> tuple[
Callable[[Tensor, Tensor, Tensor, Tensor], Tensor], # autograd function
tuple[Tensor, Tensor, Tensor, Tensor], # differentiable variables
]:
Expand Down
8 changes: 2 additions & 6 deletions test/test_grad/test_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@
tol = 1e-8


def gradchecker(
dtype: torch.dtype, name: str
) -> tuple[
def gradchecker(dtype: torch.dtype, name: str) -> tuple[
Callable[[Tensor], Tensor], # autograd function
Tensor, # differentiable variables
]:
Expand Down Expand Up @@ -86,9 +84,7 @@ def test_gradgradcheck(dtype: torch.dtype, name: str) -> None:
assert dgradgradcheck(func, diffvars, atol=tol, fast_mode=FAST_MODE)


def gradchecker_batch(
dtype: torch.dtype, name1: str, name2: str
) -> tuple[
def gradchecker_batch(dtype: torch.dtype, name1: str, name2: str) -> tuple[
Callable[[Tensor], Tensor], # autograd function
Tensor, # differentiable variables
]:
Expand Down

0 comments on commit 6e35b6e

Please sign in to comment.