Skip to content

Commit

Permalink
Improve teardown of gradchecks
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Jul 25, 2023
1 parent fe01d91 commit 8a403f1
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/tad_dftd3/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
List,
NoReturn,
Optional,
Protocol,
Tuple,
TypedDict,
Union,
Expand Down
25 changes: 23 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@

torch.set_printoptions(precision=10)

FAST_MODE: bool = True
"""Flag for fast gradient tests."""


def pytest_addoption(parser: pytest.Parser) -> None:
"""Set up additional command line options."""
Expand All @@ -42,6 +45,18 @@ def pytest_addoption(parser: pytest.Parser) -> None:
help="Enable JIT during tests (default = False).",
)

parser.addoption(
"--fast",
action="store_true",
help="Use `fast_mode` for gradient checks (default = True).",
)

parser.addoption(
"--slow",
action="store_true",
help="Do *not* use `fast_mode` for gradient checks (default = False).",
)

parser.addoption(
"--tpo-linewidth",
action="store",
Expand Down Expand Up @@ -83,9 +98,15 @@ def pytest_configure(config: pytest.Config) -> None:
torch.autograd.anomaly_mode.set_detect_anomaly(True)

if config.getoption("--jit"):
torch.jit._state.enable() # type: ignore
torch.jit._state.enable() # type: ignore # pylint: disable=protected-access
else:
torch.jit._state.disable() # type: ignore
torch.jit._state.disable() # type: ignore # pylint: disable=protected-access

global FAST_MODE
if config.getoption("--fast"):
FAST_MODE = True
if config.getoption("--slow"):
FAST_MODE = False

if config.getoption("--tpo-linewidth"):
torch.set_printoptions(linewidth=config.getoption("--tpo-linewidth"))
Expand Down
10 changes: 5 additions & 5 deletions tests/test_grad/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

import pytest
import torch
from torch.autograd.gradcheck import gradcheck, gradgradcheck

from tad_dftd3 import dftd3, util
from tad_dftd3._typing import Callable, Tensor, Tuple

from ..samples import samples
from ..utils import dgradcheck, dgradgradcheck

sample_list = ["LiH", "SiH4", "MB16_43_01"]

Expand Down Expand Up @@ -66,7 +66,7 @@ def test_gradcheck(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradchecker(dtype, name)
assert gradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol)


@pytest.mark.grad
Expand All @@ -78,7 +78,7 @@ def test_gradgradcheck(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradchecker(dtype, name)
assert gradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol)


def gradchecker_batch(
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_gradcheck_batch(dtype: torch.dtype, name1: str, name2: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradchecker_batch(dtype, name1, name2)
assert gradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol)


@pytest.mark.grad
Expand All @@ -140,4 +140,4 @@ def test_gradgradcheck_batch(dtype: torch.dtype, name1: str, name2: str) -> None
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradchecker_batch(dtype, name1, name2)
assert gradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol)
10 changes: 5 additions & 5 deletions tests/test_grad/test_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

import pytest
import torch
from torch.autograd.gradcheck import gradcheck, gradgradcheck

from tad_dftd3 import dftd3, util
from tad_dftd3._typing import Callable, Tensor, Tuple

from ..samples import samples
from ..utils import dgradcheck, dgradgradcheck

sample_list = ["LiH", "SiH4", "MB16_43_01"]

Expand Down Expand Up @@ -66,7 +66,7 @@ def test_gradcheck(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradchecker(dtype, name)
assert gradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol)

diffvars.detach_()

Expand All @@ -80,7 +80,7 @@ def test_gradgradcheck(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradchecker(dtype, name)
assert gradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol)

diffvars.detach_()

Expand Down Expand Up @@ -131,7 +131,7 @@ def test_gradcheck_batch(dtype: torch.dtype, name1: str, name2: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradchecker_batch(dtype, name1, name2)
assert gradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol)

diffvars.detach_()

Expand All @@ -146,7 +146,7 @@ def test_gradgradcheck_batch(dtype: torch.dtype, name1: str, name2: str) -> None
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradchecker_batch(dtype, name1, name2)
assert gradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol)

diffvars.detach_()

Expand Down
130 changes: 129 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,19 @@
"""

import torch
from torch.autograd.gradcheck import gradcheck, gradgradcheck

from tad_dftd3._typing import Dict, Size, Tensor
from tad_dftd3._typing import (
Any,
Callable,
Dict,
Protocol,
Size,
Tensor,
TensorOrTensors,
)

from .conftest import FAST_MODE


def merge_nested_dicts(a: Dict[str, Dict], b: Dict[str, Dict]) -> Dict: # type: ignore[type-arg]
Expand Down Expand Up @@ -94,3 +105,120 @@ def reshape_fortran(x: Tensor, shape: Size) -> Tensor:
if len(x.shape) > 0:
x = x.permute(*reversed(range(len(x.shape))))
return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))


class _GradcheckFunction(Protocol):
"""
Type annotation for gradcheck function.
"""

def __call__(
self,
func: Callable[..., TensorOrTensors],
inputs: TensorOrTensors,
*,
eps: float = 1e-6,
atol: float = 1e-5,
rtol: float = 1e-3,
raise_exception: bool = True,
check_sparse_nnz: bool = False,
nondet_tol: float = 0.0,
check_undefined_grad: bool = True,
check_grad_dtypes: bool = False,
check_batched_grad: bool = False,
check_batched_forward_grad: bool = False,
check_forward_ad: bool = False,
check_backward_ad: bool = True,
fast_mode: bool = False,
) -> bool:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note test

This statement has no effect.


class _GradgradcheckFunction(Protocol):
"""
Type annotation for gradgradcheck function.
"""

def __call__(
self,
func: Callable[..., TensorOrTensors],
inputs: TensorOrTensors,
grad_outputs: TensorOrTensors | None = None,
*,
eps: float = 1e-6,
atol: float = 1e-5,
rtol: float = 1e-3,
gen_non_contig_grad_outputs: bool = False,
raise_exception: bool = True,
nondet_tol: float = 0.0,
check_undefined_grad: bool = True,
check_grad_dtypes: bool = False,
check_batched_grad: bool = False,
check_fwd_over_rev: bool = False,
check_rev_over_rev: bool = True,
fast_mode: bool = False,
) -> bool:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note test

This statement has no effect.


def _wrap_gradcheck(
gradcheck_func: _GradcheckFunction | _GradgradcheckFunction,
func: Callable[..., TensorOrTensors],
diffvars: TensorOrTensors,
**kwargs: Any,
) -> bool:
fast_mode = kwargs.pop("fast_mode", FAST_MODE)
try:
assert gradcheck_func(func, diffvars, fast_mode=fast_mode, **kwargs)
finally:
if isinstance(diffvars, Tensor):
diffvars.detach_()
else:
for diffvar in diffvars:
diffvar.detach_()

return True


def dgradcheck(
func: Callable[..., TensorOrTensors], diffvars: TensorOrTensors, **kwargs: Any
) -> bool:
"""
Wrapper for `torch.autograd.gradcheck` that detaches the differentiated
variables after the check.
Parameters
----------
func : Callable[..., TensorOrTensors]
Forward function.
diffvars : TensorOrTensors
Variables w.r.t. which we differentiate.
Returns
-------
bool
Status of check.
"""
return _wrap_gradcheck(gradcheck, func, diffvars, **kwargs)


def dgradgradcheck(
func: Callable[..., TensorOrTensors], diffvars: TensorOrTensors, **kwargs: Any
) -> bool:
"""
Wrapper for `torch.autograd.gradgradcheck` that detaches the differentiated
variables after the check.
Parameters
----------
func : Callable[..., TensorOrTensors]
Forward function.
diffvars : TensorOrTensors
Variables w.r.t. which we differentiate.
Returns
-------
bool
Status of check.
"""
return _wrap_gradcheck(gradgradcheck, func, diffvars, **kwargs)

0 comments on commit 8a403f1

Please sign in to comment.