From cf991f6fa6a7f49638bdc11e5f79c5774b53f829 Mon Sep 17 00:00:00 2001 From: Marvin Friede <51965259+marvinfriede@users.noreply.github.com> Date: Thu, 21 Mar 2024 21:08:07 +0100 Subject: [PATCH] Reduce memory consumption (#45) --- src/tad_dftd4/model.py | 50 ++++++++++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/src/tad_dftd4/model.py b/src/tad_dftd4/model.py index 63b3e4c..3609c4a 100644 --- a/src/tad_dftd4/model.py +++ b/src/tad_dftd4/model.py @@ -41,6 +41,7 @@ from __future__ import annotations import torch +from tad_mctc.math import einsum from . import data, params from .typing import Tensor, TensorLike @@ -252,27 +253,36 @@ def get_atomic_c6(self, gw: Tensor) -> Tensor: Parameters ---------- gw : Tensor - Weights for the atomic reference systems. + Weights for the atomic reference systems of shape + `(..., nat, nref)`. Returns ------- Tensor - C6 coefficients for all atom pairs. + C6 coefficients for all atom pairs of shape `(..., nat, nat)`. """ + # (..., nunique, r, 23) -> (..., n, r, 23) alpha = self.alpha[self.atom_to_unique] - # shape of alpha: (b, nat, nref, 23) - # (b, 1, nat, 1, nref, 23) * (b, nat, 1, nref, 1, 23) = - # (b, nat, nat, nref, nref, 23) - rc6 = trapzd( - alpha.unsqueeze(-4).unsqueeze(-3) * alpha.unsqueeze(-3).unsqueeze(-2) - ) + # (..., n, r, 23) -> (..., n, n, r, r) + rc6 = trapzd(alpha) - # shape of gw: (batch, natoms, nref) - # (b, 1, nat, 1, nref)*(b, nat, 1, nref, 1) = (b, nat, nat, nref, nref) - g = gw.unsqueeze(-3).unsqueeze(-2) * gw.unsqueeze(-2).unsqueeze(-1) + # The default einsum path is fastest if the large tensors comes first. + # (..., n1, n2, r1, r2) * (..., n1, r1) * (..., n2, r2) -> (..., n1, n2) + return einsum( + "...ijab,...ia,...jb->...ij", + *(rc6, gw, gw), + optimize=[(0, 1), (0, 1)], + ) - return torch.sum(g * rc6, dim=(-2, -1)) + # NOTE: This old version creates large intermediate tensors and builds + # the full matrix before the sum reduction, requiring a lot of memory. + # + # (..., 1, n, 1, r) * (..., n, 1, r, 1) = (..., n, n, r, r) + # g = gw.unsqueeze(-3).unsqueeze(-2) * gw.unsqueeze(-2).unsqueeze(-1) + # + # (..., n, n, r, r) * (..., n, n, r, r) -> (..., n, n) + # c6 = torch.sum(g * rc6, dim=(-2, -1)) def _zeta(self, gam: Tensor, qref: Tensor, qmod: Tensor) -> Tensor: """ @@ -348,7 +358,7 @@ def trapzd(polarizability: Tensor) -> Tensor: Parameters ---------- polarizability : Tensor - Polarizabilities. + Polarizabilities of shape `(..., nat, nref, 23)` Returns ------- @@ -385,4 +395,16 @@ def trapzd(polarizability: Tensor) -> Tensor: ] ) - return thopi * torch.sum(weights * polarizability, dim=-1) + # NOTE: In the old version, a memory inefficient intermediate tensor was + # created. The new version uses `einsum` to avoid this. + # + # (..., 1, nat, 1, nref, 23) * (..., nat, 1, nref, 1, 23) = + # (..., nat, nat, nref, nref, 23) -> (..., nat, nat, nref, nref) + # a = alpha.unsqueeze(-4).unsqueeze(-3) * alpha.unsqueeze(-3).unsqueeze(-2) + # + # rc6 = thopi * torch.sum(weights * a, dim=-1) + + return thopi * einsum( + "w,...iaw,...jbw->...ijab", + *(weights, polarizability, polarizability), + )