Skip to content

Commit

Permalink
Refactor and GPU tests
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Nov 26, 2023
1 parent 3ce3dd8 commit 346af9a
Show file tree
Hide file tree
Showing 37 changed files with 1,305 additions and 433 deletions.
1 change: 0 additions & 1 deletion src/tad_dftd4/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,4 @@
"""
Module containing the version string.
"""

__version__ = "0.0.4"
34 changes: 26 additions & 8 deletions src/tad_dftd4/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@
elif sys.version_info >= (3, 8):
# in Python 3.8, "from __future__ import annotations" only affects
# type annotations not type aliases
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Union, tuple

Sliceable = Union[List[Tensor], Tuple[Tensor, ...]]
Size = Union[List[int], Tuple[int], torch.Size]
TensorOrTensors = Union[List[Tensor], Tuple[Tensor, ...], Tensor]
Sliceable = Union[List[Tensor], tuple[Tensor, ...]]
Size = Union[List[int], tuple[int], torch.Size]
TensorOrTensors = Union[List[Tensor], tuple[Tensor, ...], Tensor]
DampingFunction = Callable[[int, Tensor, Tensor, Dict[str, Tensor]], Tensor]
else:
raise RuntimeError(
Expand All @@ -107,6 +107,26 @@ class Molecule(TypedDict):
"""Tensor of 3D coordinates of shape (n, 3)"""


class DD(TypedDict):
"""Collection of torch.device and torch.dtype."""

device: torch.device | None
"""Device on which a tensor lives."""

dtype: torch.dtype
"""Floating point precision of a tensor."""


def get_default_device() -> torch.device:
"""Default device for tensors."""
return torch.tensor(1.0).device


def get_default_dtype() -> torch.dtype:
"""Default data type for floating point tensors."""
return torch.tensor(1.0).dtype


class TensorLike:
"""
Provide `device` and `dtype` as well as `to()` and `type()` for other
Expand All @@ -118,10 +138,8 @@ class TensorLike:
def __init__(
self, device: torch.device | None = None, dtype: torch.dtype | None = None
):
self.__device = (
device if device is not None else torch.device(defaults.TORCH_DEVICE)
)
self.__dtype = dtype if dtype is not None else defaults.TORCH_DTYPE
self.__device = device if device is not None else get_default_device()
self.__dtype = dtype if dtype is not None else get_default_dtype()

@property
def device(self) -> torch.device:
Expand Down
34 changes: 22 additions & 12 deletions src/tad_dftd4/charges.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@

import torch

from ._typing import Tensor, TensorLike
from .ncoord import get_coordination_number_eeq
from ._typing import DD, Tensor, TensorLike
from .ncoord import coordination_number_eeq
from .utils import real_atoms, real_pairs

__all__ = ["ChargeModel", "solve", "get_charges"]
Expand Down Expand Up @@ -93,6 +93,8 @@ def __init__(
self.eta = eta
self.rad = rad

print(self.device)
print(self.chi.device)
if any(
tensor.device != self.device
for tensor in (self.chi, self.kcn, self.eta, self.rad)
Expand All @@ -106,20 +108,28 @@ def __init__(
raise RuntimeError("All tensors must have the same dtype!")

@classmethod
def param2019(cls) -> ChargeModel:
def param2019(
cls,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> ChargeModel:
"""
Electronegativity equilibration charge model published in
- E. Caldeweyher, S. Ehlert, A. Hansen, H. Neugebauer, S. Spicher, C. Bannwarth
and S. Grimme, *J. Chem. Phys.*, **2019**, 150, 154122.
- E. Caldeweyher, S. Ehlert, A. Hansen, H. Neugebauer, S. Spicher,
C. Bannwarth and S. Grimme, *J. Chem. Phys.*, **2019**, 150, 154122.
DOI: `10.1063/1.5090222 <https://dx.doi.org/10.1063/1.5090222>`__
"""
dd: dict = {"device": device}
if dtype is not None:
dd["dtype"] = dtype

return cls(
_chi2019,
_kcn2019,
_eta2019,
_rad2019,
_chi2019.to(**dd),
_kcn2019.to(**dd),
_eta2019.to(**dd),
_rad2019.to(**dd),
**dd,
)


Expand Down Expand Up @@ -247,7 +257,7 @@ def solve(
Returns
-------
(Tensor, Tensor)
Tuple of electrostatic energies and partial charges.
tuple of electrostatic energies and partial charges.
Example
-------
Expand Down Expand Up @@ -371,8 +381,8 @@ def get_charges(
Tensor
Atomic charges.
"""
eeq = ChargeModel.param2019().to(positions.device).type(positions.dtype)
cn = get_coordination_number_eeq(numbers, positions, cutoff=cutoff)
eeq = ChargeModel.param2019(device=positions.device, dtype=positions.dtype)
cn = coordination_number_eeq(numbers, positions, cutoff=cutoff)
_, qat = solve(numbers, positions, chrg, eeq, cn)

return qat
58 changes: 39 additions & 19 deletions src/tad_dftd4/damping/atm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .. import defaults
from .._typing import Tensor
from ..data import r4r2
from ..utils import real_pairs, real_triples
from ..utils import cdist, real_pairs, real_triples


def get_atm_dispersion(
Expand Down Expand Up @@ -89,15 +89,24 @@ def get_atm_dispersion(

cutoff2 = cutoff * cutoff

mask_pairs = real_pairs(numbers, diagonal=False)
mask_triples = real_triples(numbers, diagonal=False, self=False)

# filler values for masks
eps = torch.tensor(torch.finfo(positions.dtype).eps, **dd)
zero = torch.tensor(0.0, **dd)
one = torch.tensor(1.0, **dd)

# C9_ABC = s9 * sqrt(|C6_AB * C6_AC * C6_BC|)
c9 = s9 * torch.sqrt(
torch.abs(c6.unsqueeze(-1) * c6.unsqueeze(-2) * c6.unsqueeze(-3))
torch.clamp(
torch.abs(c6.unsqueeze(-1) * c6.unsqueeze(-2) * c6.unsqueeze(-3)), min=eps
)
)

temp = (
a1 * torch.sqrt(3.0 * r4r2[numbers].unsqueeze(-1) * r4r2[numbers].unsqueeze(-2))
+ a2
)
radii = r4r2[numbers].unsqueeze(-1) * r4r2[numbers].unsqueeze(-2)
temp = a1 * torch.sqrt(3.0 * radii) + a2

r0ij = temp.unsqueeze(-1)
r0ik = temp.unsqueeze(-2)
r0jk = temp.unsqueeze(-3)
Expand All @@ -107,11 +116,9 @@ def get_atm_dispersion(
# very slow: (pos.unsqueeze(-2) - pos.unsqueeze(-3)).pow(2).sum(-1)
distances = torch.pow(
torch.where(
real_pairs(numbers, diagonal=False),
torch.cdist(
positions, positions, p=2, compute_mode="use_mm_for_euclid_dist"
),
torch.tensor(torch.finfo(positions.dtype).eps, **dd),
mask_pairs,
cdist(positions, positions, p=2),
eps,
),
2.0,
)
Expand All @@ -121,17 +128,30 @@ def get_atm_dispersion(
r2jk = distances.unsqueeze(-3)
r2 = r2ij * r2ik * r2jk
r1 = torch.sqrt(r2)
r3 = r1 * r2
r5 = r2 * r3
# add epsilon to avoid zero division later
r3 = torch.where(mask_triples, r1 * r2, eps)
r5 = torch.where(mask_triples, r2 * r3, eps)

# dividing by tiny numbers leads to huge numbers, which result in NaN's
# upon exponentiation in the subsequent step
base = r0 / torch.where(mask_triples, r1, one)

# to fix the previous mask, we mask again (not strictly necessary because
# `ang` is also masked and we later multiply with `ang`)
fdamp = torch.where(
mask_triples,
1.0 / (1.0 + 6.0 * base ** (alp / 3.0)),
zero,
)

fdamp = 1.0 / (1.0 + 6.0 * (r0 / r1) ** (alp / 3.0))
s = torch.where(
mask_triples,
(r2ij + r2jk - r2ik) * (r2ij - r2jk + r2ik) * (-r2ij + r2jk + r2ik),
zero,
)

s = (r2ij + r2jk - r2ik) * (r2ij - r2jk + r2ik) * (-r2ij + r2jk + r2ik)
ang = torch.where(
real_triples(numbers, diagonal=False)
* (r2ij <= cutoff2)
* (r2jk <= cutoff2)
* (r2jk <= cutoff2),
mask_triples * (r2ij <= cutoff2) * (r2jk <= cutoff2) * (r2jk <= cutoff2),
0.375 * s / r5 + 1.0 / r3,
torch.tensor(0.0, **dd),
)
Expand Down
14 changes: 1 addition & 13 deletions src/tad_dftd4/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,6 @@

import torch

# PyTorch

TORCH_DTYPE = torch.float32
"""Default data type for floating point tensors."""

TORCH_DTYPE_CHOICES = ["float32", "float64", "double", "sp", "dp"]
"""List of possible choices for `TORCH_DTYPE`."""

TORCH_DEVICE = "cpu"
"""Default device for tensors."""


# DFT-D4

D4_CN_CUTOFF = 30.0
Expand Down Expand Up @@ -48,7 +36,7 @@
D4_K6 = 2 * 11.28174**2 # 254.56
"""Parameter for electronegativity scaling."""

# DFT-D4 damping
# DFT-D4 damping parameters

A1 = 0.4
"""Scaling for the C8 / C6 ratio in the critical radius (0.4)."""
Expand Down
20 changes: 11 additions & 9 deletions src/tad_dftd4/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
import torch

from . import data, defaults
from ._typing import Any, CountingFunction, DampingFunction, Tensor
from ._typing import DD, Any, CountingFunction, DampingFunction, Tensor
from .charges import get_charges
from .cutoff import Cutoff
from .damping import get_atm_dispersion, rational_damping
from .model import D4Model
from .ncoord import erf_count, get_coordination_number_d4
from .utils import real_pairs
from .ncoord import coordination_number_d4, erf_count
from .utils import cdist, real_pairs


def dftd4(
Expand Down Expand Up @@ -98,15 +98,17 @@ def dftd4(
Shape inconsistencies between `numbers`, `positions`, `r4r2`, or,
`rcov`.
"""
dd: DD = {"device": positions.device, "dtype": positions.dtype}

if model is None:
model = D4Model(numbers, device=positions.device, dtype=positions.dtype)
model = D4Model(numbers, **dd)
if cutoff is None:
cutoff = Cutoff(device=positions.device, dtype=positions.dtype)
cutoff = Cutoff(**dd)

if rcov is None:
rcov = data.cov_rad_d3[numbers].type(positions.dtype).to(positions.device)
rcov = data.cov_rad_d3.to(**dd)[numbers]
if r4r2 is None:
r4r2 = data.r4r2[numbers].type(positions.dtype).to(positions.device)
r4r2 = data.r4r2.to(**dd)[numbers]
if q is None:
q = get_charges(numbers, positions, charge, cutoff=cutoff.cn_eeq)

Expand All @@ -131,7 +133,7 @@ def dftd4(
f"atomic numbers ({numbers.shape}).",
)

cn = get_coordination_number_d4(
cn = coordination_number_d4(
numbers, positions, counting_function, rcov, cutoff=cutoff.cn
)
weights = model.weight_references(cn, q)
Expand Down Expand Up @@ -204,7 +206,7 @@ def dispersion2(
mask = real_pairs(numbers, diagonal=False)
distances = torch.where(
mask,
torch.cdist(positions, positions, p=2, compute_mode="use_mm_for_euclid_dist"),
cdist(positions, positions, p=2),
torch.tensor(torch.finfo(positions.dtype).eps, **dd),
)

Expand Down
7 changes: 3 additions & 4 deletions src/tad_dftd4/ncoord/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,12 @@
... ))
>>>
>>> torch.set_printoptions(precision=7)
>>> print(d4.get_coordination_number_d4(numbers, positions))
>>> print(d4.coordination_number_d4(numbers, positions))
tensor([[2.6886456, 2.6886456, 2.6314170, 2.6314168, 0.8594539, 0.9231414,
0.8605307, 0.8605307, 0.8594539, 0.9231414, 0.8568342, 0.8568342],
[2.6886456, 0.8568335, 2.6314168, 0.8605307, 0.8594532, 0.9231415,
0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000]])
"""

from .count import derf_count, dexp_count, erf_count, exp_count
from .d4 import get_coordination_number_d4
from .eeq import get_coordination_number_eeq
from .d4 import coordination_number_d4
from .eeq import coordination_number_eeq
3 changes: 2 additions & 1 deletion src/tad_dftd4/ncoord/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@
Additionally, the analytical derivatives for both counting functions are also
provided and can be used for checking the autograd results.
"""

from math import pi, sqrt

import torch

from .. import defaults
from .._typing import Tensor

__all__ = ["exp_count", "dexp_count", "erf_count", "derf_count"]


def exp_count(r: Tensor, r0: Tensor, kcn: float = defaults.D4_KCN) -> Tensor:
"""
Expand Down
4 changes: 2 additions & 2 deletions src/tad_dftd4/ncoord/d4.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
from ..utils import real_pairs
from .count import erf_count

__all__ = ["get_coordination_number_d4"]
__all__ = ["coordination_number_d4"]


def get_coordination_number_d4(
def coordination_number_d4(
numbers: Tensor,
positions: Tensor,
counting_function: CountingFunction = erf_count,
Expand Down
4 changes: 2 additions & 2 deletions src/tad_dftd4/ncoord/eeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
from ..utils import real_pairs
from .count import erf_count

__all__ = ["get_coordination_number_eeq"]
__all__ = ["coordination_number_eeq"]


def get_coordination_number_eeq(
def coordination_number_eeq(
numbers: Tensor,
positions: Tensor,
counting_function: CountingFunction = erf_count,
Expand Down
Loading

0 comments on commit 346af9a

Please sign in to comment.