From 63881c4de9b19b378d083cefd98ca9ceda309531 Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Thu, 21 Mar 2024 18:34:12 +0100 Subject: [PATCH] Clean up --- setup.cfg | 2 +- src/tad_dftd3/model.py | 25 ++++++++++++------------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/setup.cfg b/setup.cfg index 6c72592..0d4a5dc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,7 +24,7 @@ project_urls = packages = find: install_requires = numpy - opt_einsum + opt-einsum tad-mctc torch python_requires = >=3.8 diff --git a/src/tad_dftd3/model.py b/src/tad_dftd3/model.py index 2a5dd66..4dde8e8 100644 --- a/src/tad_dftd3/model.py +++ b/src/tad_dftd3/model.py @@ -68,23 +68,22 @@ def atomic_c6(numbers: Tensor, weights: Tensor, reference: Reference) -> Tensor: Atomic dispersion coefficients of shape `(..., nat, nat)`. """ # (..., nel, nel, 7, 7) -> (..., nat, nat, 7, 7) - c6 = reference.c6[numbers.unsqueeze(-1), numbers.unsqueeze(-2)] + rc6 = reference.c6[numbers.unsqueeze(-1), numbers.unsqueeze(-2)] - # This old version creates large intermediate tensors and builds the full - # matrix before the sum reduction, which requires a lot of memory. - # gw = torch.mul( - # weights.unsqueeze(-1).unsqueeze(-3), - # weights.unsqueeze(-2).unsqueeze(-4), - # ) - # return torch.sum(torch.sum(torch.mul(gw, c6), dim=-1), dim=-1) - - # (..., nat, 7) * (..., nat, 7) * (..., nat, nat, 7, 7) -> (..., nat, nat) + # The default einsum path is fastest if the large tensors comes first. + # (..., n1, n2, r1, r2) * (..., n1, r1) * (..., n2, r2) -> (..., n1, n2) return einsum( - "...ia,...jb,...ijab->...ij", - *(weights, weights, c6), - optimize=[(1, 2), (0, 1)], # fastest path + "...ijab,...ia,...jb->...ij", + *(rc6, weights, weights), + optimize=[(0, 1), (0, 1)], ) + # NOTE: This old version creates large intermediate tensors and builds the + # full matrix before the sum reduction, which requires a lot of memory. + # + # gw = w.unsqueeze(-1).unsqueeze(-3) * w.unsqueeze(-2).unsqueeze(-4) + # c6 = torch.sum(torch.sum(torch.mul(gw, rc6), dim=-1), dim=-1) + def gaussian_weight(dcn: Tensor, factor: float = 4.0) -> Tensor: """