Skip to content

Commit

Permalink
Fix handling of very large CNs in weights (#50)
Browse files Browse the repository at this point in the history
For very large coordination numbers (larger than the provided in the
reference systems), the weights become zero because of the exponential
in the weighting function. The Fortran code has some logic for handling
this edge case, which was added to this project too.
  • Loading branch information
marvinfriede authored Mar 28, 2024
1 parent 080ff84 commit df8963c
Show file tree
Hide file tree
Showing 11 changed files with 333 additions and 17 deletions.
1 change: 1 addition & 0 deletions docs/modules/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ The following modules are contained with `tad-dftd3`.
defaults
disp
model
ncoord/index
reference
typing/index
3 changes: 3 additions & 0 deletions docs/modules/ncoord/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.. _ncoord:

.. automodule:: tad_dftd3.ncoord
15 changes: 14 additions & 1 deletion src/tad_dftd3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@
"""
import torch

from . import damping, data, defaults, disp, model, reference, typing
from . import damping, data, defaults, disp, model, ncoord, reference, typing
from .__version__ import __version__
from .disp import dftd3

__alll__ = [
"dftd3",
"damping",
"data",
"defaults",
"disp",
"model",
"ncoord",
"reference",
"typing",
"__version__",
]
14 changes: 14 additions & 0 deletions src/tad_dftd3/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@
This module defines the default values for all parameters within DFT-D3.
"""

__all__ = [
"D3_CN_CUTOFF",
"D3_DISP_CUTOFF",
"D3_KCN",
"A1",
"A2",
"S6",
"S8",
"S9",
"RS9",
"ALP",
"MAX_ELEMENT",
]

# DFT-D3

D3_CN_CUTOFF = 25.0
Expand Down
6 changes: 4 additions & 2 deletions src/tad_dftd3/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@
from typing import Dict, Optional

import torch
from tad_mctc import ncoord, storch
from tad_mctc import storch
from tad_mctc.batch import real_pairs
from tad_mctc.data import pse

from . import data, defaults, model
from . import data, defaults, model, ncoord
from .damping import dispersion_atm, rational_damping
from .reference import Reference
from .typing import (
Expand All @@ -71,6 +71,8 @@
WeightingFunction,
)

__all__ = ["dftd3", "dispersion", "dispersion2", "dispersion3"]


def dftd3(
numbers: Tensor,
Expand Down
54 changes: 41 additions & 13 deletions src/tad_dftd3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@
[ 5.4368822, 3.0930154, 3.0930154]], dtype=torch.float64)
"""
import torch
from tad_mctc.batch import real_atoms
from tad_mctc.math import einsum

from .reference import Reference
from .typing import Any, Tensor, WeightingFunction

__all__ = ["atomic_c6", "gaussian_weight", "weight_references"]


def atomic_c6(numbers: Tensor, weights: Tensor, reference: Reference) -> Tensor:
"""
Expand Down Expand Up @@ -131,8 +132,12 @@ def weight_references(
Tensor
Weights of all reference systems
"""
refcn = reference.cn[numbers]
mask = refcn >= 0

mask = reference.cn[numbers] >= 0
zero = torch.tensor(0.0, device=cn.device, dtype=cn.dtype)
zero_double = torch.tensor(0.0, device=cn.device, dtype=torch.double)
one = torch.tensor(1.0, device=cn.device, dtype=cn.dtype)

# Due to the exponentiation, `norms` and `weights` may become very small.
# This may cause problems for the division by `norms`. It may occur that
Expand All @@ -148,23 +153,46 @@ def weight_references(
weights = torch.where(
mask,
weighting_function(dcn, **kwargs),
torch.tensor(0.0, device=dcn.device, dtype=dcn.dtype), # not eps!
zero_double, # not eps!
)

# Nevertheless, we must avoid zero division here in batched calculations.
#
# Previously, a small value was added to `norms` to prevent division by zero
# (`norms = torch.add(torch.sum(weights, dim=-1), 1e-20)`). However, even
# such small values can lead to relatively large deviations because the
# small value is not added to the weights, and hence, the case where
# `weights` and `norms` are equal does not yield one anymore. In fact, the
# test suite fails because some elements deviate up to around 1e-4.
#
# We solve this issue by using a mask from the atoms and only add a small
# value, where the actual padding zeros are.
norms = torch.where(
real_atoms(numbers),
torch.sum(weights, dim=-1),
torch.tensor(torch.finfo(dcn.dtype).eps, device=cn.device, dtype=dcn.dtype),
# We solve this by running in double precision, adding a very small number
# and using multiple masks.

# normalize weights
norm = torch.where(
mask,
torch.sum(weights, dim=-1, keepdim=True),
torch.tensor(1e-300, device=cn.device, dtype=torch.double), # double!
)

# back to real dtype
gw_temp = (weights / norm).type(cn.dtype)

# The following section handles cases with large CNs that lead to zeros in
# after the exponential in the weighting function. If this happens all
# weights become zero, which is not desired. Instead, we set the weight of
# the largest reference number to one.
# This case can occur if the CN of the current (actual) system is too far
# away from the largest CN of the reference systems. An example would be an
# atom within a fullerene (La3N@C80).

# maximum reference CN for each atom
maxcn = torch.max(refcn, dim=-1, keepdim=True)[0]

# prevent division by 0 and small values
exceptional = (torch.isnan(gw_temp)) | (gw_temp > torch.finfo(cn.dtype).max)

gw = torch.where(
exceptional,
torch.where(refcn == maxcn, one, zero),
gw_temp,
)
return (weights / norms.unsqueeze(-1)).type(cn.dtype)

return torch.where(mask, gw, zero)
26 changes: 26 additions & 0 deletions src/tad_dftd3/ncoord/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# This file is part of tad-dftd3.
# SPDX-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Coordination Number
===================
Functions for calculating the D3 coordination numbers.
Only exported for convenience.
"""

from tad_mctc.ncoord.count import exp_count
from tad_mctc.ncoord.d3 import cn_d3

__all__ = ["exp_count", "cn_d3"]
2 changes: 2 additions & 0 deletions src/tad_dftd3/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

from .typing import Any, NoReturn, Tensor, get_default_device, get_default_dtype

__all__ = ["Reference"]


def _load_cn(
dtype: torch.dtype = torch.double, device: Optional[torch.device] = None
Expand Down
109 changes: 109 additions & 0 deletions test/test_disp/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,115 @@ class Record(Molecule, Refs):
),
}
),
"La3N@C80": Refs(
{
"cn": torch.tensor(
[],
dtype=torch.double,
),
"weights": torch.tensor(
[],
dtype=torch.double,
),
"c6": torch.tensor(
[],
dtype=torch.double,
),
"disp2": torch.tensor(
[
-6.4568698147826646e-003,
-6.4559561239969799e-003,
-6.4564281797744585e-003,
-2.7360474586652791e-003,
-1.7407093093953240e-003,
-1.8301394258209106e-003,
-1.8524936502264853e-003,
-1.7350547435936382e-003,
-1.6338086590634386e-003,
-1.5755111668016490e-003,
-1.5618576612617284e-003,
-1.6147968576847084e-003,
-1.7733089538039231e-003,
-1.6245203604557511e-003,
-1.6209618005004513e-003,
-1.7599254182297916e-003,
-1.7369678621080445e-003,
-1.8528080133639840e-003,
-1.9413055642552414e-003,
-1.8525635860998158e-003,
-1.8297098714605152e-003,
-1.7566218566864807e-003,
-1.6202382184123294e-003,
-1.5462201695356063e-003,
-1.5084481213619406e-003,
-1.5140452587746691e-003,
-1.5390048264384269e-003,
-1.5981780755403895e-003,
-1.6506427677755436e-003,
-1.6296629464721212e-003,
-1.5795214054885784e-003,
-1.5089651771174383e-003,
-1.5471741156195414e-003,
-1.5758460954735725e-003,
-1.6340729580740559e-003,
-1.7411978969490475e-003,
-1.8300611979096514e-003,
-1.7577001063579932e-003,
-1.7359143119427235e-003,
-1.6337607742643375e-003,
-1.7409640835067410e-003,
-1.7727788519513107e-003,
-1.6242019237858963e-003,
-1.5788567024896301e-003,
-1.5802817234666848e-003,
-1.6505243428138216e-003,
-1.7989840061950748e-003,
-1.7772803360684576e-003,
-1.6143157287502792e-003,
-1.6293179770636760e-003,
-1.5792498551407159e-003,
-1.5087992421745076e-003,
-1.5140439565480947e-003,
-1.5390983057429895e-003,
-1.5985396168404411e-003,
-1.7439931054253181e-003,
-1.7776307518328823e-003,
-1.7991525034320502e-003,
-1.8968763791053463e-003,
-1.7986664231541786e-003,
-1.6494582261534513e-003,
-1.7416254195092209e-003,
-1.5976268653784852e-003,
-1.6494603571438624e-003,
-1.6149177685073563e-003,
-1.7726437418526361e-003,
-1.6240595135519110e-003,
-1.6205535517789524e-003,
-1.5468593370110194e-003,
-1.5757631710259304e-003,
-1.5621392033851348e-003,
-1.6153080590843665e-003,
-1.6506519485602186e-003,
-1.6295782927553905e-003,
-1.6146709380519242e-003,
-1.5808631439379914e-003,
-1.6511668531186460e-003,
-1.7446879237905224e-003,
-1.7773764100057496e-003,
-1.6142164156269839e-003,
-1.5801412516061992e-003,
-1.5135963042277999e-003,
-1.5384743970934046e-003,
-1.5619345437512383e-003,
],
dtype=torch.double,
),
"disp3": torch.tensor(
[],
dtype=torch.double,
),
}
),
"AmF3": Refs(
{
"cn": torch.tensor(
Expand Down
2 changes: 1 addition & 1 deletion test/test_disp/test_dftd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import pytest
import torch
from tad_mctc.batch import pack
from tad_mctc.ncoord import exp_count

from tad_dftd3 import damping, data, dftd3, model, reference
from tad_dftd3.ncoord import exp_count
from tad_dftd3.typing import DD

from ..conftest import DEVICE
Expand Down
Loading

0 comments on commit df8963c

Please sign in to comment.