diff --git a/setup.cfg b/setup.cfg index 7401d3d..0d4a5dc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,7 @@ project_urls = packages = find: install_requires = numpy + opt-einsum tad-mctc torch python_requires = >=3.8 diff --git a/src/tad_dftd3/model.py b/src/tad_dftd3/model.py index b7dabc2..4dde8e8 100644 --- a/src/tad_dftd3/model.py +++ b/src/tad_dftd3/model.py @@ -42,6 +42,7 @@ """ import torch from tad_mctc.batch import real_atoms +from tad_mctc.math import einsum from .reference import Reference from .typing import Any, Tensor, WeightingFunction @@ -54,24 +55,34 @@ def atomic_c6(numbers: Tensor, weights: Tensor, reference: Reference) -> Tensor: Parameters ---------- numbers : Tensor - The atomic numbers of the atoms in the system. + The atomic numbers of the atoms in the system of shape `(..., nat)`. weights : Tensor - Weights of all reference systems. + Weights of all reference systems of shape `(..., nat, 7)`. reference : Reference - Reference systems for D3 model. + Reference systems for D3 model. Contains the reference C6 coefficients + of shape `(..., nelements, nelements, 7, 7)`. Returns ------- Tensor - Atomic dispersion coefficients. + Atomic dispersion coefficients of shape `(..., nat, nat)`. """ - - c6 = reference.c6[numbers.unsqueeze(-1), numbers.unsqueeze(-2)] - gw = torch.mul( - weights.unsqueeze(-1).unsqueeze(-3), weights.unsqueeze(-2).unsqueeze(-4) + # (..., nel, nel, 7, 7) -> (..., nat, nat, 7, 7) + rc6 = reference.c6[numbers.unsqueeze(-1), numbers.unsqueeze(-2)] + + # The default einsum path is fastest if the large tensors comes first. + # (..., n1, n2, r1, r2) * (..., n1, r1) * (..., n2, r2) -> (..., n1, n2) + return einsum( + "...ijab,...ia,...jb->...ij", + *(rc6, weights, weights), + optimize=[(0, 1), (0, 1)], ) - return torch.sum(torch.sum(torch.mul(gw, c6), dim=-1), dim=-1) + # NOTE: This old version creates large intermediate tensors and builds the + # full matrix before the sum reduction, which requires a lot of memory. + # + # gw = w.unsqueeze(-1).unsqueeze(-3) * w.unsqueeze(-2).unsqueeze(-4) + # c6 = torch.sum(torch.sum(torch.mul(gw, rc6), dim=-1), dim=-1) def gaussian_weight(dcn: Tensor, factor: float = 4.0) -> Tensor: diff --git a/test/test_model/test_reference.py b/test/test_model/test_reference.py index 99a6148..b92c7c6 100644 --- a/test/test_model/test_reference.py +++ b/test/test_model/test_reference.py @@ -21,6 +21,7 @@ import pytest import torch from tad_mctc.convert import str_to_device +from tad_mctc.typing import MockTensor from tad_dftd3 import reference from tad_dftd3.typing import DD, Any, Tensor, TypedDict @@ -69,16 +70,6 @@ def test_reference_device(device_str: str, device_str2: str) -> None: def test_reference_different_devices() -> None: - # Custom Tensor class with overridable device property - class MockTensor(Tensor): - @property - def device(self) -> Any: - return self._device - - @device.setter - def device(self, value: Any) -> None: - self._device = value - # Custom mock functions def mock_load_cn(*_: Any, **__: Any) -> Tensor: tensor = MockTensor([1, 2, 3])