diff --git a/src/tad_dftd4/disp.py b/src/tad_dftd4/disp.py index 3da78a3..81ae917 100644 --- a/src/tad_dftd4/disp.py +++ b/src/tad_dftd4/disp.py @@ -30,7 +30,7 @@ from tad_mctc import storch from tad_mctc.batch import real_pairs from tad_mctc.ncoord import cn_d4, erf_count -from tad_multicharge.eeq import get_charges +from tad_multicharge import get_eeq_charges from . import data, defaults from .cutoff import Cutoff @@ -113,7 +113,7 @@ def dftd4( if r4r2 is None: r4r2 = data.R4R2.to(**dd)[numbers] if q is None: - q = get_charges(numbers, positions, charge, cutoff=cutoff.cn_eeq) + q = get_eeq_charges(numbers, positions, charge, cutoff=cutoff.cn_eeq) if numbers.shape != positions.shape[:-1]: raise ValueError( diff --git a/test/test_disp/samples.py b/test/test_disp/samples.py index faf1d94..29b8d7c 100644 --- a/test/test_disp/samples.py +++ b/test/test_disp/samples.py @@ -123,22 +123,22 @@ class Record(Molecule, Refs): { "q": torch.tensor( [ - 7.733478452374956e-01, - 1.076268996435148e-01, + +7.733478452374956e-01, + +1.076268996435148e-01, -3.669996418388237e-01, - 4.928336699714377e-02, + +4.928336699714377e-02, -1.833320732359188e-01, - 2.333021537750765e-01, - 6.618377120945702e-02, + +2.333021537750765e-01, + +6.618377120945702e-02, -5.439442982394790e-01, -2.702644018249256e-01, - 2.666190421598861e-01, - 2.627250807290775e-01, + +2.666190421598861e-01, + +2.627250807290775e-01, -7.153145902661326e-02, -3.733008547230057e-01, - 3.845854622327315e-02, + +3.845854622327315e-02, -5.058512350299781e-01, - 5.176772579438197e-01, + +5.176772579438197e-01, ], dtype=torch.float64, ), @@ -211,21 +211,21 @@ class Record(Molecule, Refs): { "q": torch.tensor( [ - 7.383947733600110e-02, + +7.383947733600110e-02, -1.683548174961888e-01, -3.476428218085238e-01, -7.054893587245280e-01, - 7.735482364313262e-01, - 2.302076155262403e-01, - 1.027485077260907e-01, - 9.478181684909791e-02, - 2.442577506263613e-02, - 2.349849530340006e-01, + +7.735482364313262e-01, + +2.302076155262403e-01, + +1.027485077260907e-01, + +9.478181684909791e-02, + +2.442577506263613e-02, + +2.349849530340006e-01, -3.178399496308427e-01, - 6.671128897373533e-01, + +6.671128897373533e-01, -4.781198581208199e-01, - 6.575365749452318e-02, - 1.082591170860242e-01, + +6.575365749452318e-02, + +1.082591170860242e-01, -3.582152405023902e-01, ], dtype=torch.float64, diff --git a/test/test_grad/test_hessian.py b/test/test_grad/test_hessian.py index 933aeb6..cfac104 100644 --- a/test/test_grad/test_hessian.py +++ b/test/test_grad/test_hessian.py @@ -93,7 +93,6 @@ def test_single(dtype: torch.dtype, name: str) -> None: sample["hessian"].to(**dd), torch.Size(2 * (numbers.shape[-1], 3)), ) - print(ref) # variable to be differentiated positions.requires_grad_(True) diff --git a/test/test_model/samples.py b/test/test_model/samples.py index 44d7437..944bf63 100644 --- a/test/test_model/samples.py +++ b/test/test_model/samples.py @@ -30,6 +30,9 @@ class Refs(TypedDict): """Format of reference values.""" + q: Tensor + """EEQ charges.""" + gw: Tensor """ Gaussian weights. Shape must be `(nrefs, natoms)`, which we have to @@ -48,7 +51,14 @@ class Record(Molecule, Refs): refs: dict[str, Refs] = { "LiH": Refs( - { # CN, q + { # CN + "q": torch.tensor( + [ + 3.708714958301688e-01, + -3.708714958301688e-01, + ], + dtype=torch.float64, + ), "gw": reshape_fortran( torch.tensor( [ @@ -75,7 +85,17 @@ class Record(Molecule, Refs): } ), "SiH4": Refs( - { # CN, q + { # CN + "q": torch.tensor( + [ + -8.412842390895063e-02, + 2.103210597723753e-02, + 2.103210597723774e-02, + 2.103210597723764e-02, + 2.103210597723773e-02, + ], + dtype=torch.float64, + ), "gw": reshape_fortran( torch.tensor( [ @@ -143,6 +163,27 @@ class Record(Molecule, Refs): ), "MB16_43_01": Refs( { # CN + "q": torch.tensor( + [ + +7.733478452374956e-01, + +1.076268996435148e-01, + -3.669996418388237e-01, + +4.928336699714377e-02, + -1.833320732359188e-01, + +2.333021537750765e-01, + +6.618377120945702e-02, + -5.439442982394790e-01, + -2.702644018249256e-01, + +2.666190421598861e-01, + +2.627250807290775e-01, + -7.153145902661326e-02, + -3.733008547230057e-01, + +3.845854622327315e-02, + -5.058512350299781e-01, + +5.176772579438197e-01, + ], + dtype=torch.float64, + ), "gw": reshape_fortran( torch.tensor( [ @@ -495,7 +536,28 @@ class Record(Molecule, Refs): } ), "MB16_43_02": Refs( - { # q + { + "q": torch.tensor( + [ + +7.383947733600110e-02, + -1.683548174961888e-01, + -3.476428218085238e-01, + -7.054893587245280e-01, + +7.735482364313262e-01, + +2.302076155262403e-01, + +1.027485077260907e-01, + +9.478181684909791e-02, + +2.442577506263613e-02, + +2.349849530340006e-01, + -3.178399496308427e-01, + +6.671128897373533e-01, + -4.781198581208199e-01, + +6.575365749452318e-02, + +1.082591170860242e-01, + -3.582152405023902e-01, + ], + dtype=torch.float64, + ), "gw": reshape_fortran( torch.tensor( [ @@ -848,7 +910,28 @@ class Record(Molecule, Refs): } ), "MB16_43_03": Refs( - { # CN, q + { # CN + "q": torch.tensor( + [ + -1.7778832703574010e-01, + -8.2294323973571670e-01, + 4.0457879113787724e-02, + 5.7971038082866722e-01, + 6.9960183636529338e-01, + 6.8430976075776473e-02, + -3.4297147449169296e-01, + 4.6495478328605205e-02, + 6.7701246205863264e-02, + 8.4993144140514468e-02, + -5.2228521752048518e-01, + -2.9251488187370783e-01, + -3.9837556749973635e-01, + 2.0976964648102694e-01, + 7.2314045922878123e-01, + 3.6577661388763623e-02, + ], + dtype=torch.float64, + ), "gw": reshape_fortran( torch.tensor( [ diff --git a/test/test_model/test_model.py b/test/test_model/test_model.py index 948798e..9fa5929 100644 --- a/test/test_model/test_model.py +++ b/test/test_model/test_model.py @@ -23,7 +23,6 @@ import torch from tad_mctc.batch import pack from tad_mctc.ncoord import cn_d4 -from tad_multicharge.eeq import get_charges # get rid! from tad_dftd4.model import D4Model from tad_dftd4.typing import DD @@ -44,14 +43,12 @@ def test_single(name: str, dtype: torch.dtype) -> None: sample = samples[name] numbers = sample["numbers"].to(DEVICE) positions = sample["positions"].to(**dd) + q = sample["q"].to(**dd) ref = sample["c6"].to(**dd) d4 = D4Model(numbers, **dd) cn = cn_d4(numbers, positions) - total_charge = torch.tensor(0.0, **dd) - q = get_charges(numbers, positions, total_charge) - gw = d4.weight_references(cn=cn, q=q) c6 = d4.get_atomic_c6(gw) assert pytest.approx(ref.cpu(), abs=tol, rel=tol) == c6.cpu() @@ -77,6 +74,12 @@ def test_batch(name1: str, name2: str, dtype: torch.dtype) -> None: sample2["positions"].to(**dd), ] ) + q = pack( + [ + sample1["q"].to(**dd), + sample2["q"].to(**dd), + ] + ) refs = pack( [ sample1["c6"].to(**dd), @@ -87,9 +90,6 @@ def test_batch(name1: str, name2: str, dtype: torch.dtype) -> None: d4 = D4Model(numbers, **dd) cn = cn_d4(numbers, positions) - total_charge = torch.zeros(numbers.shape[0], **dd) - q = get_charges(numbers, positions, total_charge) - gw = d4.weight_references(cn=cn, q=q) c6 = d4.get_atomic_c6(gw) assert pytest.approx(refs.cpu(), abs=tol, rel=tol) == c6.cpu() diff --git a/test/test_model/test_weights.py b/test/test_model/test_weights.py index 8861c00..fa803b1 100644 --- a/test/test_model/test_weights.py +++ b/test/test_model/test_weights.py @@ -25,7 +25,6 @@ import torch.nn.functional as F from tad_mctc.batch import pack from tad_mctc.ncoord import cn_d4 -from tad_multicharge.eeq import get_charges from tad_dftd4.model import D4Model from tad_dftd4.typing import DD @@ -55,7 +54,7 @@ def single( cn = None # positions.new_zeros(numbers.shape) if with_q is True: - q = get_charges(numbers, positions, torch.tensor(0.0, **dd)) + q = sample["q"].to(**dd) else: q = None # positions.new_zeros(numbers.shape) @@ -120,13 +119,16 @@ def test_batch(name1: str, name2: str, dtype: torch.dtype) -> None: sample2["positions"].to(**dd), ] ) + q = pack( + [ + sample1["q"].to(**dd), + sample2["q"].to(**dd), + ] + ) d4 = D4Model(numbers, **dd) cn = cn_d4(numbers, positions) - total_charge = positions.new_zeros(numbers.shape[0]) - q = get_charges(numbers, positions, total_charge) - gwvec = d4.weight_references(cn, q) # pad reference tensor to always be of shape `(natoms, 7)`