Skip to content

Commit

Permalink
Chunked C6 for Memory Efficiency (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede authored Apr 26, 2024
1 parent a72e385 commit 9dfd874
Show file tree
Hide file tree
Showing 10 changed files with 649 additions and 68 deletions.
2 changes: 1 addition & 1 deletion examples/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tad_dftd3 as d3

numbers = mctc.convert.symbol_to_number(symbols="C C C C N C S H H H H H".split())
positions = torch.Tensor(
positions = torch.tensor(
[
[-2.56745685564671, -0.02509985979910, 0.00000000000000],
[-1.39177582455797, +2.27696188880014, 0.00000000000000],
Expand Down
34 changes: 19 additions & 15 deletions src/tad_dftd3/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
>>> print(torch.sum(energy[0] - energy[1] - energy[2])) # energy in Hartree
tensor(-0.0003964, dtype=torch.float64)
"""
from typing import Dict, Optional
from __future__ import annotations

import torch
from tad_mctc import storch
Expand All @@ -77,16 +77,17 @@
def dftd3(
numbers: Tensor,
positions: Tensor,
param: Dict[str, Tensor],
param: dict[str, Tensor],
*,
ref: Optional[Reference] = None,
rcov: Optional[Tensor] = None,
rvdw: Optional[Tensor] = None,
r4r2: Optional[Tensor] = None,
cutoff: Optional[Tensor] = None,
ref: Reference | None = None,
rcov: Tensor | None = None,
rvdw: Tensor | None = None,
r4r2: Tensor | None = None,
cutoff: Tensor | None = None,
counting_function: CountingFunction = ncoord.exp_count,
weighting_function: WeightingFunction = model.gaussian_weight,
damping_function: DampingFunction = rational_damping,
chunk_size: int | None = None,
) -> Tensor:
"""
Evaluate DFT-D3 dispersion energy for a batch of geometries.
Expand All @@ -113,6 +114,9 @@ def dftd3(
Function to calculate weight of individual reference systems.
counting_function : Callable, optional
Calculates counting value in range 0 to 1 for each atom pair.
chunk_size : int, optional
Chunk size for chunked computation of huge tensors that otherwise
create memory bottlenecks.
Returns
-------
Expand Down Expand Up @@ -142,7 +146,7 @@ def dftd3(
numbers, positions, counting_function=counting_function, rcov=rcov
)
weights = model.weight_references(numbers, cn, ref, weighting_function)
c6 = model.atomic_c6(numbers, weights, ref)
c6 = model.atomic_c6(numbers, weights, ref, chunk_size=chunk_size)

return dispersion(
numbers,
Expand All @@ -159,12 +163,12 @@ def dftd3(
def dispersion(
numbers: Tensor,
positions: Tensor,
param: Dict[str, Tensor],
param: dict[str, Tensor],
c6: Tensor,
rvdw: Optional[Tensor] = None,
r4r2: Optional[Tensor] = None,
rvdw: Tensor | None = None,
r4r2: Tensor | None = None,
damping_function: DampingFunction = rational_damping,
cutoff: Optional[Tensor] = None,
cutoff: Tensor | None = None,
**kwargs: Any,
) -> Tensor:
"""
Expand Down Expand Up @@ -210,7 +214,7 @@ def dispersion(
)
if torch.max(numbers) >= defaults.MAX_ELEMENT:
raise ValueError(
f"No D3 parameters available for Z > {defaults.MAX_ELEMENT-1} "
f"No D3 parameters available for Z > {defaults.MAX_ELEMENT - 1} "
f"({pse.Z2S[defaults.MAX_ELEMENT]})."
)

Expand All @@ -232,7 +236,7 @@ def dispersion(
def dispersion2(
numbers: Tensor,
positions: Tensor,
param: Dict[str, Tensor],
param: dict[str, Tensor],
c6: Tensor,
r4r2: Tensor,
damping_function: DampingFunction,
Expand Down Expand Up @@ -292,7 +296,7 @@ def dispersion2(
def dispersion3(
numbers: Tensor,
positions: Tensor,
param: Dict[str, Tensor],
param: dict[str, Tensor],
c6: Tensor,
rvdw: Tensor,
cutoff: Tensor,
Expand Down
45 changes: 45 additions & 0 deletions src/tad_dftd3/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# This file is part of tad-dftd3.
# SPDX-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model: Dispersion model
=======================
Implementation of D3 model to obtain atomic C6 coefficients for a given
geometry.
Examples
--------
>>> import torch
>>> import tad_dftd3 as d3
>>> import tad_mctc as mctc
>>> numbers = mctc.convert.symbol_to_number(["O", "H", "H"])
>>> positions = torch.Tensor([
... [+0.00000000000000, +0.00000000000000, -0.73578586109551],
... [+1.44183152868459, +0.00000000000000, +0.36789293054775],
... [-1.44183152868459, +0.00000000000000, +0.36789293054775],
... ])
>>> ref = d3.reference.Reference()
>>> rcov = d3.data.covalent_rad_d3[numbers]
>>> cn = mctc.ncoord.cn_d3(numbers, positions, rcov=rcov, counting_function=d3.ncoord.exp_count)
>>> weights = d3.model.weight_references(numbers, cn, ref, d3.model.gaussian_weight)
>>> c6 = d3.model.atomic_c6(numbers, weights, ref)
>>> torch.set_printoptions(precision=7)
>>> print(c6)
tensor([[10.4130471, 5.4368822, 5.4368822],
[ 5.4368822, 3.0930154, 3.0930154],
[ 5.4368822, 3.0930154, 3.0930154]], dtype=torch.float64)
"""
from .c6 import *
from .weights import *
Loading

0 comments on commit 9dfd874

Please sign in to comment.