diff --git a/docs/modules/index.rst b/docs/modules/index.rst index 7f1ce71..d47c366 100644 --- a/docs/modules/index.rst +++ b/docs/modules/index.rst @@ -12,5 +12,6 @@ The following modules are contained with `tad-dftd3`. defaults disp model + ncoord/index reference typing/index diff --git a/docs/modules/ncoord/index.rst b/docs/modules/ncoord/index.rst new file mode 100644 index 0000000..9204ede --- /dev/null +++ b/docs/modules/ncoord/index.rst @@ -0,0 +1,3 @@ +.. _ncoord: + +.. automodule:: tad_dftd3.ncoord diff --git a/src/tad_dftd3/__init__.py b/src/tad_dftd3/__init__.py index afa95fc..d81136d 100644 --- a/src/tad_dftd3/__init__.py +++ b/src/tad_dftd3/__init__.py @@ -75,6 +75,19 @@ """ import torch -from . import damping, data, defaults, disp, model, reference, typing +from . import damping, data, defaults, disp, model, ncoord, reference, typing from .__version__ import __version__ from .disp import dftd3 + +__alll__ = [ + "dftd3", + "damping", + "data", + "defaults", + "disp", + "model", + "ncoord", + "reference", + "typing", + "__version__", +] diff --git a/src/tad_dftd3/defaults.py b/src/tad_dftd3/defaults.py index 9c142cc..487e95a 100644 --- a/src/tad_dftd3/defaults.py +++ b/src/tad_dftd3/defaults.py @@ -19,6 +19,20 @@ This module defines the default values for all parameters within DFT-D3. """ +__all__ = [ + "D3_CN_CUTOFF", + "D3_DISP_CUTOFF", + "D3_KCN", + "A1", + "A2", + "S6", + "S8", + "S9", + "RS9", + "ALP", + "MAX_ELEMENT", +] + # DFT-D3 D3_CN_CUTOFF = 25.0 diff --git a/src/tad_dftd3/disp.py b/src/tad_dftd3/disp.py index b5d4476..c928a37 100644 --- a/src/tad_dftd3/disp.py +++ b/src/tad_dftd3/disp.py @@ -55,11 +55,11 @@ from typing import Dict, Optional import torch -from tad_mctc import ncoord, storch +from tad_mctc import storch from tad_mctc.batch import real_pairs from tad_mctc.data import pse -from . import data, defaults, model +from . import data, defaults, model, ncoord from .damping import dispersion_atm, rational_damping from .reference import Reference from .typing import ( @@ -71,6 +71,8 @@ WeightingFunction, ) +__all__ = ["dftd3", "dispersion", "dispersion2", "dispersion3"] + def dftd3( numbers: Tensor, diff --git a/src/tad_dftd3/model.py b/src/tad_dftd3/model.py index 4dde8e8..e8f1559 100644 --- a/src/tad_dftd3/model.py +++ b/src/tad_dftd3/model.py @@ -41,12 +41,13 @@ [ 5.4368822, 3.0930154, 3.0930154]], dtype=torch.float64) """ import torch -from tad_mctc.batch import real_atoms from tad_mctc.math import einsum from .reference import Reference from .typing import Any, Tensor, WeightingFunction +__all__ = ["atomic_c6", "gaussian_weight", "weight_references"] + def atomic_c6(numbers: Tensor, weights: Tensor, reference: Reference) -> Tensor: """ @@ -131,8 +132,12 @@ def weight_references( Tensor Weights of all reference systems """ + refcn = reference.cn[numbers] + mask = refcn >= 0 - mask = reference.cn[numbers] >= 0 + zero = torch.tensor(0.0, device=cn.device, dtype=cn.dtype) + zero_double = torch.tensor(0.0, device=cn.device, dtype=torch.double) + one = torch.tensor(1.0, device=cn.device, dtype=cn.dtype) # Due to the exponentiation, `norms` and `weights` may become very small. # This may cause problems for the division by `norms`. It may occur that @@ -148,23 +153,46 @@ def weight_references( weights = torch.where( mask, weighting_function(dcn, **kwargs), - torch.tensor(0.0, device=dcn.device, dtype=dcn.dtype), # not eps! + zero_double, # not eps! ) - # Nevertheless, we must avoid zero division here in batched calculations. - # # Previously, a small value was added to `norms` to prevent division by zero # (`norms = torch.add(torch.sum(weights, dim=-1), 1e-20)`). However, even # such small values can lead to relatively large deviations because the # small value is not added to the weights, and hence, the case where # `weights` and `norms` are equal does not yield one anymore. In fact, the # test suite fails because some elements deviate up to around 1e-4. - # - # We solve this issue by using a mask from the atoms and only add a small - # value, where the actual padding zeros are. - norms = torch.where( - real_atoms(numbers), - torch.sum(weights, dim=-1), - torch.tensor(torch.finfo(dcn.dtype).eps, device=cn.device, dtype=dcn.dtype), + # We solve this by running in double precision, adding a very small number + # and using multiple masks. + + # normalize weights + norm = torch.where( + mask, + torch.sum(weights, dim=-1, keepdim=True), + torch.tensor(1e-300, device=cn.device, dtype=torch.double), # double! + ) + + # back to real dtype + gw_temp = (weights / norm).type(cn.dtype) + + # The following section handles cases with large CNs that lead to zeros in + # after the exponential in the weighting function. If this happens all + # weights become zero, which is not desired. Instead, we set the weight of + # the largest reference number to one. + # This case can occur if the CN of the current (actual) system is too far + # away from the largest CN of the reference systems. An example would be an + # atom within a fullerene (La3N@C80). + + # maximum reference CN for each atom + maxcn = torch.max(refcn, dim=-1, keepdim=True)[0] + + # prevent division by 0 and small values + exceptional = (torch.isnan(gw_temp)) | (gw_temp > torch.finfo(cn.dtype).max) + + gw = torch.where( + exceptional, + torch.where(refcn == maxcn, one, zero), + gw_temp, ) - return (weights / norms.unsqueeze(-1)).type(cn.dtype) + + return torch.where(mask, gw, zero) diff --git a/src/tad_dftd3/ncoord/__init__.py b/src/tad_dftd3/ncoord/__init__.py new file mode 100644 index 0000000..f1c407e --- /dev/null +++ b/src/tad_dftd3/ncoord/__init__.py @@ -0,0 +1,26 @@ +# This file is part of tad-dftd3. +# SPDX-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Coordination Number +=================== + +Functions for calculating the D3 coordination numbers. +Only exported for convenience. +""" + +from tad_mctc.ncoord.count import exp_count +from tad_mctc.ncoord.d3 import cn_d3 + +__all__ = ["exp_count", "cn_d3"] diff --git a/src/tad_dftd3/reference.py b/src/tad_dftd3/reference.py index b6a8379..5c71fea 100644 --- a/src/tad_dftd3/reference.py +++ b/src/tad_dftd3/reference.py @@ -26,6 +26,8 @@ from .typing import Any, NoReturn, Tensor, get_default_device, get_default_dtype +__all__ = ["Reference"] + def _load_cn( dtype: torch.dtype = torch.double, device: Optional[torch.device] = None diff --git a/test/test_disp/samples.py b/test/test_disp/samples.py index 58c956b..f9f93ef 100644 --- a/test/test_disp/samples.py +++ b/test/test_disp/samples.py @@ -1544,6 +1544,115 @@ class Record(Molecule, Refs): ), } ), + "La3N@C80": Refs( + { + "cn": torch.tensor( + [], + dtype=torch.double, + ), + "weights": torch.tensor( + [], + dtype=torch.double, + ), + "c6": torch.tensor( + [], + dtype=torch.double, + ), + "disp2": torch.tensor( + [ + -6.4568698147826646e-003, + -6.4559561239969799e-003, + -6.4564281797744585e-003, + -2.7360474586652791e-003, + -1.7407093093953240e-003, + -1.8301394258209106e-003, + -1.8524936502264853e-003, + -1.7350547435936382e-003, + -1.6338086590634386e-003, + -1.5755111668016490e-003, + -1.5618576612617284e-003, + -1.6147968576847084e-003, + -1.7733089538039231e-003, + -1.6245203604557511e-003, + -1.6209618005004513e-003, + -1.7599254182297916e-003, + -1.7369678621080445e-003, + -1.8528080133639840e-003, + -1.9413055642552414e-003, + -1.8525635860998158e-003, + -1.8297098714605152e-003, + -1.7566218566864807e-003, + -1.6202382184123294e-003, + -1.5462201695356063e-003, + -1.5084481213619406e-003, + -1.5140452587746691e-003, + -1.5390048264384269e-003, + -1.5981780755403895e-003, + -1.6506427677755436e-003, + -1.6296629464721212e-003, + -1.5795214054885784e-003, + -1.5089651771174383e-003, + -1.5471741156195414e-003, + -1.5758460954735725e-003, + -1.6340729580740559e-003, + -1.7411978969490475e-003, + -1.8300611979096514e-003, + -1.7577001063579932e-003, + -1.7359143119427235e-003, + -1.6337607742643375e-003, + -1.7409640835067410e-003, + -1.7727788519513107e-003, + -1.6242019237858963e-003, + -1.5788567024896301e-003, + -1.5802817234666848e-003, + -1.6505243428138216e-003, + -1.7989840061950748e-003, + -1.7772803360684576e-003, + -1.6143157287502792e-003, + -1.6293179770636760e-003, + -1.5792498551407159e-003, + -1.5087992421745076e-003, + -1.5140439565480947e-003, + -1.5390983057429895e-003, + -1.5985396168404411e-003, + -1.7439931054253181e-003, + -1.7776307518328823e-003, + -1.7991525034320502e-003, + -1.8968763791053463e-003, + -1.7986664231541786e-003, + -1.6494582261534513e-003, + -1.7416254195092209e-003, + -1.5976268653784852e-003, + -1.6494603571438624e-003, + -1.6149177685073563e-003, + -1.7726437418526361e-003, + -1.6240595135519110e-003, + -1.6205535517789524e-003, + -1.5468593370110194e-003, + -1.5757631710259304e-003, + -1.5621392033851348e-003, + -1.6153080590843665e-003, + -1.6506519485602186e-003, + -1.6295782927553905e-003, + -1.6146709380519242e-003, + -1.5808631439379914e-003, + -1.6511668531186460e-003, + -1.7446879237905224e-003, + -1.7773764100057496e-003, + -1.6142164156269839e-003, + -1.5801412516061992e-003, + -1.5135963042277999e-003, + -1.5384743970934046e-003, + -1.5619345437512383e-003, + ], + dtype=torch.double, + ), + "disp3": torch.tensor( + [], + dtype=torch.double, + ), + } + ), "AmF3": Refs( { "cn": torch.tensor( diff --git a/test/test_disp/test_dftd3.py b/test/test_disp/test_dftd3.py index e43d962..778f9c8 100644 --- a/test/test_disp/test_dftd3.py +++ b/test/test_disp/test_dftd3.py @@ -18,9 +18,9 @@ import pytest import torch from tad_mctc.batch import pack -from tad_mctc.ncoord import exp_count from tad_dftd3 import damping, data, dftd3, model, reference +from tad_dftd3.ncoord import exp_count from tad_dftd3.typing import DD from ..conftest import DEVICE diff --git a/test/test_disp/test_special.py b/test/test_disp/test_special.py new file mode 100644 index 0000000..2cee39b --- /dev/null +++ b/test/test_disp/test_special.py @@ -0,0 +1,118 @@ +# This file is part of tad-dftd3. +# SPDX-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test calculation of dispersion energy for a system, which fail without the +weird handling of exceptional values in the calculation of the weights. +""" +import pytest +import torch +from tad_mctc.batch import pack + +from tad_dftd3 import damping, data, dftd3, model, reference +from tad_dftd3.ncoord import exp_count +from tad_dftd3.typing import DD + +from ..conftest import DEVICE +from .samples import samples + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +@pytest.mark.parametrize("name", ["La3N@C80"]) +def test_single(dtype: torch.dtype, name: str) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} + + sample = samples[name] + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + ref = sample["disp2"].to(**dd) + + rcov = data.COV_D3.to(**dd)[numbers] + rvdw = data.VDW_D3.to(**dd)[numbers.unsqueeze(-1), numbers.unsqueeze(-2)] + r4r2 = data.R4R2.to(**dd)[numbers] + cutoff = torch.tensor(50, **dd) + + # GFN1-xTB parameters + param = { + "s6": torch.tensor(1.0000, **dd), + "s8": torch.tensor(2.4000, **dd), + "s9": torch.tensor(0.0000, **dd), + "alp": torch.tensor(14.00, **dd), + "a1": torch.tensor(0.6300, **dd), + "a2": torch.tensor(5.0000, **dd), + } + + energy = dftd3( + numbers, + positions, + param, + ref=reference.Reference(**dd), + rcov=rcov, + rvdw=rvdw, + r4r2=r4r2, + cutoff=cutoff, + counting_function=exp_count, + weighting_function=model.gaussian_weight, + damping_function=damping.rational_damping, + ) + + assert energy.dtype == dtype + assert pytest.approx(ref.cpu()) == energy.cpu() + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_batch(dtype: torch.dtype) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} + + sample1, sample2 = (samples["LiH"], samples["La3N@C80"]) + numbers = pack( + ( + sample1["numbers"].to(DEVICE), + sample2["numbers"].to(DEVICE), + ) + ) + positions = pack( + ( + sample1["positions"].to(**dd), + sample2["positions"].to(**dd), + ) + ) + ref = pack( + ( + torch.tensor( + [ + -4.1054019506089849e-05, + -4.1054019506089849e-05, + ], + **dd + ), + sample2["disp2"].to(**dd), + ) + ) + + # GFN1-xTB parameters + param = { + "s6": torch.tensor(1.0000, **dd), + "s8": torch.tensor(2.4000, **dd), + "s9": torch.tensor(0.0000, **dd), + "alp": torch.tensor(14.00, **dd), + "a1": torch.tensor(0.6300, **dd), + "a2": torch.tensor(5.0000, **dd), + } + + energy = dftd3(numbers, positions, param) + print(energy.sum(-1)) + print(ref.sum(-1)) + assert energy.dtype == dtype + assert pytest.approx(ref.cpu()) == energy.cpu()