From ee1598c774b278b1cd9aeda8b1d7b8fbc531ac08 Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Thu, 21 Mar 2024 18:40:08 +0100 Subject: [PATCH] Reduce memory consumption --- src/tad_dftd4/model.py | 56 +++++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/src/tad_dftd4/model.py b/src/tad_dftd4/model.py index a59de2f..3609c4a 100644 --- a/src/tad_dftd4/model.py +++ b/src/tad_dftd4/model.py @@ -41,6 +41,7 @@ from __future__ import annotations import torch +from tad_mctc.math import einsum from . import data, params from .typing import Tensor, TensorLike @@ -177,9 +178,9 @@ def weight_references( # Consequently, some values become zero although the actual result # should be close to one. The problem does not arise when using `torch. # double`. In order to avoid this error, which is also difficult to - # detect, this part always uses `torch.double`. `params.refcn` is saved - # with `torch.double`, but I still made sure... - refcn = params.refcn.to(device=self.device, dtype=torch.double)[self.numbers] + # detect, this part always uses `torch.double`. `params.refcovcn` is + # saved with `torch.double`, but I still made sure... + refcn = params.refcovcn.to(device=self.device, dtype=torch.double)[self.numbers] # For vectorization, we reformulate the Gaussian weighting function: # exp(-wf * igw * (cn - cn_ref)^2) = [exp(-(cn - cn_ref)^2)]^(wf * igw) @@ -252,27 +253,36 @@ def get_atomic_c6(self, gw: Tensor) -> Tensor: Parameters ---------- gw : Tensor - Weights for the atomic reference systems. + Weights for the atomic reference systems of shape + `(..., nat, nref)`. Returns ------- Tensor - C6 coefficients for all atom pairs. + C6 coefficients for all atom pairs of shape `(..., nat, nat)`. """ + # (..., nunique, r, 23) -> (..., n, r, 23) alpha = self.alpha[self.atom_to_unique] - # shape of alpha: (b, nat, nref, 23) - # (b, 1, nat, 1, nref, 23) * (b, nat, 1, nref, 1, 23) = - # (b, nat, nat, nref, nref, 23) - rc6 = trapzd( - alpha.unsqueeze(-4).unsqueeze(-3) * alpha.unsqueeze(-3).unsqueeze(-2) - ) + # (..., n, r, 23) -> (..., n, n, r, r) + rc6 = trapzd(alpha) - # shape of gw: (batch, natoms, nref) - # (b, 1, nat, 1, nref)*(b, nat, 1, nref, 1) = (b, nat, nat, nref, nref) - g = gw.unsqueeze(-3).unsqueeze(-2) * gw.unsqueeze(-2).unsqueeze(-1) + # The default einsum path is fastest if the large tensors comes first. + # (..., n1, n2, r1, r2) * (..., n1, r1) * (..., n2, r2) -> (..., n1, n2) + return einsum( + "...ijab,...ia,...jb->...ij", + *(rc6, gw, gw), + optimize=[(0, 1), (0, 1)], + ) - return torch.sum(g * rc6, dim=(-2, -1)) + # NOTE: This old version creates large intermediate tensors and builds + # the full matrix before the sum reduction, requiring a lot of memory. + # + # (..., 1, n, 1, r) * (..., n, 1, r, 1) = (..., n, n, r, r) + # g = gw.unsqueeze(-3).unsqueeze(-2) * gw.unsqueeze(-2).unsqueeze(-1) + # + # (..., n, n, r, r) * (..., n, n, r, r) -> (..., n, n) + # c6 = torch.sum(g * rc6, dim=(-2, -1)) def _zeta(self, gam: Tensor, qref: Tensor, qmod: Tensor) -> Tensor: """ @@ -348,7 +358,7 @@ def trapzd(polarizability: Tensor) -> Tensor: Parameters ---------- polarizability : Tensor - Polarizabilities. + Polarizabilities of shape `(..., nat, nref, 23)` Returns ------- @@ -385,4 +395,16 @@ def trapzd(polarizability: Tensor) -> Tensor: ] ) - return thopi * torch.sum(weights * polarizability, dim=-1) + # NOTE: In the old version, a memory inefficient intermediate tensor was + # created. The new version uses `einsum` to avoid this. + # + # (..., 1, nat, 1, nref, 23) * (..., nat, 1, nref, 1, 23) = + # (..., nat, nat, nref, nref, 23) -> (..., nat, nat, nref, nref) + # a = alpha.unsqueeze(-4).unsqueeze(-3) * alpha.unsqueeze(-3).unsqueeze(-2) + # + # rc6 = thopi * torch.sum(weights * a, dim=-1) + + return thopi * einsum( + "w,...iaw,...jbw->...ijab", + *(weights, polarizability, polarizability), + )