From 346af9a7069c5564b7c55a2bd16002ef39ecc6bd Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Sun, 26 Nov 2023 17:18:55 +0100 Subject: [PATCH] Refactor and GPU tests --- src/tad_dftd4/__version__.py | 1 - src/tad_dftd4/_typing.py | 34 ++++- src/tad_dftd4/charges.py | 34 +++-- src/tad_dftd4/damping/atm.py | 58 +++++--- src/tad_dftd4/defaults.py | 14 +- src/tad_dftd4/disp.py | 20 +-- src/tad_dftd4/ncoord/__init__.py | 7 +- src/tad_dftd4/ncoord/count.py | 3 +- src/tad_dftd4/ncoord/d4.py | 4 +- src/tad_dftd4/ncoord/eeq.py | 4 +- src/tad_dftd4/utils/__init__.py | 26 ++++ src/tad_dftd4/utils/distance.py | 151 +++++++++++++++++++ src/tad_dftd4/utils/grad.py | 156 ++++++++++++++++++++ src/tad_dftd4/{utils.py => utils/misc.py} | 56 ++++--- test/conftest.py | 125 +++++++++++++++- test/molecules.py | 24 --- test/test_charge/samples.py | 3 +- test/test_charge/test_charges.py | 93 ++++++------ test/test_charge/test_general.py | 6 +- test/test_cutoff/test_types.py | 18 +-- test/test_disp/samples.py | 3 +- test/test_disp/test_atm.py | 68 ++++----- test/test_disp/test_full.py | 62 ++++---- test/test_disp/test_grad.py | 45 +++--- test/test_disp/test_twobody.py | 116 ++++++++------- test/test_model/samples.py | 4 +- test/test_model/test_c6.py | 32 ++-- test/test_model/test_model.py | 44 +++--- test/test_model/test_weights.py | 47 +++--- test/test_ncoord/samples.py | 3 +- test/test_ncoord/test_cn_d4.py | 36 +++-- test/test_ncoord/test_cn_eeq.py | 47 +++--- test/test_ncoord/test_general.py | 6 +- test/test_ncoord/test_grad.py | 16 +- test/test_utils/test_cdist.py | 102 +++++++++++++ test/test_utils/test_real.py | 99 +++++++++++-- test/utils.py | 171 +++++++++++++++++++++- 37 files changed, 1305 insertions(+), 433 deletions(-) create mode 100644 src/tad_dftd4/utils/__init__.py create mode 100644 src/tad_dftd4/utils/distance.py create mode 100644 src/tad_dftd4/utils/grad.py rename src/tad_dftd4/{utils.py => utils/misc.py} (68%) create mode 100644 test/test_utils/test_cdist.py diff --git a/src/tad_dftd4/__version__.py b/src/tad_dftd4/__version__.py index e150122..db0c3d8 100644 --- a/src/tad_dftd4/__version__.py +++ b/src/tad_dftd4/__version__.py @@ -18,5 +18,4 @@ """ Module containing the version string. """ - __version__ = "0.0.4" diff --git a/src/tad_dftd4/_typing.py b/src/tad_dftd4/_typing.py index c3cc506..4fb4d69 100644 --- a/src/tad_dftd4/_typing.py +++ b/src/tad_dftd4/_typing.py @@ -84,11 +84,11 @@ elif sys.version_info >= (3, 8): # in Python 3.8, "from __future__ import annotations" only affects # type annotations not type aliases - from typing import Dict, List, Tuple, Union + from typing import Dict, List, Union, tuple - Sliceable = Union[List[Tensor], Tuple[Tensor, ...]] - Size = Union[List[int], Tuple[int], torch.Size] - TensorOrTensors = Union[List[Tensor], Tuple[Tensor, ...], Tensor] + Sliceable = Union[List[Tensor], tuple[Tensor, ...]] + Size = Union[List[int], tuple[int], torch.Size] + TensorOrTensors = Union[List[Tensor], tuple[Tensor, ...], Tensor] DampingFunction = Callable[[int, Tensor, Tensor, Dict[str, Tensor]], Tensor] else: raise RuntimeError( @@ -107,6 +107,26 @@ class Molecule(TypedDict): """Tensor of 3D coordinates of shape (n, 3)""" +class DD(TypedDict): + """Collection of torch.device and torch.dtype.""" + + device: torch.device | None + """Device on which a tensor lives.""" + + dtype: torch.dtype + """Floating point precision of a tensor.""" + + +def get_default_device() -> torch.device: + """Default device for tensors.""" + return torch.tensor(1.0).device + + +def get_default_dtype() -> torch.dtype: + """Default data type for floating point tensors.""" + return torch.tensor(1.0).dtype + + class TensorLike: """ Provide `device` and `dtype` as well as `to()` and `type()` for other @@ -118,10 +138,8 @@ class TensorLike: def __init__( self, device: torch.device | None = None, dtype: torch.dtype | None = None ): - self.__device = ( - device if device is not None else torch.device(defaults.TORCH_DEVICE) - ) - self.__dtype = dtype if dtype is not None else defaults.TORCH_DTYPE + self.__device = device if device is not None else get_default_device() + self.__dtype = dtype if dtype is not None else get_default_dtype() @property def device(self) -> torch.device: diff --git a/src/tad_dftd4/charges.py b/src/tad_dftd4/charges.py index b274b87..7f8e2e5 100644 --- a/src/tad_dftd4/charges.py +++ b/src/tad_dftd4/charges.py @@ -52,8 +52,8 @@ import torch -from ._typing import Tensor, TensorLike -from .ncoord import get_coordination_number_eeq +from ._typing import DD, Tensor, TensorLike +from .ncoord import coordination_number_eeq from .utils import real_atoms, real_pairs __all__ = ["ChargeModel", "solve", "get_charges"] @@ -93,6 +93,8 @@ def __init__( self.eta = eta self.rad = rad + print(self.device) + print(self.chi.device) if any( tensor.device != self.device for tensor in (self.chi, self.kcn, self.eta, self.rad) @@ -106,20 +108,28 @@ def __init__( raise RuntimeError("All tensors must have the same dtype!") @classmethod - def param2019(cls) -> ChargeModel: + def param2019( + cls, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> ChargeModel: """ Electronegativity equilibration charge model published in - - E. Caldeweyher, S. Ehlert, A. Hansen, H. Neugebauer, S. Spicher, C. Bannwarth - and S. Grimme, *J. Chem. Phys.*, **2019**, 150, 154122. + - E. Caldeweyher, S. Ehlert, A. Hansen, H. Neugebauer, S. Spicher, + C. Bannwarth and S. Grimme, *J. Chem. Phys.*, **2019**, 150, 154122. DOI: `10.1063/1.5090222 `__ """ + dd: dict = {"device": device} + if dtype is not None: + dd["dtype"] = dtype return cls( - _chi2019, - _kcn2019, - _eta2019, - _rad2019, + _chi2019.to(**dd), + _kcn2019.to(**dd), + _eta2019.to(**dd), + _rad2019.to(**dd), + **dd, ) @@ -247,7 +257,7 @@ def solve( Returns ------- (Tensor, Tensor) - Tuple of electrostatic energies and partial charges. + tuple of electrostatic energies and partial charges. Example ------- @@ -371,8 +381,8 @@ def get_charges( Tensor Atomic charges. """ - eeq = ChargeModel.param2019().to(positions.device).type(positions.dtype) - cn = get_coordination_number_eeq(numbers, positions, cutoff=cutoff) + eeq = ChargeModel.param2019(device=positions.device, dtype=positions.dtype) + cn = coordination_number_eeq(numbers, positions, cutoff=cutoff) _, qat = solve(numbers, positions, chrg, eeq, cn) return qat diff --git a/src/tad_dftd4/damping/atm.py b/src/tad_dftd4/damping/atm.py index 353dc1c..f5367ca 100644 --- a/src/tad_dftd4/damping/atm.py +++ b/src/tad_dftd4/damping/atm.py @@ -40,7 +40,7 @@ from .. import defaults from .._typing import Tensor from ..data import r4r2 -from ..utils import real_pairs, real_triples +from ..utils import cdist, real_pairs, real_triples def get_atm_dispersion( @@ -89,15 +89,24 @@ def get_atm_dispersion( cutoff2 = cutoff * cutoff + mask_pairs = real_pairs(numbers, diagonal=False) + mask_triples = real_triples(numbers, diagonal=False, self=False) + + # filler values for masks + eps = torch.tensor(torch.finfo(positions.dtype).eps, **dd) + zero = torch.tensor(0.0, **dd) + one = torch.tensor(1.0, **dd) + # C9_ABC = s9 * sqrt(|C6_AB * C6_AC * C6_BC|) c9 = s9 * torch.sqrt( - torch.abs(c6.unsqueeze(-1) * c6.unsqueeze(-2) * c6.unsqueeze(-3)) + torch.clamp( + torch.abs(c6.unsqueeze(-1) * c6.unsqueeze(-2) * c6.unsqueeze(-3)), min=eps + ) ) - temp = ( - a1 * torch.sqrt(3.0 * r4r2[numbers].unsqueeze(-1) * r4r2[numbers].unsqueeze(-2)) - + a2 - ) + radii = r4r2[numbers].unsqueeze(-1) * r4r2[numbers].unsqueeze(-2) + temp = a1 * torch.sqrt(3.0 * radii) + a2 + r0ij = temp.unsqueeze(-1) r0ik = temp.unsqueeze(-2) r0jk = temp.unsqueeze(-3) @@ -107,11 +116,9 @@ def get_atm_dispersion( # very slow: (pos.unsqueeze(-2) - pos.unsqueeze(-3)).pow(2).sum(-1) distances = torch.pow( torch.where( - real_pairs(numbers, diagonal=False), - torch.cdist( - positions, positions, p=2, compute_mode="use_mm_for_euclid_dist" - ), - torch.tensor(torch.finfo(positions.dtype).eps, **dd), + mask_pairs, + cdist(positions, positions, p=2), + eps, ), 2.0, ) @@ -121,17 +128,30 @@ def get_atm_dispersion( r2jk = distances.unsqueeze(-3) r2 = r2ij * r2ik * r2jk r1 = torch.sqrt(r2) - r3 = r1 * r2 - r5 = r2 * r3 + # add epsilon to avoid zero division later + r3 = torch.where(mask_triples, r1 * r2, eps) + r5 = torch.where(mask_triples, r2 * r3, eps) + + # dividing by tiny numbers leads to huge numbers, which result in NaN's + # upon exponentiation in the subsequent step + base = r0 / torch.where(mask_triples, r1, one) + + # to fix the previous mask, we mask again (not strictly necessary because + # `ang` is also masked and we later multiply with `ang`) + fdamp = torch.where( + mask_triples, + 1.0 / (1.0 + 6.0 * base ** (alp / 3.0)), + zero, + ) - fdamp = 1.0 / (1.0 + 6.0 * (r0 / r1) ** (alp / 3.0)) + s = torch.where( + mask_triples, + (r2ij + r2jk - r2ik) * (r2ij - r2jk + r2ik) * (-r2ij + r2jk + r2ik), + zero, + ) - s = (r2ij + r2jk - r2ik) * (r2ij - r2jk + r2ik) * (-r2ij + r2jk + r2ik) ang = torch.where( - real_triples(numbers, diagonal=False) - * (r2ij <= cutoff2) - * (r2jk <= cutoff2) - * (r2jk <= cutoff2), + mask_triples * (r2ij <= cutoff2) * (r2jk <= cutoff2) * (r2jk <= cutoff2), 0.375 * s / r5 + 1.0 / r3, torch.tensor(0.0, **dd), ) diff --git a/src/tad_dftd4/defaults.py b/src/tad_dftd4/defaults.py index 14e37c3..49f45fc 100644 --- a/src/tad_dftd4/defaults.py +++ b/src/tad_dftd4/defaults.py @@ -7,18 +7,6 @@ import torch -# PyTorch - -TORCH_DTYPE = torch.float32 -"""Default data type for floating point tensors.""" - -TORCH_DTYPE_CHOICES = ["float32", "float64", "double", "sp", "dp"] -"""List of possible choices for `TORCH_DTYPE`.""" - -TORCH_DEVICE = "cpu" -"""Default device for tensors.""" - - # DFT-D4 D4_CN_CUTOFF = 30.0 @@ -48,7 +36,7 @@ D4_K6 = 2 * 11.28174**2 # 254.56 """Parameter for electronegativity scaling.""" -# DFT-D4 damping +# DFT-D4 damping parameters A1 = 0.4 """Scaling for the C8 / C6 ratio in the critical radius (0.4).""" diff --git a/src/tad_dftd4/disp.py b/src/tad_dftd4/disp.py index 46f8494..980880d 100644 --- a/src/tad_dftd4/disp.py +++ b/src/tad_dftd4/disp.py @@ -29,13 +29,13 @@ import torch from . import data, defaults -from ._typing import Any, CountingFunction, DampingFunction, Tensor +from ._typing import DD, Any, CountingFunction, DampingFunction, Tensor from .charges import get_charges from .cutoff import Cutoff from .damping import get_atm_dispersion, rational_damping from .model import D4Model -from .ncoord import erf_count, get_coordination_number_d4 -from .utils import real_pairs +from .ncoord import coordination_number_d4, erf_count +from .utils import cdist, real_pairs def dftd4( @@ -98,15 +98,17 @@ def dftd4( Shape inconsistencies between `numbers`, `positions`, `r4r2`, or, `rcov`. """ + dd: DD = {"device": positions.device, "dtype": positions.dtype} + if model is None: - model = D4Model(numbers, device=positions.device, dtype=positions.dtype) + model = D4Model(numbers, **dd) if cutoff is None: - cutoff = Cutoff(device=positions.device, dtype=positions.dtype) + cutoff = Cutoff(**dd) if rcov is None: - rcov = data.cov_rad_d3[numbers].type(positions.dtype).to(positions.device) + rcov = data.cov_rad_d3.to(**dd)[numbers] if r4r2 is None: - r4r2 = data.r4r2[numbers].type(positions.dtype).to(positions.device) + r4r2 = data.r4r2.to(**dd)[numbers] if q is None: q = get_charges(numbers, positions, charge, cutoff=cutoff.cn_eeq) @@ -131,7 +133,7 @@ def dftd4( f"atomic numbers ({numbers.shape}).", ) - cn = get_coordination_number_d4( + cn = coordination_number_d4( numbers, positions, counting_function, rcov, cutoff=cutoff.cn ) weights = model.weight_references(cn, q) @@ -204,7 +206,7 @@ def dispersion2( mask = real_pairs(numbers, diagonal=False) distances = torch.where( mask, - torch.cdist(positions, positions, p=2, compute_mode="use_mm_for_euclid_dist"), + cdist(positions, positions, p=2), torch.tensor(torch.finfo(positions.dtype).eps, **dd), ) diff --git a/src/tad_dftd4/ncoord/__init__.py b/src/tad_dftd4/ncoord/__init__.py index c5bcdd7..a79afa2 100644 --- a/src/tad_dftd4/ncoord/__init__.py +++ b/src/tad_dftd4/ncoord/__init__.py @@ -59,13 +59,12 @@ ... )) >>> >>> torch.set_printoptions(precision=7) ->>> print(d4.get_coordination_number_d4(numbers, positions)) +>>> print(d4.coordination_number_d4(numbers, positions)) tensor([[2.6886456, 2.6886456, 2.6314170, 2.6314168, 0.8594539, 0.9231414, 0.8605307, 0.8605307, 0.8594539, 0.9231414, 0.8568342, 0.8568342], [2.6886456, 0.8568335, 2.6314168, 0.8605307, 0.8594532, 0.9231415, 0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000]]) """ - from .count import derf_count, dexp_count, erf_count, exp_count -from .d4 import get_coordination_number_d4 -from .eeq import get_coordination_number_eeq +from .d4 import coordination_number_d4 +from .eeq import coordination_number_eeq diff --git a/src/tad_dftd4/ncoord/count.py b/src/tad_dftd4/ncoord/count.py index fce1b39..5f40acd 100644 --- a/src/tad_dftd4/ncoord/count.py +++ b/src/tad_dftd4/ncoord/count.py @@ -27,7 +27,6 @@ Additionally, the analytical derivatives for both counting functions are also provided and can be used for checking the autograd results. """ - from math import pi, sqrt import torch @@ -35,6 +34,8 @@ from .. import defaults from .._typing import Tensor +__all__ = ["exp_count", "dexp_count", "erf_count", "derf_count"] + def exp_count(r: Tensor, r0: Tensor, kcn: float = defaults.D4_KCN) -> Tensor: """ diff --git a/src/tad_dftd4/ncoord/d4.py b/src/tad_dftd4/ncoord/d4.py index a186413..c96d3ce 100644 --- a/src/tad_dftd4/ncoord/d4.py +++ b/src/tad_dftd4/ncoord/d4.py @@ -32,10 +32,10 @@ from ..utils import real_pairs from .count import erf_count -__all__ = ["get_coordination_number_d4"] +__all__ = ["coordination_number_d4"] -def get_coordination_number_d4( +def coordination_number_d4( numbers: Tensor, positions: Tensor, counting_function: CountingFunction = erf_count, diff --git a/src/tad_dftd4/ncoord/eeq.py b/src/tad_dftd4/ncoord/eeq.py index d6b8fac..4b75d57 100644 --- a/src/tad_dftd4/ncoord/eeq.py +++ b/src/tad_dftd4/ncoord/eeq.py @@ -31,10 +31,10 @@ from ..utils import real_pairs from .count import erf_count -__all__ = ["get_coordination_number_eeq"] +__all__ = ["coordination_number_eeq"] -def get_coordination_number_eeq( +def coordination_number_eeq( numbers: Tensor, positions: Tensor, counting_function: CountingFunction = erf_count, diff --git a/src/tad_dftd4/utils/__init__.py b/src/tad_dftd4/utils/__init__.py new file mode 100644 index 0000000..45f0462 --- /dev/null +++ b/src/tad_dftd4/utils/__init__.py @@ -0,0 +1,26 @@ +# This file is part of tad-dftd4. +# +# SPDX-Identifier: LGPL-3.0 +# Copyright (C) 2022 Marvin Friede +# +# tad-dftd4 is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# tad-dftd4 is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with tad-dftd4. If not, see . +""" +Utility +======= + +Collection of utility functions. +""" +from .distance import * +from .grad import * +from .misc import * diff --git a/src/tad_dftd4/utils/distance.py b/src/tad_dftd4/utils/distance.py new file mode 100644 index 0000000..a983e0f --- /dev/null +++ b/src/tad_dftd4/utils/distance.py @@ -0,0 +1,151 @@ +# This file is part of tad-dftd4. +# +# SPDX-Identifier: LGPL-3.0 +# Copyright (C) 2022 Marvin Friede +# +# tad-dftd4 is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# tad-dftd4 is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with tad-dftd4. If not, see . +""" +Utility functions: Distance +=========================== + +Functions for calculating the cartesian distance of two vectors. +""" +import torch + +from .._typing import Tensor + +__all__ = ["cdist"] + + +def euclidean_dist_quadratic_expansion(x: Tensor, y: Tensor) -> Tensor: + """ + Computation of euclidean distance matrix via quadratic expansion (sum of + squared differences or L2-norm of differences). + + While this is significantly faster than the "direct expansion" or + "broadcast" approach, it only works for euclidean (p=2) distances. + Additionally, it has issues with numerical stability (the diagonal slightly + deviates from zero for ``x=y``). The numerical stability should not pose + problems, since we must remove zeros anyway for batched calculations. + + For more information, see \ + `this Jupyter notebook `__ or \ + `this discussion thread in the PyTorch forum `__. + + Parameters + ---------- + x : Tensor + First tensor. + y : Tensor + Second tensor (with same shape as first tensor). + + Returns + ------- + Tensor + Pair-wise distance matrix. + """ + eps = torch.tensor( + torch.finfo(x.dtype).eps, + device=x.device, + dtype=x.dtype, + ) + + # using einsum is slightly faster than `torch.pow(x, 2).sum(-1)` + xnorm = torch.einsum("...ij,...ij->...i", x, x) + ynorm = torch.einsum("...ij,...ij->...i", y, y) + + n = xnorm.unsqueeze(-1) + ynorm.unsqueeze(-2) + + # x @ y.mT + prod = torch.einsum("...ik,...jk->...ij", x, y) + + # important: remove negative values that give NaN in backward + return torch.sqrt(torch.clamp(n - 2.0 * prod, min=eps)) + + +def cdist_direct_expansion(x: Tensor, y: Tensor, p: int = 2) -> Tensor: + """ + Computation of cartesian distance matrix. + + Contrary to `euclidean_dist_quadratic_expansion`, this function allows + arbitrary powers but is considerably slower. + + Parameters + ---------- + x : Tensor + First tensor. + y : Tensor + Second tensor (with same shape as first tensor). + p : int, optional + Power used in the distance evaluation (p-norm). Defaults to 2. + + Returns + ------- + Tensor + Pair-wise distance matrix. + """ + eps = torch.tensor( + torch.finfo(x.dtype).eps, + device=x.device, + dtype=x.dtype, + ) + + # unsqueeze different dimension to create matrix + diff = torch.abs(x.unsqueeze(-2) - y.unsqueeze(-3)) + + # einsum is nearly twice as fast! + if p == 2: + distances = torch.einsum("...ijk,...ijk->...ij", diff, diff) + else: + distances = torch.sum(torch.pow(diff, p), -1) + + return torch.pow(torch.clamp(distances, min=eps), 1.0 / p) + + +def cdist(x: Tensor, y: Tensor | None = None, p: int = 2) -> Tensor: + """ + Wrapper for cartesian distance computation. + + This currently replaces the use of ``torch.cdist``, which does not handle + zeros well and produces nan's in the backward pass. + + Additionally, ``torch.cdist`` does not return zero for distances between + same vectors (see `here + `__). + + Parameters + ---------- + x : Tensor + First tensor. + y : Tensor | None, optional + Second tensor. If no second tensor is given (default), the first tensor + is used as the second tensor, too. + p : int, optional + Power used in the distance evaluation (p-norm). Defaults to 2. + + Returns + ------- + Tensor + Pair-wise distance matrix. + """ + if y is None: + y = x + + # faster + if p == 2: + return euclidean_dist_quadratic_expansion(x, y) + + return cdist_direct_expansion(x, y, p=p) diff --git a/src/tad_dftd4/utils/grad.py b/src/tad_dftd4/utils/grad.py new file mode 100644 index 0000000..2ebbdeb --- /dev/null +++ b/src/tad_dftd4/utils/grad.py @@ -0,0 +1,156 @@ +# This file is part of tad-dftd4. +# SPDX-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utility functions: Gradient +=========================== + +Utilities for calculating gradients and Hessians. +""" +import torch + +from .._typing import Any, Callable, Tensor + +__all__ = ["jac", "hessian"] + + +if torch.__version__ < (2, 0, 0): # type: ignore # pragma: no cover + try: + from functorch import jacrev # type: ignore + except ModuleNotFoundError: + jacrev = None + from torch.autograd.functional import jacobian # type: ignore + +else: # pragma: no cover + from torch.func import jacrev # type: ignore + + +def jac(f: Callable[..., Tensor], argnums: int = 0) -> Any: # pragma: no cover + """ + Wrapper for Jacobian calcluation. + + Parameters + ---------- + f : Callable[[Any], Tensor] + The function whose result is differentiated. + argnums : int, optional + The variable w.r.t. which will be differentiated. Defaults to 0. + """ + + if jacrev is None: + + def wrap(*inps: Any) -> Any: + """ + Wrapper to imitate the calling signature of functorch's `jacrev` + with `torch.autograd.functional.jacobian`. + + Parameters + ---------- + inps : tuple[Any, ...] + The input parameters of the function `f`. + + Returns + ------- + Any + Jacobian function. + + Raises + ------ + RuntimeError + The parameter selected for differentiation (via `argnums`) is + not a tensor. + """ + diffarg = inps[argnums] + if not isinstance(diffarg, Tensor): + raise RuntimeError( + f"The {argnums}'th input parameter must be a tensor but is " + f"of type '{type(diffarg)}'." + ) + + before = inps[:argnums] + after = inps[(argnums + 1) :] + + # `jacobian` only takes tensors, requiring another wrapper than + # passes the non-tensor arguments to the function `f` + def _f(arg: Tensor) -> Tensor: + return f(*(*before, arg, *after)) + + return jacobian(_f, inputs=diffarg) # type: ignore # pylint: disable=used-before-assignment + + return wrap + + return jacrev(f, argnums=argnums) # type: ignore + + +def hessian( + f: Callable[..., Tensor], + inputs: tuple[Any, ...], + argnums: int, + is_batched: bool = False, +) -> Tensor: + """ + Wrapper for Hessian. The Hessian is the Jacobian of the gradient. + + PyTorch, however, suggests calculating the Jacobian of the Jacobian, which + does not yield the correct shape in this case. + + Parameters + ---------- + f : Callable[[Any], Tensor] + The function whose result is differentiated. + inputs : tuple[Any, ...] + The input parameters of `f`. + argnums : int, optional + The variable w.r.t. which will be differentiated. Defaults to 0. + + Returns + ------- + Tensor + The Hessian. + + Raises + ------ + RuntimeError + The parameter selected for differentiation (via `argnums`) is not a + tensor. + """ + + def _grad(*inps: tuple[Any, ...]) -> Tensor: + e = f(*inps).sum() + + if not isinstance(inps[argnums], Tensor): # pragma: no cover + raise RuntimeError( + f"The {argnums}'th input parameter must be a tensor but is of " + f"type '{type(inps[argnums])}'." + ) + + # catch missing gradients + if e.grad_fn is None: + return torch.zeros_like(inps[argnums]) # type: ignore + + (g,) = torch.autograd.grad( + e, + inps[argnums], + create_graph=True, + ) + return g + + _jac = jac(_grad, argnums=argnums) + + if is_batched: + raise NotImplementedError("Batched Hessian not available.") + # dims = tuple(None if x != argnums else 0 for x in range(len(inputs))) + # _jac = torch.func.vmap(_jac, in_dims=dims) + + return _jac(*inputs) # type: ignore diff --git a/src/tad_dftd4/utils.py b/src/tad_dftd4/utils/misc.py similarity index 68% rename from src/tad_dftd4/utils.py rename to src/tad_dftd4/utils/misc.py index cf924bf..b36e788 100644 --- a/src/tad_dftd4/utils.py +++ b/src/tad_dftd4/utils/misc.py @@ -16,8 +16,8 @@ # You should have received a copy of the GNU Lesser General Public License # along with tad-dftd4. If not, see . """ -Miscellaneous functions -======================= +Utility functions: Miscellaneous +================================ Utilities for working with tensors as well as translating between element symbols and atomic numbers. @@ -26,43 +26,48 @@ import torch -from ._typing import Size, Tensor, TensorOrTensors -from .constants import ATOMIC_NUMBER +from .._typing import Size, Tensor, TensorOrTensors +from ..constants import ATOMIC_NUMBER + +__all__ = ["real_atoms", "real_pairs", "real_triples", "pack", "to_number"] def real_atoms(numbers: Tensor) -> Tensor: """ - Generates mask that differentiates real atom and padding. + Create a mask for atoms, discerning padding and actual atoms. + Padding value is zero. Parameters ---------- numbers : Tensor - Atomic numbers of the atoms in the system. + Atomic numbers for all atoms. Returns ------- Tensor - Mask for real atoms. + Mask for atoms that discerns padding and real atoms. """ return numbers != 0 def real_pairs(numbers: Tensor, diagonal: bool = False) -> Tensor: """ - Generates mask that differentiates real atom pairs and padding. + Create a mask for pairs of atoms from atomic numbers, discerning padding + and actual atoms. Padding value is zero. Parameters ---------- numbers : Tensor - Atomic numbers of the atoms in the system. + Atomic numbers for all atoms. diagonal : bool, optional - Whether the diagonal should be masked, i.e. filled with `False`. - Defaults to `False`, i.e., `True` remains on the diagonal for real atoms. + Flag for also writing `False` to the diagonal, i.e., to all pairs + with the same indices. Defaults to `False`, i.e., writing False + to the diagonal. Returns ------- Tensor - Mask for real atom pairs. + Mask for atom pairs that discerns padding and real atoms. """ real = real_atoms(numbers) mask = real.unsqueeze(-2) * real.unsqueeze(-1) @@ -71,27 +76,40 @@ def real_pairs(numbers: Tensor, diagonal: bool = False) -> Tensor: return mask -def real_triples(numbers: Tensor, diagonal: bool = False) -> Tensor: +def real_triples( + numbers: torch.Tensor, diagonal: bool = False, self: bool = True +) -> Tensor: """ - Generates mask that differentiates real atom triples and padding. + Create a mask for triples from atomic numbers. Padding value is zero. Parameters ---------- - numbers : Tensor - Atomic numbers of the atoms in the system. + numbers : torch.Tensor + Atomic numbers for all atoms. diagonal : bool, optional - Whether the diagonal should be masked, i.e. filled with `False`. - Defaults to `False`, i.e., `True` remains on the diagonal for real atoms. + Flag for also writing `False` to the space diagonal, i.e., to all + triples with the same indices. Defaults to `False`, i.e., writing False + to the diagonal. + self : bool, optional + Flag for also writing `False` to all triples where at least two indices + are identical. Defaults to `True`, i.e., not writing `False`. Returns ------- Tensor - Mask for real atom triples. + Mask for triples. """ real = real_pairs(numbers, diagonal=True) mask = real.unsqueeze(-3) * real.unsqueeze(-2) * real.unsqueeze(-1) + if diagonal is False: mask *= ~torch.diag_embed(torch.ones_like(real)) + + if self is False: + mask *= ~torch.diag_embed(torch.ones_like(real), offset=0, dim1=-3, dim2=-2) + mask *= ~torch.diag_embed(torch.ones_like(real), offset=0, dim1=-3, dim2=-1) + mask *= ~torch.diag_embed(torch.ones_like(real), offset=0, dim1=-2, dim2=-1) + return mask diff --git a/test/conftest.py b/test/conftest.py index 6cca76c..a0f0fa6 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -28,32 +28,143 @@ torch.set_printoptions(precision=10) +FAST_MODE: bool = True +"""Flag for fast gradient tests.""" -def pytest_addoption(parser: pytest.Parser): +DEVICE: torch.device | None = None +"""Name of Device.""" + + +def pytest_addoption(parser: pytest.Parser) -> None: """Set up additional command line options.""" + parser.addoption( + "--cuda", + action="store_true", + help="Use GPU as default device.", + ) + parser.addoption( "--detect-anomaly", + "--da", action="store_true", - help="Enable more comprehensive gradient tests.", + help="Enable PyTorch's debug mode for gradient tests.", ) + parser.addoption( + "--jit", + action="store_true", + help="Enable JIT during tests (default = False).", + ) -def pytest_configure(config: pytest.Config): + parser.addoption( + "--fast", + action="store_true", + help="Use `fast_mode` for gradient checks (default = True).", + ) + + parser.addoption( + "--slow", + action="store_true", + help="Do *not* use `fast_mode` for gradient checks (default = False).", + ) + + parser.addoption( + "--tpo-linewidth", + action="store", + default=400, + type=int, + help=( + "The number of characters per line for the purpose of inserting " + "line breaks (default = 80). Thresholded matrices will ignore " + "this parameter." + ), + ) + + parser.addoption( + "--tpo-precision", + action="store", + default=6, + type=int, + help=( + "Number of digits of precision for floating point output " "(default = 4)." + ), + ) + + parser.addoption( + "--tpo-threshold", + action="store", + default=1000, + type=int, + help=( + "Total number of array elements which trigger summarization " + "rather than full `repr` (default = 1000)." + ), + ) + + +def pytest_configure(config: pytest.Config) -> None: """Pytest configuration hook.""" if config.getoption("--detect-anomaly"): torch.autograd.anomaly_mode.set_detect_anomaly(True) + if config.getoption("--jit"): + torch.jit._state.enable() # type: ignore # pylint: disable=protected-access + else: + torch.jit._state.disable() # type: ignore # pylint: disable=protected-access + + global FAST_MODE + if config.getoption("--fast"): + FAST_MODE = True + if config.getoption("--slow"): + FAST_MODE = False + + global DEVICE + if config.getoption("--cuda"): + if not torch.cuda.is_available(): + raise RuntimeError("No cuda devices available.") + + if FAST_MODE is True: + FAST_MODE = False + + from warnings import warn + + warn( + "Fast mode for gradient checks not compatible with GPU " + "execution. Switching to slow mode. Use the '--slow' flag " + "for GPU tests ('--cuda') to avoid this warning." + ) + + DEVICE = torch.device("cuda:0") + torch.use_deterministic_algorithms(False) + + # `torch.set_default_tensor_type` is deprecated since 2.1.0 and version + # 2.0.0 introduces `torch.set_default_device` + if torch.__version__ < (2, 0, 0): # type: ignore + torch.set_default_tensor_type("torch.cuda.FloatTensor") # type: ignore + else: + torch.set_default_device(DEVICE) # type: ignore + else: + torch.use_deterministic_algorithms(True) + DEVICE = None + + if config.getoption("--tpo-linewidth"): + torch.set_printoptions(linewidth=config.getoption("--tpo-linewidth")) + + if config.getoption("--tpo-precision"): + torch.set_printoptions(precision=config.getoption("--tpo-precision")) + + if config.getoption("--tpo-threshold"): + torch.set_printoptions(threshold=config.getoption("--tpo-threshold")) + # register an additional marker config.addinivalue_line("markers", "cuda: mark test that require CUDA.") -def pytest_runtest_setup(item: pytest.Function): +def pytest_runtest_setup(item: pytest.Function) -> None: """Custom marker for tests requiring CUDA.""" for _ in item.iter_markers(name="cuda"): if not torch.cuda.is_available(): - pytest.skip( - "Torch not compiled with CUDA enabled or no CUDA device available." - ) + pytest.skip("Torch not compiled with CUDA or no CUDA device available.") diff --git a/test/molecules.py b/test/molecules.py index 9d74cbd..6913d2e 100644 --- a/test/molecules.py +++ b/test/molecules.py @@ -25,30 +25,6 @@ from tad_dftd4._typing import Molecule from tad_dftd4.utils import to_number - -def merge_nested_dicts(a: dict, b: dict) -> dict: - """ - Merge nested dictionaries. Dictionary `a` remains unaltered, while - the corresponding keys of it are added to `b`. - - Parameters - ---------- - a : dict - First dictionary (not changed). - b : dict - Second dictionary (changed). - - Returns - ------- - dict - Merged dictionary `b`. - """ - for key in b: - if key in a: - b[key].update(a[key]) - return b - - mols: dict[str, Molecule] = { "H": { "numbers": to_number(["H"]), diff --git a/test/test_charge/samples.py b/test/test_charge/samples.py index b112b44..8eb73a4 100644 --- a/test/test_charge/samples.py +++ b/test/test_charge/samples.py @@ -24,7 +24,8 @@ from tad_dftd4._typing import Molecule, Tensor, TypedDict -from ..molecules import merge_nested_dicts, mols +from ..molecules import mols +from ..utils import merge_nested_dicts class Refs(TypedDict): diff --git a/test/test_charge/test_charges.py b/test/test_charge/test_charges.py index be40552..8769cbd 100644 --- a/test/test_charge/test_charges.py +++ b/test/test_charge/test_charges.py @@ -31,54 +31,53 @@ """ from __future__ import annotations -from math import sqrt - import pytest import torch from tad_dftd4 import charges +from tad_dftd4._typing import DD from tad_dftd4.utils import pack +from ..conftest import DEVICE from .samples import samples @pytest.mark.parametrize("dtype", [torch.float, torch.double]) def test_single(dtype: torch.dtype): - tol = sqrt(torch.finfo(dtype).eps) * 10 + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = torch.finfo(dtype).eps ** 0.5 * 10 sample = samples["NH3-dimer"] - numbers = sample["numbers"] - positions = sample["positions"].type(dtype) - total_charge = sample["total_charge"].type(dtype) - qref = sample["q"].type(dtype) - eref = sample["energy"].type(dtype) + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + total_charge = sample["total_charge"].to(**dd) - cn = torch.tensor( - [3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - dtype=dtype, - ) - eeq = charges.ChargeModel.param2019().type(dtype) + qref = sample["q"].to(**dd) + eref = sample["energy"].to(**dd) + + cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], **dd) + eeq = charges.ChargeModel.param2019(**dd) energy, qat = charges.solve(numbers, positions, total_charge, eeq, cn) + tot = torch.sum(qat, -1) assert qat.dtype == energy.dtype == dtype - assert pytest.approx(torch.sum(qat, -1), abs=1e-6) == total_charge - assert pytest.approx(qat, abs=tol) == qref - assert pytest.approx(energy, abs=tol) == eref + assert pytest.approx(total_charge.cpu(), abs=1e-6) == tot.cpu() + assert pytest.approx(qref.cpu(), abs=tol) == qat.cpu() + assert pytest.approx(eref.cpu(), abs=tol) == energy.cpu() @pytest.mark.parametrize("dtype", [torch.float, torch.double]) def test_ghost(dtype: torch.dtype): - tol = sqrt(torch.finfo(dtype).eps) + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = torch.finfo(dtype).eps ** 0.5 sample = samples["NH3-dimer"] - numbers = sample["numbers"].detach() + numbers = sample["numbers"].to(DEVICE).detach() numbers[[1, 5, 6, 7]] = 0 - positions = sample["positions"].type(dtype) - total_charge = sample["total_charge"].type(dtype) - cn = torch.tensor( - [3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - dtype=dtype, - ) + positions = sample["positions"].to(**dd) + total_charge = sample["total_charge"].to(**dd) + cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], **dd) + qref = torch.tensor( [ -0.8189238943, @@ -90,7 +89,7 @@ def test_ghost(dtype: torch.dtype): +0.0000000000, +0.0000000000, ], - dtype=dtype, + **dd, ) eref = torch.tensor( [ @@ -103,20 +102,23 @@ def test_ghost(dtype: torch.dtype): +0.0000000000, +0.0000000000, ], - dtype=dtype, + **dd, ) - eeq = charges.ChargeModel.param2019().type(dtype) + + eeq = charges.ChargeModel.param2019(**dd) energy, qat = charges.solve(numbers, positions, total_charge, eeq, cn) + tot = torch.sum(qat, -1) assert qat.dtype == energy.dtype == dtype - assert pytest.approx(torch.sum(qat, -1), abs=1e-6) == total_charge - assert pytest.approx(qat, abs=tol) == qref - assert pytest.approx(energy, abs=tol) == eref + assert pytest.approx(total_charge.cpu(), abs=1e-6) == tot.cpu() + assert pytest.approx(qref.cpu(), abs=tol) == qat.cpu() + assert pytest.approx(eref.cpu(), abs=tol) == energy.cpu() @pytest.mark.parametrize("dtype", [torch.float, torch.double]) def test_batch(dtype: torch.dtype): - tol = sqrt(torch.finfo(dtype).eps) + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = torch.finfo(dtype).eps ** 0.5 sample1, sample2 = ( samples["PbH4-BiH3"], @@ -124,27 +126,27 @@ def test_batch(dtype: torch.dtype): ) numbers = pack( ( - sample1["numbers"], - sample2["numbers"], + sample1["numbers"].to(DEVICE), + sample2["numbers"].to(DEVICE), ) ) positions = pack( ( - sample1["positions"].type(dtype), - sample2["positions"].type(dtype), + sample1["positions"].to(**dd), + sample2["positions"].to(**dd), ) ) - total_charge = torch.tensor([0.0, 0.0], dtype=dtype) + total_charge = torch.tensor([0.0, 0.0], **dd) eref = pack( ( - sample1["energy"].type(dtype), - sample2["energy"].type(dtype), + sample1["energy"].to(**dd), + sample2["energy"].to(**dd), ) ) qref = pack( ( - sample1["q"].type(dtype), - sample2["q"].type(dtype), + sample1["q"].to(**dd), + sample2["q"].to(**dd), ) ) @@ -191,12 +193,13 @@ def test_batch(dtype: torch.dtype): 0.9939362885, ], ], - dtype=dtype, + **dd, ) - eeq = charges.ChargeModel.param2019().type(dtype) + eeq = charges.ChargeModel.param2019(**dd) energy, qat = charges.solve(numbers, positions, total_charge, eeq, cn) + tot = torch.sum(qat, -1) assert qat.dtype == energy.dtype == dtype - assert pytest.approx(torch.sum(qat, -1), abs=1e-6) == total_charge - assert pytest.approx(qat, abs=tol) == qref - assert pytest.approx(energy, abs=tol) == eref + assert pytest.approx(total_charge.cpu(), abs=1e-6) == tot.cpu() + assert pytest.approx(qref.cpu(), abs=tol) == qat.cpu() + assert pytest.approx(eref.cpu(), abs=tol) == energy.cpu() diff --git a/test/test_charge/test_general.py b/test/test_charge/test_general.py index bf3a98f..68bc193 100644 --- a/test/test_charge/test_general.py +++ b/test/test_charge/test_general.py @@ -84,10 +84,14 @@ def test_init_dtype_fail() -> None: @pytest.mark.cuda def test_init_device_fail() -> None: t = torch.rand(5) + if "cuda" in str(t.device): + t = t.cpu() + elif "cpu" in str(t.device): + t = t.cuda() # all tensor must be on the same device with pytest.raises(RuntimeError): - charges.ChargeModel(t.to("cuda"), t, t, t) + charges.ChargeModel(t, t, t, t) @pytest.mark.cuda diff --git a/test/test_cutoff/test_types.py b/test/test_cutoff/test_types.py index 0e6a939..e08c666 100644 --- a/test/test_cutoff/test_types.py +++ b/test/test_cutoff/test_types.py @@ -30,10 +30,10 @@ def test_defaults(): cutoff = Cutoff() - assert pytest.approx(defaults.D4_DISP2_CUTOFF) == cutoff.disp2 - assert pytest.approx(defaults.D4_DISP3_CUTOFF) == cutoff.disp3 - assert pytest.approx(defaults.D4_CN_CUTOFF) == cutoff.cn - assert pytest.approx(defaults.D4_CN_EEQ_CUTOFF) == cutoff.cn_eeq + assert pytest.approx(defaults.D4_DISP2_CUTOFF) == cutoff.disp2.cpu() + assert pytest.approx(defaults.D4_DISP3_CUTOFF) == cutoff.disp3.cpu() + assert pytest.approx(defaults.D4_CN_CUTOFF) == cutoff.cn.cpu() + assert pytest.approx(defaults.D4_CN_EEQ_CUTOFF) == cutoff.cn_eeq.cpu() def test_tensor(): @@ -45,7 +45,7 @@ def test_tensor(): assert isinstance(cutoff.cn, Tensor) assert isinstance(cutoff.cn_eeq, Tensor) - assert pytest.approx(tmp) == cutoff.disp2 + assert pytest.approx(tmp.cpu()) == cutoff.disp2.cpu() @pytest.mark.parametrize("vals", [(1, 2, -3, 4), (1.0, 2.0, 3.0, -4.0)]) @@ -58,7 +58,7 @@ def test_int_float(vals: tuple[int | float, ...]): assert isinstance(cutoff.cn, Tensor) assert isinstance(cutoff.cn_eeq, Tensor) - assert pytest.approx(vals[0]) == cutoff.disp2 - assert pytest.approx(vals[1]) == cutoff.disp3 - assert pytest.approx(vals[2]) == cutoff.cn - assert pytest.approx(vals[3]) == cutoff.cn_eeq + assert pytest.approx(vals[0]) == cutoff.disp2.cpu() + assert pytest.approx(vals[1]) == cutoff.disp3.cpu() + assert pytest.approx(vals[2]) == cutoff.cn.cpu() + assert pytest.approx(vals[3]) == cutoff.cn_eeq.cpu() diff --git a/test/test_disp/samples.py b/test/test_disp/samples.py index a82234a..cfec053 100644 --- a/test/test_disp/samples.py +++ b/test/test_disp/samples.py @@ -24,7 +24,8 @@ from tad_dftd4._typing import Molecule, Tensor, TypedDict -from ..molecules import merge_nested_dicts, mols +from ..molecules import mols +from ..utils import merge_nested_dicts class Refs(TypedDict): diff --git a/test/test_disp/test_atm.py b/test/test_disp/test_atm.py index afb2ffd..14474ea 100644 --- a/test/test_disp/test_atm.py +++ b/test/test_disp/test_atm.py @@ -18,16 +18,16 @@ """ Test calculation of two-body and three-body dispersion terms. """ -from math import sqrt - import pytest import torch +from tad_dftd4._typing import DD from tad_dftd4.disp import dispersion3 from tad_dftd4.model import D4Model -from tad_dftd4.ncoord import get_coordination_number_d4 +from tad_dftd4.ncoord import coordination_number_d4 from tad_dftd4.utils import pack +from ..conftest import DEVICE from .samples import samples sample_list = ["LiH", "SiH4", "MB16_43_01", "MB16_43_02"] @@ -47,34 +47,35 @@ def test_single_large(name: str, dtype: torch.dtype) -> None: def single(name: str, dtype: torch.dtype) -> None: - tol = sqrt(torch.finfo(dtype).eps) * 10 + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = torch.finfo(dtype).eps ** 0.5 * 10 sample = samples[name] - numbers = sample["numbers"] - positions = sample["positions"].type(dtype) - ref = sample["disp3"].type(dtype) + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + ref = sample["disp3"].to(**dd) # TPSS0-D4-ATM parameters param = { - "s6": positions.new_tensor(1.0), - "s8": positions.new_tensor(1.85897750), - "s9": positions.new_tensor(1.0), - "s10": positions.new_tensor(0.0), - "alp": positions.new_tensor(16.0), - "a1": positions.new_tensor(0.44286966), - "a2": positions.new_tensor(4.60230534), + "s6": torch.tensor(1.00000000, **dd), + "s8": torch.tensor(1.85897750, **dd), + "s9": torch.tensor(1.00000000, **dd), + "s10": torch.tensor(0.0000000, **dd), + "alp": torch.tensor(16.000000, **dd), + "a1": torch.tensor(0.44286966, **dd), + "a2": torch.tensor(4.60230534, **dd), } model = D4Model(numbers, device=positions.device, dtype=positions.dtype) - cn = get_coordination_number_d4(numbers, positions) + cn = coordination_number_d4(numbers, positions) weights = model.weight_references(cn, q=None) c6 = model.get_atomic_c6(weights) - cutoff = positions.new_tensor(40.0) + cutoff = torch.tensor(40.0, **dd) energy = dispersion3(numbers, positions, param, c6, cutoff=cutoff) assert energy.dtype == dtype - assert pytest.approx(ref, abs=tol) == energy + assert pytest.approx(ref.cpu().cpu(), abs=tol) == energy.cpu() @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @@ -93,45 +94,46 @@ def test_batch_large(name1: str, name2: str, dtype: torch.dtype) -> None: def batch(name1: str, name2: str, dtype: torch.dtype) -> None: - tol = sqrt(torch.finfo(dtype).eps) * 10 + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = torch.finfo(dtype).eps ** 0.5 * 10 sample1, sample2 = samples[name1], samples[name2] numbers = pack( [ - sample1["numbers"], - sample2["numbers"], + sample1["numbers"].to(DEVICE), + sample2["numbers"].to(DEVICE), ] ) positions = pack( [ - sample1["positions"].type(dtype), - sample2["positions"].type(dtype), + sample1["positions"].to(**dd), + sample2["positions"].to(**dd), ] ) ref = pack( [ - sample1["disp3"].type(dtype), - sample2["disp3"].type(dtype), + sample1["disp3"].to(**dd), + sample2["disp3"].to(**dd), ] ) # TPSS0-D4-ATM parameters param = { - "s6": positions.new_tensor(1.0), - "s8": positions.new_tensor(1.85897750), - "s9": positions.new_tensor(1.0), - "s10": positions.new_tensor(0.0), - "alp": positions.new_tensor(16.0), - "a1": positions.new_tensor(0.44286966), - "a2": positions.new_tensor(4.60230534), + "s6": torch.tensor(1.00000000, **dd), + "s8": torch.tensor(1.85897750, **dd), + "s9": torch.tensor(1.00000000, **dd), + "s10": torch.tensor(0.0000000, **dd), + "alp": torch.tensor(16.000000, **dd), + "a1": torch.tensor(0.44286966, **dd), + "a2": torch.tensor(4.60230534, **dd), } model = D4Model(numbers, device=positions.device, dtype=positions.dtype) - cn = get_coordination_number_d4(numbers, positions) + cn = coordination_number_d4(numbers, positions) weights = model.weight_references(cn, q=None) c6 = model.get_atomic_c6(weights) energy = dispersion3(numbers, positions, param, c6) assert energy.dtype == dtype - assert pytest.approx(ref, abs=tol) == energy + assert pytest.approx(ref.cpu(), abs=tol) == energy.cpu() diff --git a/test/test_disp/test_full.py b/test/test_disp/test_full.py index 1a74d40..69c78dc 100644 --- a/test/test_disp/test_full.py +++ b/test/test_disp/test_full.py @@ -18,18 +18,18 @@ """ Test calculation of two-body and three-body dispersion terms. """ -from math import sqrt - import pytest import torch from tad_dftd4 import data +from tad_dftd4._typing import DD from tad_dftd4.charges import get_charges from tad_dftd4.cutoff import Cutoff from tad_dftd4.disp import dftd4 from tad_dftd4.model import D4Model from tad_dftd4.utils import pack +from ..conftest import DEVICE from .samples import samples sample_list = ["LiH", "SiH4", "MB16_43_01", "MB16_43_02"] @@ -49,23 +49,24 @@ def test_single_large(name: str, dtype: torch.dtype) -> None: def single(name: str, dtype: torch.dtype) -> None: - tol = sqrt(torch.finfo(dtype).eps) * 10 + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = torch.finfo(dtype).eps ** 0.5 * 10 sample = samples[name] - numbers = sample["numbers"] - positions = sample["positions"].type(dtype) - charge = positions.new_tensor(0.0) - ref = sample["disp"].type(dtype) + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + charge = torch.tensor(0.0, **dd) + ref = sample["disp"].to(**dd) # TPSS0-D4-ATM parameters param = { - "s6": positions.new_tensor(1.0), - "s8": positions.new_tensor(1.85897750), - "s9": positions.new_tensor(1.0), - "s10": positions.new_tensor(0.0), - "alp": positions.new_tensor(16.0), - "a1": positions.new_tensor(0.44286966), - "a2": positions.new_tensor(4.60230534), + "s6": torch.tensor(1.00000000, **dd), + "s8": torch.tensor(1.85897750, **dd), + "s9": torch.tensor(1.00000000, **dd), + "s10": torch.tensor(0.0000000, **dd), + "alp": torch.tensor(16.000000, **dd), + "a1": torch.tensor(0.44286966, **dd), + "a2": torch.tensor(4.60230534, **dd), } model = D4Model(numbers, dtype=dtype) @@ -87,7 +88,7 @@ def single(name: str, dtype: torch.dtype) -> None: ) assert energy.dtype == dtype - assert pytest.approx(ref, abs=tol) == energy + assert pytest.approx(ref.cpu(), abs=tol) == energy.cpu() @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @@ -106,42 +107,43 @@ def test_batch_large(name1: str, name2: str, dtype: torch.dtype) -> None: def batch(name1: str, name2: str, dtype: torch.dtype) -> None: - tol = sqrt(torch.finfo(dtype).eps) * 10 + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = torch.finfo(dtype).eps ** 0.5 * 10 sample1, sample2 = samples[name1], samples[name2] numbers = pack( [ - sample1["numbers"], - sample2["numbers"], + sample1["numbers"].to(DEVICE), + sample2["numbers"].to(DEVICE), ] ) positions = pack( [ - sample1["positions"].type(dtype), - sample2["positions"].type(dtype), + sample1["positions"].to(**dd), + sample2["positions"].to(**dd), ] ) charge = positions.new_zeros(numbers.shape[0]) ref = pack( [ - sample1["disp"].type(dtype), - sample2["disp"].type(dtype), + sample1["disp"].to(**dd), + sample2["disp"].to(**dd), ] ) # TPSS0-D4-ATM parameters param = { - "s6": positions.new_tensor(1.0), - "s8": positions.new_tensor(1.85897750), - "s9": positions.new_tensor(1.0), - "s10": positions.new_tensor(0.0), - "alp": positions.new_tensor(16.0), - "a1": positions.new_tensor(0.44286966), - "a2": positions.new_tensor(4.60230534), + "s6": torch.tensor(1.00000000, **dd), + "s8": torch.tensor(1.85897750, **dd), + "s9": torch.tensor(1.00000000, **dd), + "s10": torch.tensor(0.0000000, **dd), + "alp": torch.tensor(16.000000, **dd), + "a1": torch.tensor(0.44286966, **dd), + "a2": torch.tensor(4.60230534, **dd), } energy = dftd4(numbers, positions, charge, param) assert energy.dtype == dtype - assert pytest.approx(ref, abs=tol) == energy + assert pytest.approx(ref.cpu(), abs=tol) == energy.cpu() diff --git a/test/test_disp/test_grad.py b/test/test_disp/test_grad.py index 9026690..b4f87c9 100644 --- a/test/test_disp/test_grad.py +++ b/test/test_disp/test_grad.py @@ -24,9 +24,10 @@ import torch from torch.autograd.gradcheck import gradcheck -from tad_dftd4._typing import Tensor +from tad_dftd4._typing import DD, Tensor from tad_dftd4.disp import dftd4 +from ..conftest import DEVICE from .samples import samples sample_list = ["LiH", "SiH4", "MB16_43_01"] @@ -35,18 +36,19 @@ @pytest.mark.grad @pytest.mark.parametrize("name", sample_list) def test_grad_param(name) -> None: - dtype = torch.float64 + dd: DD = {"device": DEVICE, "dtype": torch.float64} + sample = samples[name] - numbers = sample["numbers"] - positions = sample["positions"].type(dtype) - charge = positions.new_tensor(0.0) + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + charge = torch.tensor(0.0, **dd) param = ( - positions.new_tensor(1.00000000).requires_grad_(True), - positions.new_tensor(0.78981345).requires_grad_(True), - positions.new_tensor(1.00000000).requires_grad_(True), - positions.new_tensor(0.49484001).requires_grad_(True), - positions.new_tensor(5.73083694).requires_grad_(True), + torch.tensor(1.00000000, **dd, requires_grad=True), + torch.tensor(0.78981345, **dd, requires_grad=True), + torch.tensor(1.00000000, **dd, requires_grad=True), + torch.tensor(0.49484001, **dd, requires_grad=True), + torch.tensor(5.73083694, **dd, requires_grad=True), ) label = ("s6", "s8", "s9", "a1", "a2") @@ -60,21 +62,22 @@ def func(*inputs: Tensor) -> Tensor: @pytest.mark.grad @pytest.mark.parametrize("name", sample_list) def test_grad_positions(name: str) -> None: - dtype = torch.float64 + dd: DD = {"device": DEVICE, "dtype": torch.float64} + sample = samples[name] - numbers = sample["numbers"] - positions = sample["positions"].type(dtype) - charge = positions.new_tensor(0.0) + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + charge = torch.tensor(0.0, **dd) # TPSS0-D4-ATM parameters param = { - "s6": positions.new_tensor(1.0), - "s8": positions.new_tensor(1.85897750), - "s9": positions.new_tensor(1.0), - "s10": positions.new_tensor(0.0), - "alp": positions.new_tensor(16.0), - "a1": positions.new_tensor(0.44286966), - "a2": positions.new_tensor(4.60230534), + "s6": torch.tensor(1.00000000, **dd), + "s8": torch.tensor(1.85897750, **dd), + "s9": torch.tensor(1.00000000, **dd), + "s10": torch.tensor(0.0000000, **dd), + "alp": torch.tensor(16.000000, **dd), + "a1": torch.tensor(0.44286966, **dd), + "a2": torch.tensor(4.60230534, **dd), } pos = positions.detach().clone().requires_grad_(True) diff --git a/test/test_disp/test_twobody.py b/test/test_disp/test_twobody.py index 2705cc4..d28ad81 100644 --- a/test/test_disp/test_twobody.py +++ b/test/test_disp/test_twobody.py @@ -18,18 +18,18 @@ """ Test calculation of two-body and three-body dispersion terms. """ -from math import sqrt - import pytest import torch from tad_dftd4 import data +from tad_dftd4._typing import DD from tad_dftd4.charges import get_charges from tad_dftd4.disp import dftd4, dispersion2 from tad_dftd4.model import D4Model -from tad_dftd4.ncoord import get_coordination_number_d4 +from tad_dftd4.ncoord import coordination_number_d4 from tad_dftd4.utils import pack +from ..conftest import DEVICE from .samples import samples sample_list = ["LiH", "SiH4", "MB16_43_01", "MB16_43_02"] @@ -49,28 +49,29 @@ def test_single_large(name: str, dtype: torch.dtype) -> None: def single(name: str, dtype: torch.dtype) -> None: - tol = sqrt(torch.finfo(dtype).eps) * 10 + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = torch.finfo(dtype).eps ** 0.5 * 10 sample = samples[name] - numbers = sample["numbers"] - positions = sample["positions"].type(dtype) - charge = positions.new_tensor(0.0) - ref = sample["disp2"].type(dtype) + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + charge = torch.tensor(0.0, **dd) + ref = sample["disp2"].to(**dd) # TPSS0-D4-ATM parameters param = { - "s6": positions.new_tensor(1.0), - "s8": positions.new_tensor(1.85897750), - "s9": positions.new_tensor(1.0), - "s10": positions.new_tensor(0.0), - "alp": positions.new_tensor(16.0), - "a1": positions.new_tensor(0.44286966), - "a2": positions.new_tensor(4.60230534), + "s6": torch.tensor(1.00000000, **dd), + "s8": torch.tensor(1.85897750, **dd), + "s9": torch.tensor(1.00000000, **dd), + "s10": torch.tensor(0.0000000, **dd), + "alp": torch.tensor(16.000000, **dd), + "a1": torch.tensor(0.44286966, **dd), + "a2": torch.tensor(4.60230534, **dd), } - r4r2 = data.r4r2[numbers].type(positions.dtype) - model = D4Model(numbers, device=positions.device, dtype=positions.dtype) - cn = get_coordination_number_d4(numbers, positions) + r4r2 = data.r4r2.to(**dd)[numbers] + model = D4Model(numbers, **dd) + cn = coordination_number_d4(numbers, positions) q = get_charges(numbers, positions, charge) weights = model.weight_references(cn, q) c6 = model.get_atomic_c6(weights) @@ -78,43 +79,45 @@ def single(name: str, dtype: torch.dtype) -> None: energy = dispersion2(numbers, positions, param, c6, r4r2) assert energy.dtype == dtype - assert pytest.approx(ref, abs=tol) == energy + assert pytest.approx(ref.cpu(), abs=tol) == energy.cpu() @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("name", sample_list) def test_single_s9_zero(name: str, dtype: torch.dtype) -> None: - tol = sqrt(torch.finfo(dtype).eps) * 10 + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = torch.finfo(dtype).eps ** 0.5 * 10 sample = samples[name] - numbers = sample["numbers"] - positions = sample["positions"].type(dtype) - charge = positions.new_tensor(0.0) - ref = sample["disp2"].type(dtype) + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + charge = torch.tensor(0.0, **dd) + ref = sample["disp2"].to(**dd) # TPSS0-D4-ATM parameters param = { - "s8": positions.new_tensor(1.85897750), - "s9": positions.new_tensor(0.0), # skip ATM - "a1": positions.new_tensor(0.44286966), - "a2": positions.new_tensor(4.60230534), + "s8": torch.tensor(1.85897750, **dd), + "s9": torch.tensor(0.00000000, **dd), # skip ATM + "a1": torch.tensor(0.44286966, **dd), + "a2": torch.tensor(4.60230534, **dd), } energy = dftd4(numbers, positions, charge, param) assert energy.dtype == dtype - assert pytest.approx(ref, abs=tol) == energy + assert pytest.approx(ref.cpu(), abs=tol) == energy.cpu() @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("name", ["SiH4"]) def test_single_s10_one(name: str, dtype: torch.dtype) -> None: - tol = sqrt(torch.finfo(dtype).eps) * 10 + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = torch.finfo(dtype).eps ** 0.5 * 10 sample = samples[name] - numbers = sample["numbers"] - positions = sample["positions"].type(dtype) - charge = positions.new_tensor(0.0) + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + charge = torch.tensor(0.0, **dd) ref = torch.tensor( [ -8.8928018057670788e-04, @@ -123,22 +126,22 @@ def test_single_s10_one(name: str, dtype: torch.dtype) -> None: -3.3765541880036940e-04, -3.3765541880036940e-04, ], - dtype=dtype, + **dd, ) # TPSS0-D4-ATM parameters param = { - "s8": positions.new_tensor(1.85897750), - "s9": positions.new_tensor(0.0), # skip ATM - "s10": positions.new_tensor(1.0), # quadrupole-quadrupole - "a1": positions.new_tensor(0.44286966), - "a2": positions.new_tensor(4.60230534), + "s8": torch.tensor(1.85897750, **dd), + "s9": torch.tensor(0.00000000, **dd), # skip ATM + "s10": torch.tensor(1.0000000, **dd), # quadrupole-quadrupole + "a1": torch.tensor(0.44286966, **dd), + "a2": torch.tensor(4.60230534, **dd), } energy = dftd4(numbers, positions, charge, param) assert energy.dtype == dtype - assert pytest.approx(ref, abs=tol) == energy + assert pytest.approx(ref.cpu(), abs=tol) == energy.cpu() @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @@ -157,44 +160,45 @@ def test_batch_large(name1: str, name2: str, dtype: torch.dtype) -> None: def batch(name1: str, name2: str, dtype: torch.dtype) -> None: - tol = sqrt(torch.finfo(dtype).eps) * 10 + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = torch.finfo(dtype).eps ** 0.5 * 10 sample1, sample2 = samples[name1], samples[name2] numbers = pack( [ - sample1["numbers"], - sample2["numbers"], + sample1["numbers"].to(DEVICE), + sample2["numbers"].to(DEVICE), ] ) positions = pack( [ - sample1["positions"].type(dtype), - sample2["positions"].type(dtype), + sample1["positions"].to(**dd), + sample2["positions"].to(**dd), ] ) charge = positions.new_zeros(numbers.shape[0]) ref = pack( [ - sample1["disp2"].type(dtype), - sample2["disp2"].type(dtype), + sample1["disp2"].to(**dd), + sample2["disp2"].to(**dd), ] ) # TPSS0-D4-ATM parameters param = { - "s6": positions.new_tensor(1.0), - "s8": positions.new_tensor(1.85897750), - "s9": positions.new_tensor(1.0), - "s10": positions.new_tensor(0.0), - "alp": positions.new_tensor(16.0), - "a1": positions.new_tensor(0.44286966), - "a2": positions.new_tensor(4.60230534), + "s6": torch.tensor(1.00000000, **dd), + "s8": torch.tensor(1.85897750, **dd), + "s9": torch.tensor(1.00000000, **dd), + "s10": torch.tensor(0.0000000, **dd), + "alp": torch.tensor(16.000000, **dd), + "a1": torch.tensor(0.44286966, **dd), + "a2": torch.tensor(4.60230534, **dd), } r4r2 = data.r4r2[numbers].type(positions.dtype) model = D4Model(numbers, device=positions.device, dtype=positions.dtype) - cn = get_coordination_number_d4(numbers, positions) + cn = coordination_number_d4(numbers, positions) q = get_charges(numbers, positions, charge) weights = model.weight_references(cn, q) c6 = model.get_atomic_c6(weights) @@ -202,4 +206,4 @@ def batch(name1: str, name2: str, dtype: torch.dtype) -> None: energy = dispersion2(numbers, positions, param, c6, r4r2) assert energy.dtype == dtype - assert pytest.approx(ref, abs=tol) == energy + assert pytest.approx(ref.cpu(), abs=tol) == energy.cpu() diff --git a/test/test_model/samples.py b/test/test_model/samples.py index 7baba42..f46d58e 100644 --- a/test/test_model/samples.py +++ b/test/test_model/samples.py @@ -24,8 +24,8 @@ from tad_dftd4._typing import Molecule, Tensor, TypedDict -from ..molecules import merge_nested_dicts, mols -from ..utils import reshape_fortran +from ..molecules import mols +from ..utils import merge_nested_dicts, reshape_fortran class Refs(TypedDict): diff --git a/test/test_model/test_c6.py b/test/test_model/test_c6.py index 3fbde56..eba9e47 100644 --- a/test/test_model/test_c6.py +++ b/test/test_model/test_c6.py @@ -23,9 +23,11 @@ import torch import torch.nn.functional as F +from tad_dftd4._typing import DD from tad_dftd4.model import D4Model from tad_dftd4.utils import pack +from ..conftest import DEVICE from .samples import samples # only these references use `cn=True` and `q=True` for `gw` @@ -35,15 +37,17 @@ @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("name", sample_list) def test_single(name: str, dtype: torch.dtype) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = 1e-4 if dtype == torch.float else 1e-5 sample = samples[name] - numbers = sample["numbers"] - ref = sample["c6"] + numbers = sample["numbers"].to(DEVICE) + ref = sample["c6"].to(**dd) - d4 = D4Model(numbers, dtype=dtype) + d4 = D4Model(numbers, **dd) # pad reference tensor to always be of shape `(natoms, 7)` - src = sample["gw"].type(dtype) + src = sample["gw"].to(**dd) gw = F.pad( input=src, pad=(0, 0, 0, 7 - src.size(0)), @@ -52,33 +56,35 @@ def test_single(name: str, dtype: torch.dtype) -> None: ).mT c6 = d4.get_atomic_c6(gw) - assert pytest.approx(ref, rel=tol) == c6 + assert pytest.approx(ref.cpu(), rel=tol) == c6.cpu() @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("name1", ["LiH"]) @pytest.mark.parametrize("name2", sample_list) def test_batch(name1: str, name2: str, dtype: torch.dtype) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = 1e-4 if dtype == torch.float else 1e-5 sample1, sample2 = samples[name1], samples[name2] numbers = pack( [ - sample1["numbers"], - sample2["numbers"], + sample1["numbers"].to(DEVICE), + sample2["numbers"].to(DEVICE), ] ) refs = pack( [ - sample1["c6"], - sample2["c6"], + sample1["c6"].to(**dd), + sample2["c6"].to(**dd), ] ) - d4 = D4Model(numbers, dtype=dtype) + d4 = D4Model(numbers, **dd) # pad reference tensor to always be of shape `(natoms, 7)` - src1 = sample1["gw"].type(dtype) - src2 = sample2["gw"].type(dtype) + src1 = sample1["gw"].to(**dd) + src2 = sample2["gw"].to(**dd) gw = pack( [ @@ -93,4 +99,4 @@ def test_batch(name1: str, name2: str, dtype: torch.dtype) -> None: ) c6 = d4.get_atomic_c6(gw) - assert pytest.approx(refs, rel=tol) == c6 + assert pytest.approx(refs.cpu(), rel=tol) == c6.cpu() diff --git a/test/test_model/test_model.py b/test/test_model/test_model.py index e88cc6c..f5fe137 100644 --- a/test/test_model/test_model.py +++ b/test/test_model/test_model.py @@ -22,11 +22,13 @@ import pytest import torch +from tad_dftd4._typing import DD from tad_dftd4.charges import get_charges from tad_dftd4.model import D4Model -from tad_dftd4.ncoord import get_coordination_number_d4 +from tad_dftd4.ncoord import coordination_number_d4 from tad_dftd4.utils import pack +from ..conftest import DEVICE from .samples import samples # only these references use `cn=True` and `q=True` for `gw` @@ -36,52 +38,58 @@ @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("name", sample_list) def test_single(name: str, dtype: torch.dtype) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} tol = 1e-5 + sample = samples[name] - numbers = sample["numbers"] - positions = sample["positions"].type(dtype) - ref = sample["c6"] + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + ref = sample["c6"].to(**dd) - d4 = D4Model(numbers, dtype=dtype) + d4 = D4Model(numbers, **dd) - cn = get_coordination_number_d4(numbers, positions) - q = get_charges(numbers, positions, positions.new_tensor(0.0)) + cn = coordination_number_d4(numbers, positions) + total_charge = torch.tensor(0.0, **dd) + q = get_charges(numbers, positions, total_charge) gw = d4.weight_references(cn=cn, q=q) c6 = d4.get_atomic_c6(gw) - assert pytest.approx(ref, abs=tol, rel=tol) == c6 + assert pytest.approx(ref.cpu(), abs=tol, rel=tol) == c6.cpu() @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("name1", ["LiH"]) @pytest.mark.parametrize("name2", sample_list) def test_batch(name1: str, name2: str, dtype: torch.dtype) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} tol = 1e-5 + sample1, sample2 = samples[name1], samples[name2] numbers = pack( [ - sample1["numbers"], - sample2["numbers"], + sample1["numbers"].to(DEVICE), + sample2["numbers"].to(DEVICE), ] ) positions = pack( [ - sample1["positions"].type(dtype), - sample2["positions"].type(dtype), + sample1["positions"].to(**dd), + sample2["positions"].to(**dd), ] ) refs = pack( [ - sample1["c6"], - sample2["c6"], + sample1["c6"].to(**dd), + sample2["c6"].to(**dd), ] ) - d4 = D4Model(numbers, dtype=dtype) + d4 = D4Model(numbers, **dd) - cn = get_coordination_number_d4(numbers, positions) - q = get_charges(numbers, positions, positions.new_zeros(numbers.shape[0])) + cn = coordination_number_d4(numbers, positions) + total_charge = torch.zeros(numbers.shape[0], **dd) + q = get_charges(numbers, positions, total_charge) gw = d4.weight_references(cn=cn, q=q) c6 = d4.get_atomic_c6(gw) - assert pytest.approx(refs, abs=tol, rel=tol) == c6 + assert pytest.approx(refs.cpu(), abs=tol, rel=tol) == c6.cpu() diff --git a/test/test_model/test_weights.py b/test/test_model/test_weights.py index 12e92dc..74f9ff1 100644 --- a/test/test_model/test_weights.py +++ b/test/test_model/test_weights.py @@ -18,18 +18,17 @@ """ Test calculation of DFT-D4 model. """ - -from math import sqrt - import pytest import torch import torch.nn.functional as F +from tad_dftd4._typing import DD from tad_dftd4.charges import get_charges from tad_dftd4.model import D4Model -from tad_dftd4.ncoord import get_coordination_number_d4 +from tad_dftd4.ncoord import coordination_number_d4 from tad_dftd4.utils import pack +from ..conftest import DEVICE from .samples import samples @@ -39,15 +38,17 @@ def single( with_cn: bool, with_q: bool, ) -> None: - tol = sqrt(torch.finfo(dtype).eps) * 20 + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = torch.finfo(dtype).eps ** 0.5 * 20 + sample = samples[name] - numbers = sample["numbers"] - positions = sample["positions"].type(dtype) + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) - d4 = D4Model(numbers, dtype=dtype) + d4 = D4Model(numbers, **dd) if with_cn is True: - cn = get_coordination_number_d4(numbers, positions) + cn = coordination_number_d4(numbers, positions) else: cn = None # positions.new_zeros(numbers.shape) @@ -59,7 +60,7 @@ def single( gwvec = d4.weight_references(cn, q) # pad reference tensor to always be of shape `(natoms, 7)` - src = sample["gw"].type(dtype) + src = sample["gw"].to(**dd) ref = F.pad( input=src, pad=(0, 0, 0, 7 - src.size(0)), @@ -67,8 +68,9 @@ def single( value=0, ).mT + assert gwvec.dtype == ref.dtype assert gwvec.shape == ref.shape - assert pytest.approx(gwvec, abs=tol) == ref + assert pytest.approx(gwvec.cpu(), abs=tol) == ref.cpu() @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @@ -100,32 +102,34 @@ def test_lih(dtype: torch.dtype) -> None: @pytest.mark.parametrize("name1", ["LiH"]) @pytest.mark.parametrize("name2", ["LiH", "SiH4", "MB16_43_03"]) def test_batch(name1: str, name2: str, dtype: torch.dtype) -> None: - tol = sqrt(torch.finfo(dtype).eps) * 20 + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = torch.finfo(dtype).eps ** 0.5 * 20 + sample1, sample2 = samples[name1], samples[name2] numbers = pack( [ - sample1["numbers"], - sample2["numbers"], + sample1["numbers"].to(DEVICE), + sample2["numbers"].to(DEVICE), ] ) positions = pack( [ - sample1["positions"].type(dtype), - sample2["positions"].type(dtype), + sample1["positions"].to(**dd), + sample2["positions"].to(**dd), ] ) - d4 = D4Model(numbers, dtype=dtype) + d4 = D4Model(numbers, **dd) - cn = get_coordination_number_d4(numbers, positions) + cn = coordination_number_d4(numbers, positions) total_charge = positions.new_zeros(numbers.shape[0]) q = get_charges(numbers, positions, total_charge) gwvec = d4.weight_references(cn, q) # pad reference tensor to always be of shape `(natoms, 7)` - src1 = sample1["gw"].type(dtype) - src2 = sample2["gw"].type(dtype) + src1 = sample1["gw"].to(**dd) + src2 = sample2["gw"].to(**dd) ref = pack( [ @@ -139,5 +143,6 @@ def test_batch(name1: str, name2: str, dtype: torch.dtype) -> None: ] ) + assert gwvec.dtype == ref.dtype assert gwvec.shape == ref.shape - assert pytest.approx(gwvec, abs=tol) == ref + assert pytest.approx(gwvec.cpu(), abs=tol) == ref.cpu() diff --git a/test/test_ncoord/samples.py b/test/test_ncoord/samples.py index 05c2e32..5d4552a 100644 --- a/test/test_ncoord/samples.py +++ b/test/test_ncoord/samples.py @@ -24,7 +24,8 @@ from tad_dftd4._typing import Molecule, Tensor, TypedDict -from ..molecules import merge_nested_dicts, mols +from ..molecules import mols +from ..utils import merge_nested_dicts class Refs(TypedDict): diff --git a/test/test_ncoord/test_cn_d4.py b/test/test_ncoord/test_cn_d4.py index bfe4b4c..d497d19 100644 --- a/test/test_ncoord/test_cn_d4.py +++ b/test/test_ncoord/test_cn_d4.py @@ -23,10 +23,12 @@ import pytest import torch +from tad_dftd4._typing import DD from tad_dftd4.data import cov_rad_d3, pauling_en -from tad_dftd4.ncoord import get_coordination_number_d4 as get_cn +from tad_dftd4.ncoord import coordination_number_d4 as get_cn from tad_dftd4.utils import pack +from ..conftest import DEVICE from .samples import samples sample_list = ["MB16_43_01", "MB16_43_02", "MB16_43_02"] @@ -35,42 +37,46 @@ @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("name", sample_list) def test_single(dtype: torch.dtype, name: str) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} + sample = samples[name] - numbers = sample["numbers"] - positions = sample["positions"].type(dtype) + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) - rcov = cov_rad_d3[numbers] - en = pauling_en[numbers].type(dtype) - cutoff = positions.new_tensor(30.0) - ref = sample["cn_d4"].type(dtype) + rcov = cov_rad_d3.to(**dd)[numbers] + en = pauling_en.to(**dd)[numbers] + cutoff = torch.tensor(30.0, **dd) + ref = sample["cn_d4"].to(**dd) cn = get_cn(numbers, positions, rcov=rcov, en=en, cutoff=cutoff) - assert pytest.approx(cn) == ref + assert pytest.approx(cn.cpu()) == ref.cpu() @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("name1", sample_list) @pytest.mark.parametrize("name2", sample_list) def test_batch(dtype: torch.dtype, name1: str, name2: str) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} + sample1, sample2 = samples[name1], samples[name2] numbers = pack( ( - sample1["numbers"], - sample2["numbers"], + sample1["numbers"].to(DEVICE), + sample2["numbers"].to(DEVICE), ) ) positions = pack( ( - sample1["positions"].type(dtype), - sample2["positions"].type(dtype), + sample1["positions"].to(**dd), + sample2["positions"].to(**dd), ) ) ref = pack( ( - sample1["cn_d4"].type(dtype), - sample2["cn_d4"].type(dtype), + sample1["cn_d4"].to(**dd), + sample2["cn_d4"].to(**dd), ) ) cn = get_cn(numbers, positions) - assert pytest.approx(cn) == ref + assert pytest.approx(cn.cpu()) == ref.cpu() diff --git a/test/test_ncoord/test_cn_eeq.py b/test/test_ncoord/test_cn_eeq.py index 03c70a6..9f7c4d3 100644 --- a/test/test_ncoord/test_cn_eeq.py +++ b/test/test_ncoord/test_cn_eeq.py @@ -23,11 +23,12 @@ import pytest import torch -from tad_dftd4._typing import Tensor +from tad_dftd4._typing import DD, Tensor from tad_dftd4.data import cov_rad_d3 -from tad_dftd4.ncoord import get_coordination_number_eeq as get_cn +from tad_dftd4.ncoord import coordination_number_eeq as get_cn from tad_dftd4.utils import pack +from ..conftest import DEVICE from .samples import samples sample_list = ["MB16_43_01", "MB16_43_02"] @@ -36,55 +37,61 @@ @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("name", sample_list) def test_single(dtype: torch.dtype, name: str) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} + sample = samples[name] - numbers = sample["numbers"] - positions = sample["positions"].type(dtype) + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) - rcov = cov_rad_d3[numbers].type(dtype) - cutoff = positions.new_tensor(30.0) - ref = sample["cn_eeq"].type(dtype) + rcov = cov_rad_d3.to(**dd)[numbers] + cutoff = torch.tensor(30.0, **dd) + ref = sample["cn_eeq"].to(**dd) cn = get_cn(numbers, positions, cutoff=cutoff, rcov=rcov, cn_max=None) - assert pytest.approx(ref) == cn + assert pytest.approx(ref.cpu()) == cn.cpu() @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("cn_max", [49, 51.0, torch.tensor(49)]) def test_single_cnmax(dtype: torch.dtype, cn_max: int | float | Tensor) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} + sample = samples["MB16_43_01"] - numbers = sample["numbers"] - positions = sample["positions"].type(dtype) + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) - ref = sample["cn_eeq"].type(dtype) + ref = sample["cn_eeq"].to(**dd) cn = get_cn(numbers, positions, cn_max=cn_max) - assert pytest.approx(ref, abs=1e-5) == cn + assert pytest.approx(ref.cpu(), abs=1e-5) == cn.cpu() @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize("name1", sample_list) @pytest.mark.parametrize("name2", sample_list) def test_batch(dtype: torch.dtype, name1: str, name2: str) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} + sample1, sample2 = samples[name1], samples[name2] numbers = pack( ( - sample1["numbers"], - sample2["numbers"], + sample1["numbers"].to(DEVICE), + sample2["numbers"].to(DEVICE), ) ) positions = pack( ( - sample1["positions"].type(dtype), - sample2["positions"].type(dtype), + sample1["positions"].to(**dd), + sample2["positions"].to(**dd), ) ) ref = pack( ( - sample1["cn_eeq"].type(dtype), - sample2["cn_eeq"].type(dtype), + sample1["cn_eeq"].to(**dd), + sample2["cn_eeq"].to(**dd), ) ) - cutoff = torch.tensor(30.0, dtype=dtype) + cutoff = torch.tensor(30.0, **dd) cn = get_cn(numbers, positions, cutoff=cutoff, cn_max=None) - assert pytest.approx(ref) == cn + assert pytest.approx(ref.cpu()) == cn.cpu() diff --git a/test/test_ncoord/test_general.py b/test/test_ncoord/test_general.py index 345bb87..e6841da 100644 --- a/test/test_ncoord/test_general.py +++ b/test/test_ncoord/test_general.py @@ -25,10 +25,10 @@ from tad_dftd4._typing import Any, CountingFunction, Protocol, Tensor from tad_dftd4.ncoord import ( + coordination_number_d4, + coordination_number_eeq, erf_count, exp_count, - get_coordination_number_d4, - get_coordination_number_eeq, ) @@ -52,7 +52,7 @@ def __call__( @pytest.mark.parametrize( "function", - [get_coordination_number_d4, get_coordination_number_eeq], + [coordination_number_d4, coordination_number_eeq], ) @pytest.mark.parametrize( "counting_function", diff --git a/test/test_ncoord/test_grad.py b/test/test_ncoord/test_grad.py index d1f6308..a1e40fc 100644 --- a/test/test_ncoord/test_grad.py +++ b/test/test_ncoord/test_grad.py @@ -21,14 +21,14 @@ """ from __future__ import annotations -from math import sqrt - import pytest import torch -from tad_dftd4._typing import CountingFunction +from tad_dftd4._typing import DD, CountingFunction from tad_dftd4.ncoord import derf_count, dexp_count, erf_count, exp_count +from ..conftest import DEVICE + @pytest.mark.parametrize("dtype", [torch.float, torch.double]) @pytest.mark.parametrize( @@ -41,11 +41,13 @@ def test_count_grad( dtype: torch.dtype, function: tuple[CountingFunction, CountingFunction] ) -> None: - tol = sqrt(torch.finfo(dtype).eps) * 10 + dd: DD = {"device": DEVICE, "dtype": dtype} + + tol = torch.finfo(dtype).eps ** 0.5 * 10 cf, dcf = function - a = torch.rand(4, dtype=dtype) - b = torch.rand(4, dtype=dtype) + a = torch.rand(4, **dd) + b = torch.rand(4, **dd) a_grad = a.detach().clone().requires_grad_(True) count = cf(a_grad, b) @@ -53,4 +55,4 @@ def test_count_grad( grad_auto = torch.autograd.grad(count.sum(-1), a_grad)[0] grad_expl = dcf(a, b) - assert pytest.approx(grad_auto, abs=tol) == grad_expl + assert pytest.approx(grad_auto.cpu(), abs=tol) == grad_expl.cpu() diff --git a/test/test_utils/test_cdist.py b/test/test_utils/test_cdist.py new file mode 100644 index 0000000..b9437dd --- /dev/null +++ b/test/test_utils/test_cdist.py @@ -0,0 +1,102 @@ +# This file is part of tad-dftd4. +# +# SPDX-Identifier: LGPL-3.0 +# Copyright (C) 2022 Marvin Friede +# +# tad-dftd4 is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# tad-dftd4 is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with tad-dftd4. If not, see . +""" +Test the utility functions. +""" + +import pytest +import torch + +from tad_dftd4 import utils +from tad_dftd4._typing import DD + +from ..conftest import DEVICE + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_all(dtype: torch.dtype) -> None: + """ + The single precision test sometimes fails on my GPU with the following + thresholds: + + ``` + tol = 1e-6 if dtype == torch.float else 1e-14 + ``` + + Only one matrix element seems to be affected. It also appears that the + failure only happens if `torch.rand` was run before. To be precise, + + ``` + pytest -vv test/test_ncoord/test_grad.py test/test_utils/ --cuda --slow + ``` + + fails, while + + ``` + pytest -vv test/test_utils/ --cuda --slow + ``` + + works. It also works if I remove the random tensors in the gradient test + (test/test_ncoord/test_grad.py). + + It can be fixed with + + ``` + torch.use_deterministic_algorithms(True) + ``` + + and following the PyTorch instructions to set a specific + environment variable. + + ``` + CUBLAS_WORKSPACE_CONFIG=:4096:8 pytest -vv test/test_ncoord/test_grad.py test/test_utils/ --cuda --slow + ``` + + (For simplicity, I just reduced the tolerances for single precision.) + """ + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = 1e-6 if dtype == torch.float else 1e-14 + + # only one element actually fails + if "cuda" in str(DEVICE) and dtype == torch.float: + tol = 1e-3 + + x = torch.randn(2, 3, 4, **dd) + + d1 = utils.cdist(x) + d2 = utils.distance.cdist_direct_expansion(x, x, p=2) + d3 = utils.distance.euclidean_dist_quadratic_expansion(x, x) + + assert pytest.approx(d1.cpu(), abs=tol) == d2.cpu() + assert pytest.approx(d2.cpu(), abs=tol) == d3.cpu() + assert pytest.approx(d3.cpu(), abs=tol) == d1.cpu() + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +@pytest.mark.parametrize("p", [2, 3, 4, 5]) +def test_ps(dtype: torch.dtype, p: int) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} + tol = 1e-6 if dtype == torch.float else 1e-14 + + x = torch.randn(2, 4, 5, **dd) + y = torch.randn(2, 4, 5, **dd) + + d1 = utils.cdist(x, y, p=p) + d2 = torch.cdist(x, y, p=p) + + assert pytest.approx(d1.cpu(), abs=tol) == d2.cpu() diff --git a/test/test_utils/test_real.py b/test/test_utils/test_real.py index c70d742..cf1a7fb 100644 --- a/test/test_utils/test_real.py +++ b/test/test_utils/test_real.py @@ -20,7 +20,7 @@ """ import torch -from tad_dftd4.utils import real_atoms, real_pairs, real_triples +from tad_dftd4 import utils def test_real_atoms() -> None: @@ -36,7 +36,7 @@ def test_real_atoms() -> None: [True, True, True, True, True], # CH4 ], ) - mask = real_atoms(numbers) + mask = utils.real_atoms(numbers) assert (mask == ref).all() @@ -45,11 +45,11 @@ def test_real_pairs_single() -> None: size = numbers.shape[0] ref = torch.full((size, size), True) - mask = real_pairs(numbers, diagonal=True) + mask = utils.real_pairs(numbers, diagonal=True) assert (mask == ref).all() ref *= ~torch.diag_embed(torch.ones(size, dtype=torch.bool)) - mask = real_pairs(numbers, diagonal=False) + mask = utils.real_pairs(numbers, diagonal=False) assert (mask == ref).all() @@ -75,7 +75,7 @@ def test_real_pairs_batch() -> None: ], ] ) - mask = real_pairs(numbers, diagonal=True) + mask = utils.real_pairs(numbers, diagonal=True) assert (mask == ref).all() ref = torch.tensor( @@ -92,7 +92,7 @@ def test_real_pairs_batch() -> None: ], ] ) - mask = real_pairs(numbers, diagonal=False) + mask = utils.real_pairs(numbers, diagonal=False) assert (mask == ref).all() @@ -101,11 +101,11 @@ def test_real_triples_single() -> None: size = numbers.shape[0] ref = torch.full((size, size, size), True) - mask = real_triples(numbers, diagonal=True) + mask = utils.real_triples(numbers, diagonal=True) assert (mask == ref).all() ref *= ~torch.diag_embed(torch.ones(size, dtype=torch.bool)) - mask = real_pairs(numbers, diagonal=False) + mask = utils.real_pairs(numbers, diagonal=False) assert (mask == ref).all() @@ -155,7 +155,7 @@ def test_real_triples_batch() -> None: ], ] ) - mask = real_triples(numbers, diagonal=True) + mask = utils.real_triples(numbers, diagonal=True) assert (mask == ref).all() ref = torch.tensor( @@ -196,5 +196,84 @@ def test_real_triples_batch() -> None: ], ] ) - mask = real_triples(numbers, diagonal=False) + mask = utils.real_triples(numbers, diagonal=False) + assert (mask == ref).all() + + +def test_real_triples_self_single() -> None: + numbers = torch.tensor([8, 1, 1]) # H2O + + ref = torch.tensor( + [ + [ + [False, False, False], + [False, False, True], + [False, True, False], + ], + [ + [False, False, True], + [False, False, False], + [True, False, False], + ], + [ + [False, True, False], + [True, False, False], + [False, False, False], + ], + ], + dtype=torch.bool, + ) + + mask = utils.real_triples(numbers, self=False) + assert (mask == ref).all() + + +def test_real_triples_self_batch() -> None: + numbers = torch.tensor( + [ + [1, 1, 0], # H2 + [8, 1, 1], # H2O + ], + ) + + ref = torch.tensor( + [ + [ + [ + [False, False, False], + [False, False, False], + [False, False, False], + ], + [ + [False, False, False], + [False, False, False], + [False, False, False], + ], + [ + [False, False, False], + [False, False, False], + [False, False, False], + ], + ], + [ + [ + [False, False, False], + [False, False, True], + [False, True, False], + ], + [ + [False, False, True], + [False, False, False], + [True, False, False], + ], + [ + [False, True, False], + [True, False, False], + [False, False, False], + ], + ], + ] + ) + + mask = utils.real_triples(numbers, self=False) assert (mask == ref).all() diff --git a/test/utils.py b/test/utils.py index 2066969..429c2b5 100644 --- a/test/utils.py +++ b/test/utils.py @@ -21,28 +21,51 @@ from __future__ import annotations import torch +from torch.autograd.gradcheck import gradcheck, gradgradcheck -from tad_dftd4._typing import Tensor +from tad_dftd4._typing import Any, Callable, Protocol, Size, Tensor, TensorOrTensors +from .conftest import FAST_MODE -def reshape_fortran(x: Tensor, shape: torch.Size | tuple): - if len(x.shape) > 0: - x = x.permute(*reversed(range(len(x.shape)))) - return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape)))) + +def merge_nested_dicts(a: dict[str, dict], b: dict[str, dict]) -> dict: # type: ignore[type-arg] + """ + Merge nested dictionaries. dictionary `a` remains unaltered, while + the corresponding keys of it are added to `b`. + + Parameters + ---------- + a : dict + First dictionary (not changed). + b : dict + Second dictionary (changed). + + Returns + ------- + dict + Merged dictionary `b`. + """ + for key in b: + if key in a: + b[key].update(a[key]) + return b def get_device_from_str(s: str) -> torch.device: """ Convert device name to `torch.device`. Critically, this also sets the index for CUDA devices to `torch.cuda.current_device()`. + Parameters ---------- s : str Name of the device as string. + Returns ------- torch.device Device as torch class. + Raises ------ KeyError @@ -57,3 +80,141 @@ def get_device_from_str(s: str) -> torch.device: raise KeyError(f"Unknown device '{s}' given.") return d[s] + + +def reshape_fortran(x: Tensor, shape: Size) -> Tensor: + """ + Implements Fortran's `reshape` function (column-major). + + Parameters + ---------- + x : Tensor + Input tensor + shape : Size + Output size to which `x` is reshaped. + + Returns + ------- + Tensor + Reshaped tensor of size `shape`. + """ + if len(x.shape) > 0: + x = x.permute(*reversed(range(len(x.shape)))) + return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape)))) + + +class _GradcheckFunction(Protocol): + """ + Type annotation for gradcheck function. + """ + + def __call__( # type: ignore + self, + func: Callable[..., TensorOrTensors], + inputs: TensorOrTensors, + *, + eps: float = 1e-6, + atol: float = 1e-5, + rtol: float = 1e-3, + raise_exception: bool = True, + check_sparse_nnz: bool = False, + nondet_tol: float = 0.0, + check_undefined_grad: bool = True, + check_grad_dtypes: bool = False, + check_batched_grad: bool = False, + check_batched_forward_grad: bool = False, + check_forward_ad: bool = False, + check_backward_ad: bool = True, + fast_mode: bool = False, + ) -> bool: + ... + + +class _GradgradcheckFunction(Protocol): + """ + Type annotation for gradgradcheck function. + """ + + def __call__( # type: ignore + self, + func: Callable[..., TensorOrTensors], + inputs: TensorOrTensors, + grad_outputs: TensorOrTensors | None = None, + *, + eps: float = 1e-6, + atol: float = 1e-5, + rtol: float = 1e-3, + gen_non_contig_grad_outputs: bool = False, + raise_exception: bool = True, + nondet_tol: float = 0.0, + check_undefined_grad: bool = True, + check_grad_dtypes: bool = False, + check_batched_grad: bool = False, + check_fwd_over_rev: bool = False, + check_rev_over_rev: bool = True, + fast_mode: bool = False, + ) -> bool: + ... + + +def _wrap_gradcheck( + gradcheck_func: _GradcheckFunction | _GradgradcheckFunction, + func: Callable[..., TensorOrTensors], + diffvars: TensorOrTensors, + **kwargs: Any, +) -> bool: + fast_mode = kwargs.pop("fast_mode", FAST_MODE) + try: + assert gradcheck_func(func, diffvars, fast_mode=fast_mode, **kwargs) + finally: + if isinstance(diffvars, Tensor): + diffvars.detach_() + else: + for diffvar in diffvars: + diffvar.detach_() + + return True + + +def dgradcheck( + func: Callable[..., TensorOrTensors], diffvars: TensorOrTensors, **kwargs: Any +) -> bool: + """ + Wrapper for `torch.autograd.gradcheck` that detaches the differentiated + variables after the check. + + Parameters + ---------- + func : Callable[..., TensorOrTensors] + Forward function. + diffvars : TensorOrTensors + Variables w.r.t. which we differentiate. + + Returns + ------- + bool + Status of check. + """ + return _wrap_gradcheck(gradcheck, func, diffvars, **kwargs) + + +def dgradgradcheck( + func: Callable[..., TensorOrTensors], diffvars: TensorOrTensors, **kwargs: Any +) -> bool: + """ + Wrapper for `torch.autograd.gradgradcheck` that detaches the differentiated + variables after the check. + + Parameters + ---------- + func : Callable[..., TensorOrTensors] + Forward function. + diffvars : TensorOrTensors + Variables w.r.t. which we differentiate. + + Returns + ------- + bool + Status of check. + """ + return _wrap_gradcheck(gradgradcheck, func, diffvars, **kwargs)