Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of very large CNs in weights #50

Merged
merged 4 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/tad_dftd3/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def dftd3(
numbers, positions, counting_function=counting_function, rcov=rcov
)
weights = model.weight_references(numbers, cn, ref, weighting_function)
print(weights)
c6 = model.atomic_c6(numbers, weights, ref)
print(c6)

return dispersion(
numbers,
Expand Down
52 changes: 40 additions & 12 deletions src/tad_dftd3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
[ 5.4368822, 3.0930154, 3.0930154]], dtype=torch.float64)
"""
import torch
from tad_mctc import storch
Fixed Show fixed Hide fixed
from tad_mctc.batch import real_atoms
from tad_mctc.math import einsum

Expand Down Expand Up @@ -131,8 +132,12 @@
Tensor
Weights of all reference systems
"""
refcn = reference.cn[numbers]
mask = refcn >= 0

mask = reference.cn[numbers] >= 0
zero = torch.tensor(0.0, device=cn.device, dtype=cn.dtype)
zero_double = torch.tensor(0.0, device=cn.device, dtype=torch.double)
one = torch.tensor(1.0, device=cn.device, dtype=cn.dtype)

# Due to the exponentiation, `norms` and `weights` may become very small.
# This may cause problems for the division by `norms`. It may occur that
Expand All @@ -148,23 +153,46 @@
weights = torch.where(
mask,
weighting_function(dcn, **kwargs),
torch.tensor(0.0, device=dcn.device, dtype=dcn.dtype), # not eps!
zero_double, # not eps!
)

# Nevertheless, we must avoid zero division here in batched calculations.
#
# Previously, a small value was added to `norms` to prevent division by zero
# (`norms = torch.add(torch.sum(weights, dim=-1), 1e-20)`). However, even
# such small values can lead to relatively large deviations because the
# small value is not added to the weights, and hence, the case where
# `weights` and `norms` are equal does not yield one anymore. In fact, the
# test suite fails because some elements deviate up to around 1e-4.
#
# We solve this issue by using a mask from the atoms and only add a small
# value, where the actual padding zeros are.
norms = torch.where(
real_atoms(numbers),
torch.sum(weights, dim=-1),
torch.tensor(torch.finfo(dcn.dtype).eps, device=cn.device, dtype=dcn.dtype),
# We solve this by running in double precision, adding a very small number
# and using multiple masks.

# normalize weights
norm = torch.where(
mask,
torch.sum(weights, dim=-1, keepdim=True),
torch.tensor(1e-300, device=cn.device, dtype=torch.double), # double!
)
return (weights / norms.unsqueeze(-1)).type(cn.dtype)

# back to real dtype
gw_temp = (weights / norm).type(cn.dtype)

# The following section handles cases with large CNs that lead to zeros in
# after the exponential in the weighting function. If this happens all
# weights become zero, which is not desired. Instead, we set the weight of
# the largest reference number to one.
# This case can occur if the CN of the current (actual) system is too far
# away from the largest CN of the reference systems. An example would be an
# atom within a fullerene (La3N@C80).

# maximum reference CN for each atom
maxcn = torch.max(refcn, dim=-1, keepdim=True)[0]

# prevent division by 0 and small values
exceptional = (torch.isnan(gw_temp)) | (gw_temp > torch.finfo(cn.dtype).max)

gw = torch.where(
exceptional,
torch.where(refcn == maxcn, one, zero),
gw_temp,
)

return torch.where(mask, gw, zero)
109 changes: 109 additions & 0 deletions test/test_disp/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,115 @@ class Record(Molecule, Refs):
),
}
),
"La3N@C80": Refs(
{
"cn": torch.tensor(
[],
dtype=torch.double,
),
"weights": torch.tensor(
[],
dtype=torch.double,
),
"c6": torch.tensor(
[],
dtype=torch.double,
),
"disp2": torch.tensor(
[
-6.4568698147826646e-003,
-6.4559561239969799e-003,
-6.4564281797744585e-003,
-2.7360474586652791e-003,
-1.7407093093953240e-003,
-1.8301394258209106e-003,
-1.8524936502264853e-003,
-1.7350547435936382e-003,
-1.6338086590634386e-003,
-1.5755111668016490e-003,
-1.5618576612617284e-003,
-1.6147968576847084e-003,
-1.7733089538039231e-003,
-1.6245203604557511e-003,
-1.6209618005004513e-003,
-1.7599254182297916e-003,
-1.7369678621080445e-003,
-1.8528080133639840e-003,
-1.9413055642552414e-003,
-1.8525635860998158e-003,
-1.8297098714605152e-003,
-1.7566218566864807e-003,
-1.6202382184123294e-003,
-1.5462201695356063e-003,
-1.5084481213619406e-003,
-1.5140452587746691e-003,
-1.5390048264384269e-003,
-1.5981780755403895e-003,
-1.6506427677755436e-003,
-1.6296629464721212e-003,
-1.5795214054885784e-003,
-1.5089651771174383e-003,
-1.5471741156195414e-003,
-1.5758460954735725e-003,
-1.6340729580740559e-003,
-1.7411978969490475e-003,
-1.8300611979096514e-003,
-1.7577001063579932e-003,
-1.7359143119427235e-003,
-1.6337607742643375e-003,
-1.7409640835067410e-003,
-1.7727788519513107e-003,
-1.6242019237858963e-003,
-1.5788567024896301e-003,
-1.5802817234666848e-003,
-1.6505243428138216e-003,
-1.7989840061950748e-003,
-1.7772803360684576e-003,
-1.6143157287502792e-003,
-1.6293179770636760e-003,
-1.5792498551407159e-003,
-1.5087992421745076e-003,
-1.5140439565480947e-003,
-1.5390983057429895e-003,
-1.5985396168404411e-003,
-1.7439931054253181e-003,
-1.7776307518328823e-003,
-1.7991525034320502e-003,
-1.8968763791053463e-003,
-1.7986664231541786e-003,
-1.6494582261534513e-003,
-1.7416254195092209e-003,
-1.5976268653784852e-003,
-1.6494603571438624e-003,
-1.6149177685073563e-003,
-1.7726437418526361e-003,
-1.6240595135519110e-003,
-1.6205535517789524e-003,
-1.5468593370110194e-003,
-1.5757631710259304e-003,
-1.5621392033851348e-003,
-1.6153080590843665e-003,
-1.6506519485602186e-003,
-1.6295782927553905e-003,
-1.6146709380519242e-003,
-1.5808631439379914e-003,
-1.6511668531186460e-003,
-1.7446879237905224e-003,
-1.7773764100057496e-003,
-1.6142164156269839e-003,
-1.5801412516061992e-003,
-1.5135963042277999e-003,
-1.5384743970934046e-003,
-1.5619345437512383e-003,
],
dtype=torch.double,
),
"disp3": torch.tensor(
[],
dtype=torch.double,
),
}
),
"AmF3": Refs(
{
"cn": torch.tensor(
Expand Down
118 changes: 118 additions & 0 deletions test/test_disp/test_special.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# This file is part of tad-dftd3.
# SPDX-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test calculation of dispersion energy for a system, which fail without the
weird handling of exceptional values in the calculation of the weights.
"""
import pytest
import torch
from tad_mctc.batch import pack
from tad_mctc.ncoord import exp_count

from tad_dftd3 import damping, data, dftd3, model, reference
from tad_dftd3.typing import DD

from ..conftest import DEVICE
from .samples import samples


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
@pytest.mark.parametrize("name", ["La3N@C80"])
def test_single(dtype: torch.dtype, name: str) -> None:
dd: DD = {"device": DEVICE, "dtype": dtype}

sample = samples[name]
numbers = sample["numbers"].to(DEVICE)
positions = sample["positions"].to(**dd)
ref = sample["disp2"].to(**dd)

rcov = data.COV_D3.to(**dd)[numbers]
rvdw = data.VDW_D3.to(**dd)[numbers.unsqueeze(-1), numbers.unsqueeze(-2)]
r4r2 = data.R4R2.to(**dd)[numbers]
cutoff = torch.tensor(50, **dd)

# GFN1-xTB parameters
param = {
"s6": torch.tensor(1.0000, **dd),
"s8": torch.tensor(2.4000, **dd),
"s9": torch.tensor(0.0000, **dd),
"alp": torch.tensor(14.00, **dd),
"a1": torch.tensor(0.6300, **dd),
"a2": torch.tensor(5.0000, **dd),
}

energy = dftd3(
numbers,
positions,
param,
ref=reference.Reference(**dd),
rcov=rcov,
rvdw=rvdw,
r4r2=r4r2,
cutoff=cutoff,
counting_function=exp_count,
weighting_function=model.gaussian_weight,
damping_function=damping.rational_damping,
)

assert energy.dtype == dtype
assert pytest.approx(ref.cpu()) == energy.cpu()


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_batch(dtype: torch.dtype) -> None:
dd: DD = {"device": DEVICE, "dtype": dtype}

sample1, sample2 = (samples["LiH"], samples["La3N@C80"])
numbers = pack(
(
sample1["numbers"].to(DEVICE),
sample2["numbers"].to(DEVICE),
)
)
positions = pack(
(
sample1["positions"].to(**dd),
sample2["positions"].to(**dd),
)
)
ref = pack(
(
torch.tensor(
[
-4.1054019506089849e-05,
-4.1054019506089849e-05,
],
**dd
),
sample2["disp2"].to(**dd),
)
)

# GFN1-xTB parameters
param = {
"s6": torch.tensor(1.0000, **dd),
"s8": torch.tensor(2.4000, **dd),
"s9": torch.tensor(0.0000, **dd),
"alp": torch.tensor(14.00, **dd),
"a1": torch.tensor(0.6300, **dd),
"a2": torch.tensor(5.0000, **dd),
}

energy = dftd3(numbers, positions, param)
print(energy.sum(-1))
print(ref.sum(-1))
assert energy.dtype == dtype
assert pytest.approx(ref.cpu()) == energy.cpu()
Loading