Skip to content

Commit

Permalink
Fix random NaN's for very small distances
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Apr 22, 2024
1 parent 907dd4c commit 2211df4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tad_dftd3 as d3

numbers = mctc.convert.symbol_to_number(symbols="C C C C N C S H H H H H".split())
positions = torch.Tensor(
positions = torch.tensor(
[
[-2.56745685564671, -0.02509985979910, 0.00000000000000],
[-1.39177582455797, +2.27696188880014, 0.00000000000000],
Expand Down
8 changes: 6 additions & 2 deletions src/tad_dftd3/model/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

from ..reference import Reference
from ..typing import Any, Tensor, WeightingFunction
from tad_mctc import storch

__all__ = ["gaussian_weight", "weight_references"]

Expand Down Expand Up @@ -129,15 +130,18 @@ def weight_references(
# We solve this by running in double precision, adding a very small number
# and using multiple masks.

small = torch.tensor(1e-300, device=cn.device, dtype=torch.double)

# normalize weights
norm = torch.where(
mask,
torch.sum(weights, dim=-1, keepdim=True),
torch.tensor(1e-300, device=cn.device, dtype=torch.double), # double!
small, # double!
)

# back to real dtype
gw_temp = (weights / norm).type(cn.dtype)
# gw_temp = (storch.divide(weights, norm, eps=small)).type(cn.dtype)
gw_temp = storch.divide(weights, norm, eps=small).type(cn.dtype)

# The following section handles cases with large CNs that lead to zeros in
# after the exponential in the weighting function. If this happens all
Expand Down

0 comments on commit 2211df4

Please sign in to comment.