From 19b1720fcc0e9a70ab2e518603c7f862c7a76bd0 Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Wed, 20 Mar 2024 23:07:28 +0100 Subject: [PATCH 1/4] Reduce memory consumption --- src/tad_dftd3/model.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/tad_dftd3/model.py b/src/tad_dftd3/model.py index b7dabc2..7247cc8 100644 --- a/src/tad_dftd3/model.py +++ b/src/tad_dftd3/model.py @@ -54,24 +54,31 @@ def atomic_c6(numbers: Tensor, weights: Tensor, reference: Reference) -> Tensor: Parameters ---------- numbers : Tensor - The atomic numbers of the atoms in the system. + The atomic numbers of the atoms in the system of shape `(..., nat)`. weights : Tensor - Weights of all reference systems. + Weights of all reference systems of shape `(..., nat, 7)`. reference : Reference - Reference systems for D3 model. + Reference systems for D3 model. Contains the reference C6 coefficients + of shape `(..., nelements, nelements, 7, 7)`. Returns ------- Tensor - Atomic dispersion coefficients. + Atomic dispersion coefficients of shape `(..., nat, nat)`. """ - + # (..., nel, nel, 7, 7) -> (..., nat, nat, 7, 7) c6 = reference.c6[numbers.unsqueeze(-1), numbers.unsqueeze(-2)] - gw = torch.mul( - weights.unsqueeze(-1).unsqueeze(-3), weights.unsqueeze(-2).unsqueeze(-4) - ) - return torch.sum(torch.sum(torch.mul(gw, c6), dim=-1), dim=-1) + # This old version creates large intermediate tensors and builds the full + # matrix before the sum reduction, which requires a lot of memory. + # gw = torch.mul( + # weights.unsqueeze(-1).unsqueeze(-3), + # weights.unsqueeze(-2).unsqueeze(-4), + # ) + # return torch.sum(torch.sum(torch.mul(gw, c6), dim=-1), dim=-1) + + # (..., nat, 7) * (..., nat, 7) * (..., nat, nat, 7, 7) -> (..., nat, nat) + return torch.einsum("...ia,...jb,...ijab->...ij", weights, weights, c6) def gaussian_weight(dcn: Tensor, factor: float = 4.0) -> Tensor: From d6e62620d4ca216f246a901ff77c2c6e660557e8 Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Thu, 21 Mar 2024 09:17:53 +0100 Subject: [PATCH 2/4] Use einsum for memory efficiency --- setup.cfg | 1 + src/tad_dftd3/model.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 7401d3d..6c72592 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,7 @@ project_urls = packages = find: install_requires = numpy + opt_einsum tad-mctc torch python_requires = >=3.8 diff --git a/src/tad_dftd3/model.py b/src/tad_dftd3/model.py index 7247cc8..2a5dd66 100644 --- a/src/tad_dftd3/model.py +++ b/src/tad_dftd3/model.py @@ -42,6 +42,7 @@ """ import torch from tad_mctc.batch import real_atoms +from tad_mctc.math import einsum from .reference import Reference from .typing import Any, Tensor, WeightingFunction @@ -78,7 +79,11 @@ def atomic_c6(numbers: Tensor, weights: Tensor, reference: Reference) -> Tensor: # return torch.sum(torch.sum(torch.mul(gw, c6), dim=-1), dim=-1) # (..., nat, 7) * (..., nat, 7) * (..., nat, nat, 7, 7) -> (..., nat, nat) - return torch.einsum("...ia,...jb,...ijab->...ij", weights, weights, c6) + return einsum( + "...ia,...jb,...ijab->...ij", + *(weights, weights, c6), + optimize=[(1, 2), (0, 1)], # fastest path + ) def gaussian_weight(dcn: Tensor, factor: float = 4.0) -> Tensor: From 47e2e6d0e14ea4c2b64da01e124ec5de3633055e Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Thu, 21 Mar 2024 18:34:12 +0100 Subject: [PATCH 3/4] Clean up --- setup.cfg | 2 +- src/tad_dftd3/model.py | 25 ++++++++++++------------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/setup.cfg b/setup.cfg index 6c72592..0d4a5dc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,7 +24,7 @@ project_urls = packages = find: install_requires = numpy - opt_einsum + opt-einsum tad-mctc torch python_requires = >=3.8 diff --git a/src/tad_dftd3/model.py b/src/tad_dftd3/model.py index 2a5dd66..4dde8e8 100644 --- a/src/tad_dftd3/model.py +++ b/src/tad_dftd3/model.py @@ -68,23 +68,22 @@ def atomic_c6(numbers: Tensor, weights: Tensor, reference: Reference) -> Tensor: Atomic dispersion coefficients of shape `(..., nat, nat)`. """ # (..., nel, nel, 7, 7) -> (..., nat, nat, 7, 7) - c6 = reference.c6[numbers.unsqueeze(-1), numbers.unsqueeze(-2)] + rc6 = reference.c6[numbers.unsqueeze(-1), numbers.unsqueeze(-2)] - # This old version creates large intermediate tensors and builds the full - # matrix before the sum reduction, which requires a lot of memory. - # gw = torch.mul( - # weights.unsqueeze(-1).unsqueeze(-3), - # weights.unsqueeze(-2).unsqueeze(-4), - # ) - # return torch.sum(torch.sum(torch.mul(gw, c6), dim=-1), dim=-1) - - # (..., nat, 7) * (..., nat, 7) * (..., nat, nat, 7, 7) -> (..., nat, nat) + # The default einsum path is fastest if the large tensors comes first. + # (..., n1, n2, r1, r2) * (..., n1, r1) * (..., n2, r2) -> (..., n1, n2) return einsum( - "...ia,...jb,...ijab->...ij", - *(weights, weights, c6), - optimize=[(1, 2), (0, 1)], # fastest path + "...ijab,...ia,...jb->...ij", + *(rc6, weights, weights), + optimize=[(0, 1), (0, 1)], ) + # NOTE: This old version creates large intermediate tensors and builds the + # full matrix before the sum reduction, which requires a lot of memory. + # + # gw = w.unsqueeze(-1).unsqueeze(-3) * w.unsqueeze(-2).unsqueeze(-4) + # c6 = torch.sum(torch.sum(torch.mul(gw, rc6), dim=-1), dim=-1) + def gaussian_weight(dcn: Tensor, factor: float = 4.0) -> Tensor: """ From ed53935a79ea975e13fd0c0920ef816b59f51984 Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Thu, 21 Mar 2024 22:31:01 +0100 Subject: [PATCH 4/4] Import MockTensor --- test/test_model/test_reference.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/test/test_model/test_reference.py b/test/test_model/test_reference.py index 99a6148..b92c7c6 100644 --- a/test/test_model/test_reference.py +++ b/test/test_model/test_reference.py @@ -21,6 +21,7 @@ import pytest import torch from tad_mctc.convert import str_to_device +from tad_mctc.typing import MockTensor from tad_dftd3 import reference from tad_dftd3.typing import DD, Any, Tensor, TypedDict @@ -69,16 +70,6 @@ def test_reference_device(device_str: str, device_str2: str) -> None: def test_reference_different_devices() -> None: - # Custom Tensor class with overridable device property - class MockTensor(Tensor): - @property - def device(self) -> Any: - return self._device - - @device.setter - def device(self, value: Any) -> None: - self._device = value - # Custom mock functions def mock_load_cn(*_: Any, **__: Any) -> Tensor: tensor = MockTensor([1, 2, 3])