Skip to content

Commit

Permalink
Bug fix (#19)
Browse files Browse the repository at this point in the history
- fix package_data in setup.cfg (v0.1.1 does not package non-Python data, omitting the npy file!)
- the C6 coefficients show minor deviations because of floating point inaccuracies and small values
- improve typing to fix mypy errors
- rigorous gradient tests (includes some `nan` fixes)
- hessian test
  • Loading branch information
marvinfriede authored May 22, 2023
1 parent d93d17a commit 22df391
Show file tree
Hide file tree
Showing 27 changed files with 5,773 additions and 974 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ dev =
tox

[options.package_data]
dxtb =
tad_dftd3 =
py.typed
src/*/*.npy
*.npy
12 changes: 9 additions & 3 deletions src/tad_dftd3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@

from . import damping, data, disp, model, ncoord, reference, util
from .typing import (
DD,
CountingFunction,
DampingFunction,
Dict,
Expand All @@ -93,6 +94,7 @@ def dftd3(
rcov: Optional[Tensor] = None,
rvdw: Optional[Tensor] = None,
r4r2: Optional[Tensor] = None,
cutoff: Optional[Tensor] = None,
counting_function: CountingFunction = ncoord.exp_count,
weighting_function: WeightingFunction = model.gaussian_weight,
damping_function: DampingFunction = damping.rational_damping,
Expand Down Expand Up @@ -125,12 +127,15 @@ def dftd3(
Returns
-------
torch.Tensor
DFT-D3 dispersion energy for each geometry.
Tensor
Atom-resolved DFT-D3 dispersion energy for each geometry.
"""
dd: DD = {"device": positions.device, "dtype": positions.dtype}

if cutoff is None:
cutoff = torch.tensor(50.0, **dd)
if ref is None:
ref = reference.Reference().type(positions.dtype).to(positions.device)
ref = reference.Reference(**dd)
if rcov is None:
rcov = data.covalent_rad_d3[numbers].type(positions.dtype).to(positions.device)
if rvdw is None:
Expand All @@ -155,6 +160,7 @@ def dftd3(
rvdw,
r4r2,
damping_function,
cutoff=cutoff,
)

return energy
33 changes: 18 additions & 15 deletions src/tad_dftd3/damping/atm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import torch

from .. import defaults
from ..typing import Tensor
from ..util import real_pairs, real_triples
from ..typing import DD, Tensor
from ..util import cdist, real_pairs, real_triples


def dispersion_atm(
Expand Down Expand Up @@ -60,7 +60,7 @@ def dispersion_atm(
Tensor
Atom-resolved ATM dispersion energy.
"""
dd = {"device": positions.device, "dtype": positions.dtype}
dd: DD = {"device": positions.device, "dtype": positions.dtype}

s9 = s9.type(positions.dtype).to(positions.device)
rs9 = rs9.type(positions.dtype).to(positions.device)
Expand All @@ -69,9 +69,16 @@ def dispersion_atm(
cutoff2 = cutoff * cutoff
srvdw = rs9 * rvdw

mask_pairs = real_pairs(numbers, diagonal=False)
mask_triples = real_triples(numbers, diagonal=False)

eps = torch.tensor(torch.finfo(positions.dtype).eps, **dd)

# C9_ABC = s9 * sqrt(|C6_AB * C6_AC * C6_BC|)
c9 = s9 * torch.sqrt(
torch.abs(c6.unsqueeze(-1) * c6.unsqueeze(-2) * c6.unsqueeze(-3))
torch.clamp(
torch.abs(c6.unsqueeze(-1) * c6.unsqueeze(-2) * c6.unsqueeze(-3)), min=eps
)
)

r0ij = srvdw.unsqueeze(-1)
Expand All @@ -83,11 +90,9 @@ def dispersion_atm(
# very slow: (pos.unsqueeze(-2) - pos.unsqueeze(-3)).pow(2).sum(-1)
distances = torch.pow(
torch.where(
real_pairs(numbers, diagonal=False),
torch.cdist(
positions, positions, p=2, compute_mode="use_mm_for_euclid_dist"
),
torch.tensor(torch.finfo(positions.dtype).eps, **dd),
mask_pairs,
cdist(positions, positions, p=2),
eps,
),
2.0,
)
Expand All @@ -97,17 +102,15 @@ def dispersion_atm(
r2jk = distances.unsqueeze(-3)
r2 = r2ij * r2ik * r2jk
r1 = torch.sqrt(r2)
r3 = r1 * r2
r5 = r2 * r3
# add epsilon to avoid zero division later
r3 = torch.where(mask_triples, r1 * r2, eps)
r5 = torch.where(mask_triples, r2 * r3, eps)

fdamp = 1.0 / (1.0 + 6.0 * (r0 / r1) ** ((alp + 2.0) / 3.0))

s = (r2ij + r2jk - r2ik) * (r2ij - r2jk + r2ik) * (-r2ij + r2jk + r2ik)
ang = torch.where(
real_triples(numbers, diagonal=False)
* (r2ij <= cutoff2)
* (r2jk <= cutoff2)
* (r2jk <= cutoff2),
mask_triples * (r2ij <= cutoff2) * (r2jk <= cutoff2) * (r2jk <= cutoff2),
0.375 * s / r5 + 1.0 / r3,
torch.tensor(0.0, **dd),
)
Expand Down
4 changes: 2 additions & 2 deletions src/tad_dftd3/damping/rational.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch

from .. import defaults
from ..typing import Dict, Tensor
from ..typing import DD, Dict, Tensor


def rational_damping(
Expand Down Expand Up @@ -43,7 +43,7 @@ def rational_damping(
Tensor
Values of the damping function.
"""
dd = {"device": distances.device, "dtype": distances.dtype}
dd: DD = {"device": distances.device, "dtype": distances.dtype}

a1 = param.get("a1", torch.tensor(defaults.A1, **dd))
a2 = param.get("a2", torch.tensor(defaults.A2, **dd))
Expand Down
17 changes: 11 additions & 6 deletions src/tad_dftd3/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@

from . import data, defaults
from .damping import dispersion_atm, rational_damping
from .typing import Any, DampingFunction, Dict, Optional, Tensor
from .util import real_pairs
from .typing import DD, Any, DampingFunction, Dict, Optional, Tensor
from .util import cdist, real_pairs


def dispersion(
Expand Down Expand Up @@ -91,8 +91,13 @@ def dispersion(
damping_function : Callable
Damping function evaluate distance dependent contributions.
Additional arguments are passed through to the function.
Returns
-------
Tensor
Atom-resolved DFT-D3 dispersion energy for each geometry.
"""
dd = {"device": positions.device, "dtype": positions.dtype}
dd: DD = {"device": positions.device, "dtype": positions.dtype}

if cutoff is None:
cutoff = torch.tensor(50.0, **dd)
Expand Down Expand Up @@ -157,12 +162,12 @@ def dispersion2(
Damping function evaluate distance dependent contributions.
Additional arguments are passed through to the function.
"""
dd = {"device": positions.device, "dtype": positions.dtype}
dd: DD = {"device": positions.device, "dtype": positions.dtype}

mask = real_pairs(numbers, diagonal=False)
distances = torch.where(
mask,
torch.cdist(positions, positions, p=2, compute_mode="use_mm_for_euclid_dist"),
cdist(positions, positions, p=2),
torch.tensor(torch.finfo(positions.dtype).eps, **dd),
)

Expand Down Expand Up @@ -224,7 +229,7 @@ def dispersion3(
Tensor
Atom-resolved three-body dispersion energy.
"""
dd = {"device": positions.device, "dtype": positions.dtype}
dd: DD = {"device": positions.device, "dtype": positions.dtype}

alp = param.get("alp", torch.tensor(14.0, **dd))
s9 = param.get("s9", torch.tensor(1.0, **dd))
Expand Down
34 changes: 30 additions & 4 deletions src/tad_dftd3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

from .reference import Reference
from .typing import Any, Tensor, WeightingFunction
from .util import real_atoms


def atomic_c6(
Expand Down Expand Up @@ -101,7 +102,6 @@ def weight_references(
cn: Tensor,
reference: Reference,
weighting_function: WeightingFunction = gaussian_weight,
epsilon: float = 1.0e-20,
**kwargs: Any,
) -> Tensor:
"""
Expand All @@ -126,11 +126,37 @@ def weight_references(

mask = reference.cn[numbers] >= 0

# Due to the exponentiation, `norms` and `weights` may become very small.
# This may cause problems for the division by `norms`. It may occur that
# `weights` and `norms` are equal, in which case the result should be
# exactly one. This might, however, not be the case and ultimately cause
# larger deviations in the final values.
#
# If the values become even smaller, we may have to evaluate this portion
# in double precision to retain the correct results. This must be done in
# the D4 variant because the weighting functions contains higher powers,
# which lead to values down to 1e-300.
dcn = reference.cn[numbers] - cn.unsqueeze(-1)
weights = torch.where(
mask,
weighting_function(reference.cn[numbers] - cn.unsqueeze(-1), **kwargs),
torch.tensor(0.0, device=cn.device, dtype=cn.dtype),
weighting_function(dcn, **kwargs),
torch.tensor(0.0, device=dcn.device, dtype=dcn.dtype), # not eps!
)
norms = torch.add(torch.sum(weights, dim=-1), epsilon)

# Nevertheless, we must avoid zero division here in batched calculations.
#
# Previously, a small value was added to `norms` to prevent division by zero
# (`norms = torch.add(torch.sum(weights, dim=-1), 1e-20)`). However, even
# such small values can lead to relatively large deviations because the
# small value is not added to the weights, and hence, the case where
# `weights` and `norms` are equal does not yield one anymore. In fact, the
# test suite fails because some elements deviate up to around 1e-4.
#
# We solve this issue by using a mask from the atoms and only add a small
# value, where the actual padding zeros are.
norms = torch.where(
real_atoms(numbers),
torch.sum(weights, dim=-1),
torch.tensor(torch.finfo(dcn.dtype).eps, device=cn.device, dtype=dcn.dtype),
)
return weights / norms.unsqueeze(-1)
8 changes: 4 additions & 4 deletions src/tad_dftd3/ncoord.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@
import torch

from . import data
from .typing import Any, CountingFunction, Optional, Tensor
from .util import real_pairs
from .typing import DD, Any, CountingFunction, Optional, Tensor
from .util import cdist, real_pairs


def exp_count(r: Tensor, r0: Tensor, kcn: float = 16.0) -> Tensor:
Expand Down Expand Up @@ -116,7 +116,7 @@ def coordination_number(
-------
Tensor: The coordination number of each atom in the system.
"""
dd = {"device": positions.device, "dtype": positions.dtype}
dd: DD = {"device": positions.device, "dtype": positions.dtype}

if cutoff is None:
cutoff = torch.tensor(25.0, **dd)
Expand All @@ -132,7 +132,7 @@ def coordination_number(
mask = real_pairs(numbers, diagonal=False)
distances = torch.where(
mask,
torch.cdist(positions, positions, p=2, compute_mode="use_mm_for_euclid_dist"),
cdist(positions, positions, p=2),
torch.tensor(torch.finfo(positions.dtype).eps, **dd),
)

Expand Down
35 changes: 25 additions & 10 deletions src/tad_dftd3/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from .typing import Any, NoReturn, Optional, Tensor


def _load_cn(dtype: torch.dtype = torch.float) -> Tensor:
def _load_cn(
dtype: torch.dtype = torch.float, device: Optional[torch.device] = None
) -> Tensor:
return torch.tensor(
[
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000], # None
Expand Down Expand Up @@ -124,11 +126,15 @@ def _load_cn(dtype: torch.dtype = torch.float) -> Tensor:
[+0.0000, +2.8878, -1.0000, -1.0000, -1.0000], # U
[+0.0000, +2.9095, -1.0000, -1.0000, -1.0000], # Np
[+0.0000, +1.9209, -1.0000, -1.0000, -1.0000], # Pu
]
).type(dtype)
],
device=device,
dtype=dtype,
)


def _load_c6(dtype: torch.dtype = torch.float) -> Tensor:
def _load_c6(
dtype: torch.dtype = torch.float, device: Optional[torch.device] = None
) -> Tensor:
"""
Load reference C6 coefficients from file and fill them into a tensor
"""
Expand All @@ -138,13 +144,14 @@ def _load_c6(dtype: torch.dtype = torch.float) -> Tensor:

import numpy as np

ref = torch.from_numpy(
np.load(op.join(op.dirname(__file__), "reference-c6.npy"))
).type(dtype)
path = op.join(op.dirname(__file__), "reference-c6.npy")
ref = torch.from_numpy(np.load(path)).type(dtype).to(device)

n_element = (math.isqrt(8 * ref.shape[0] + 1) - 1) // 2 + 1
n_reference = ref.shape[-1]
c6 = torch.zeros((n_element, n_element, n_reference, n_reference), dtype=dtype)
c6 = torch.zeros(
(n_element, n_element, n_reference, n_reference), dtype=dtype, device=device
)

for i in range(1, n_element):
for j in range(1, n_element):
Expand Down Expand Up @@ -176,12 +183,20 @@ def __init__(
self,
cn: Optional[Tensor] = None,
c6: Optional[Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if cn is None:
cn = _load_cn()
cn = _load_cn(
dtype if dtype is not None else torch.float,
device=device,
)
self.cn = cn
if c6 is None:
c6 = _load_c6()
c6 = _load_c6(
dtype if dtype is not None else torch.float,
device=device,
)
self.c6 = c6

self.__dtype = self.c6.dtype
Expand Down
10 changes: 10 additions & 0 deletions src/tad_dftd3/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,13 @@ class Molecule(TypedDict):

positions: Tensor
"""Tensor of 3D coordinates of shape (n, 3)"""


class DD(TypedDict):
"""Collection of torch.device and torch.dtype."""

device: Union[torch.device, None]
"""Device on which a tensor lives."""

dtype: torch.dtype
"""Floating point precision of a tensor."""
Loading

0 comments on commit 22df391

Please sign in to comment.