From 8a403f1e2a7f48d3c8218dfa936b2c230102b7ab Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Tue, 25 Jul 2023 12:30:36 +0200 Subject: [PATCH] Improve teardown of gradchecks --- src/tad_dftd3/_typing.py | 1 + tests/conftest.py | 25 ++++++- tests/test_grad/test_param.py | 10 +-- tests/test_grad/test_pos.py | 10 +-- tests/utils.py | 130 +++++++++++++++++++++++++++++++++- 5 files changed, 163 insertions(+), 13 deletions(-) diff --git a/src/tad_dftd3/_typing.py b/src/tad_dftd3/_typing.py index 3bd5238..c408ea3 100644 --- a/src/tad_dftd3/_typing.py +++ b/src/tad_dftd3/_typing.py @@ -24,6 +24,7 @@ List, NoReturn, Optional, + Protocol, Tuple, TypedDict, Union, diff --git a/tests/conftest.py b/tests/conftest.py index 6eb7ddf..3a4f0eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,9 @@ torch.set_printoptions(precision=10) +FAST_MODE: bool = True +"""Flag for fast gradient tests.""" + def pytest_addoption(parser: pytest.Parser) -> None: """Set up additional command line options.""" @@ -42,6 +45,18 @@ def pytest_addoption(parser: pytest.Parser) -> None: help="Enable JIT during tests (default = False).", ) + parser.addoption( + "--fast", + action="store_true", + help="Use `fast_mode` for gradient checks (default = True).", + ) + + parser.addoption( + "--slow", + action="store_true", + help="Do *not* use `fast_mode` for gradient checks (default = False).", + ) + parser.addoption( "--tpo-linewidth", action="store", @@ -83,9 +98,15 @@ def pytest_configure(config: pytest.Config) -> None: torch.autograd.anomaly_mode.set_detect_anomaly(True) if config.getoption("--jit"): - torch.jit._state.enable() # type: ignore + torch.jit._state.enable() # type: ignore # pylint: disable=protected-access else: - torch.jit._state.disable() # type: ignore + torch.jit._state.disable() # type: ignore # pylint: disable=protected-access + + global FAST_MODE + if config.getoption("--fast"): + FAST_MODE = True + if config.getoption("--slow"): + FAST_MODE = False if config.getoption("--tpo-linewidth"): torch.set_printoptions(linewidth=config.getoption("--tpo-linewidth")) diff --git a/tests/test_grad/test_param.py b/tests/test_grad/test_param.py index b641d9a..23ea414 100644 --- a/tests/test_grad/test_param.py +++ b/tests/test_grad/test_param.py @@ -19,12 +19,12 @@ import pytest import torch -from torch.autograd.gradcheck import gradcheck, gradgradcheck from tad_dftd3 import dftd3, util from tad_dftd3._typing import Callable, Tensor, Tuple from ..samples import samples +from ..utils import dgradcheck, dgradgradcheck sample_list = ["LiH", "SiH4", "MB16_43_01"] @@ -66,7 +66,7 @@ def test_gradcheck(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradchecker(dtype, name) - assert gradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol) @pytest.mark.grad @@ -78,7 +78,7 @@ def test_gradgradcheck(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradchecker(dtype, name) - assert gradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol) def gradchecker_batch( @@ -127,7 +127,7 @@ def test_gradcheck_batch(dtype: torch.dtype, name1: str, name2: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradchecker_batch(dtype, name1, name2) - assert gradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol) @pytest.mark.grad @@ -140,4 +140,4 @@ def test_gradgradcheck_batch(dtype: torch.dtype, name1: str, name2: str) -> None gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradchecker_batch(dtype, name1, name2) - assert gradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol) diff --git a/tests/test_grad/test_pos.py b/tests/test_grad/test_pos.py index 1cc94cf..6ec176f 100644 --- a/tests/test_grad/test_pos.py +++ b/tests/test_grad/test_pos.py @@ -19,12 +19,12 @@ import pytest import torch -from torch.autograd.gradcheck import gradcheck, gradgradcheck from tad_dftd3 import dftd3, util from tad_dftd3._typing import Callable, Tensor, Tuple from ..samples import samples +from ..utils import dgradcheck, dgradgradcheck sample_list = ["LiH", "SiH4", "MB16_43_01"] @@ -66,7 +66,7 @@ def test_gradcheck(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradchecker(dtype, name) - assert gradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol) diffvars.detach_() @@ -80,7 +80,7 @@ def test_gradgradcheck(dtype: torch.dtype, name: str) -> None: gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradchecker(dtype, name) - assert gradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol) diffvars.detach_() @@ -131,7 +131,7 @@ def test_gradcheck_batch(dtype: torch.dtype, name1: str, name2: str) -> None: gradient from `torch.autograd.gradcheck`. """ func, diffvars = gradchecker_batch(dtype, name1, name2) - assert gradcheck(func, diffvars, atol=tol) + assert dgradcheck(func, diffvars, atol=tol) diffvars.detach_() @@ -146,7 +146,7 @@ def test_gradgradcheck_batch(dtype: torch.dtype, name1: str, name2: str) -> None gradient from `torch.autograd.gradgradcheck`. """ func, diffvars = gradchecker_batch(dtype, name1, name2) - assert gradgradcheck(func, diffvars, atol=tol) + assert dgradgradcheck(func, diffvars, atol=tol) diffvars.detach_() diff --git a/tests/utils.py b/tests/utils.py index 090667d..8a9906a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,8 +17,19 @@ """ import torch +from torch.autograd.gradcheck import gradcheck, gradgradcheck -from tad_dftd3._typing import Dict, Size, Tensor +from tad_dftd3._typing import ( + Any, + Callable, + Dict, + Protocol, + Size, + Tensor, + TensorOrTensors, +) + +from .conftest import FAST_MODE def merge_nested_dicts(a: Dict[str, Dict], b: Dict[str, Dict]) -> Dict: # type: ignore[type-arg] @@ -94,3 +105,120 @@ def reshape_fortran(x: Tensor, shape: Size) -> Tensor: if len(x.shape) > 0: x = x.permute(*reversed(range(len(x.shape)))) return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape)))) + + +class _GradcheckFunction(Protocol): + """ + Type annotation for gradcheck function. + """ + + def __call__( + self, + func: Callable[..., TensorOrTensors], + inputs: TensorOrTensors, + *, + eps: float = 1e-6, + atol: float = 1e-5, + rtol: float = 1e-3, + raise_exception: bool = True, + check_sparse_nnz: bool = False, + nondet_tol: float = 0.0, + check_undefined_grad: bool = True, + check_grad_dtypes: bool = False, + check_batched_grad: bool = False, + check_batched_forward_grad: bool = False, + check_forward_ad: bool = False, + check_backward_ad: bool = True, + fast_mode: bool = False, + ) -> bool: + ... + + +class _GradgradcheckFunction(Protocol): + """ + Type annotation for gradgradcheck function. + """ + + def __call__( + self, + func: Callable[..., TensorOrTensors], + inputs: TensorOrTensors, + grad_outputs: TensorOrTensors | None = None, + *, + eps: float = 1e-6, + atol: float = 1e-5, + rtol: float = 1e-3, + gen_non_contig_grad_outputs: bool = False, + raise_exception: bool = True, + nondet_tol: float = 0.0, + check_undefined_grad: bool = True, + check_grad_dtypes: bool = False, + check_batched_grad: bool = False, + check_fwd_over_rev: bool = False, + check_rev_over_rev: bool = True, + fast_mode: bool = False, + ) -> bool: + ... + + +def _wrap_gradcheck( + gradcheck_func: _GradcheckFunction | _GradgradcheckFunction, + func: Callable[..., TensorOrTensors], + diffvars: TensorOrTensors, + **kwargs: Any, +) -> bool: + fast_mode = kwargs.pop("fast_mode", FAST_MODE) + try: + assert gradcheck_func(func, diffvars, fast_mode=fast_mode, **kwargs) + finally: + if isinstance(diffvars, Tensor): + diffvars.detach_() + else: + for diffvar in diffvars: + diffvar.detach_() + + return True + + +def dgradcheck( + func: Callable[..., TensorOrTensors], diffvars: TensorOrTensors, **kwargs: Any +) -> bool: + """ + Wrapper for `torch.autograd.gradcheck` that detaches the differentiated + variables after the check. + + Parameters + ---------- + func : Callable[..., TensorOrTensors] + Forward function. + diffvars : TensorOrTensors + Variables w.r.t. which we differentiate. + + Returns + ------- + bool + Status of check. + """ + return _wrap_gradcheck(gradcheck, func, diffvars, **kwargs) + + +def dgradgradcheck( + func: Callable[..., TensorOrTensors], diffvars: TensorOrTensors, **kwargs: Any +) -> bool: + """ + Wrapper for `torch.autograd.gradgradcheck` that detaches the differentiated + variables after the check. + + Parameters + ---------- + func : Callable[..., TensorOrTensors] + Forward function. + diffvars : TensorOrTensors + Variables w.r.t. which we differentiate. + + Returns + ------- + bool + Status of check. + """ + return _wrap_gradcheck(gradgradcheck, func, diffvars, **kwargs)