Skip to content

Commit

Permalink
Reduce memory consumption
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Mar 21, 2024
1 parent c81c12b commit ee1598c
Showing 1 changed file with 39 additions and 17 deletions.
56 changes: 39 additions & 17 deletions src/tad_dftd4/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from __future__ import annotations

import torch
from tad_mctc.math import einsum

from . import data, params
from .typing import Tensor, TensorLike
Expand Down Expand Up @@ -177,9 +178,9 @@ def weight_references(
# Consequently, some values become zero although the actual result
# should be close to one. The problem does not arise when using `torch.
# double`. In order to avoid this error, which is also difficult to
# detect, this part always uses `torch.double`. `params.refcn` is saved
# with `torch.double`, but I still made sure...
refcn = params.refcn.to(device=self.device, dtype=torch.double)[self.numbers]
# detect, this part always uses `torch.double`. `params.refcovcn` is
# saved with `torch.double`, but I still made sure...
refcn = params.refcovcn.to(device=self.device, dtype=torch.double)[self.numbers]

# For vectorization, we reformulate the Gaussian weighting function:
# exp(-wf * igw * (cn - cn_ref)^2) = [exp(-(cn - cn_ref)^2)]^(wf * igw)
Expand Down Expand Up @@ -252,27 +253,36 @@ def get_atomic_c6(self, gw: Tensor) -> Tensor:
Parameters
----------
gw : Tensor
Weights for the atomic reference systems.
Weights for the atomic reference systems of shape
`(..., nat, nref)`.
Returns
-------
Tensor
C6 coefficients for all atom pairs.
C6 coefficients for all atom pairs of shape `(..., nat, nat)`.
"""
# (..., nunique, r, 23) -> (..., n, r, 23)
alpha = self.alpha[self.atom_to_unique]

# shape of alpha: (b, nat, nref, 23)
# (b, 1, nat, 1, nref, 23) * (b, nat, 1, nref, 1, 23) =
# (b, nat, nat, nref, nref, 23)
rc6 = trapzd(
alpha.unsqueeze(-4).unsqueeze(-3) * alpha.unsqueeze(-3).unsqueeze(-2)
)
# (..., n, r, 23) -> (..., n, n, r, r)
rc6 = trapzd(alpha)

# shape of gw: (batch, natoms, nref)
# (b, 1, nat, 1, nref)*(b, nat, 1, nref, 1) = (b, nat, nat, nref, nref)
g = gw.unsqueeze(-3).unsqueeze(-2) * gw.unsqueeze(-2).unsqueeze(-1)
# The default einsum path is fastest if the large tensors comes first.
# (..., n1, n2, r1, r2) * (..., n1, r1) * (..., n2, r2) -> (..., n1, n2)
return einsum(
"...ijab,...ia,...jb->...ij",
*(rc6, gw, gw),
optimize=[(0, 1), (0, 1)],
)

return torch.sum(g * rc6, dim=(-2, -1))
# NOTE: This old version creates large intermediate tensors and builds
# the full matrix before the sum reduction, requiring a lot of memory.
#
# (..., 1, n, 1, r) * (..., n, 1, r, 1) = (..., n, n, r, r)
# g = gw.unsqueeze(-3).unsqueeze(-2) * gw.unsqueeze(-2).unsqueeze(-1)
#
# (..., n, n, r, r) * (..., n, n, r, r) -> (..., n, n)
# c6 = torch.sum(g * rc6, dim=(-2, -1))

def _zeta(self, gam: Tensor, qref: Tensor, qmod: Tensor) -> Tensor:
"""
Expand Down Expand Up @@ -348,7 +358,7 @@ def trapzd(polarizability: Tensor) -> Tensor:
Parameters
----------
polarizability : Tensor
Polarizabilities.
Polarizabilities of shape `(..., nat, nref, 23)`
Returns
-------
Expand Down Expand Up @@ -385,4 +395,16 @@ def trapzd(polarizability: Tensor) -> Tensor:
]
)

return thopi * torch.sum(weights * polarizability, dim=-1)
# NOTE: In the old version, a memory inefficient intermediate tensor was
# created. The new version uses `einsum` to avoid this.
#
# (..., 1, nat, 1, nref, 23) * (..., nat, 1, nref, 1, 23) =
# (..., nat, nat, nref, nref, 23) -> (..., nat, nat, nref, nref)
# a = alpha.unsqueeze(-4).unsqueeze(-3) * alpha.unsqueeze(-3).unsqueeze(-2)
#
# rc6 = thopi * torch.sum(weights * a, dim=-1)

return thopi * einsum(
"w,...iaw,...jbw->...ijab",
*(weights, polarizability, polarizability),
)

0 comments on commit ee1598c

Please sign in to comment.