Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Mar 21, 2024
1 parent fe37ee8 commit 63881c4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ project_urls =
packages = find:
install_requires =
numpy
opt_einsum
opt-einsum
tad-mctc
torch
python_requires = >=3.8
Expand Down
25 changes: 12 additions & 13 deletions src/tad_dftd3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,22 @@ def atomic_c6(numbers: Tensor, weights: Tensor, reference: Reference) -> Tensor:
Atomic dispersion coefficients of shape `(..., nat, nat)`.
"""
# (..., nel, nel, 7, 7) -> (..., nat, nat, 7, 7)
c6 = reference.c6[numbers.unsqueeze(-1), numbers.unsqueeze(-2)]
rc6 = reference.c6[numbers.unsqueeze(-1), numbers.unsqueeze(-2)]

# This old version creates large intermediate tensors and builds the full
# matrix before the sum reduction, which requires a lot of memory.
# gw = torch.mul(
# weights.unsqueeze(-1).unsqueeze(-3),
# weights.unsqueeze(-2).unsqueeze(-4),
# )
# return torch.sum(torch.sum(torch.mul(gw, c6), dim=-1), dim=-1)

# (..., nat, 7) * (..., nat, 7) * (..., nat, nat, 7, 7) -> (..., nat, nat)
# The default einsum path is fastest if the large tensors comes first.
# (..., n1, n2, r1, r2) * (..., n1, r1) * (..., n2, r2) -> (..., n1, n2)
return einsum(
"...ia,...jb,...ijab->...ij",
*(weights, weights, c6),
optimize=[(1, 2), (0, 1)], # fastest path
"...ijab,...ia,...jb->...ij",
*(rc6, weights, weights),
optimize=[(0, 1), (0, 1)],
)

# NOTE: This old version creates large intermediate tensors and builds the
# full matrix before the sum reduction, which requires a lot of memory.
#
# gw = w.unsqueeze(-1).unsqueeze(-3) * w.unsqueeze(-2).unsqueeze(-4)
# c6 = torch.sum(torch.sum(torch.mul(gw, rc6), dim=-1), dim=-1)


def gaussian_weight(dcn: Tensor, factor: float = 4.0) -> Tensor:
"""
Expand Down

0 comments on commit 63881c4

Please sign in to comment.