Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Udpate multicharge API call #41

Merged
merged 1 commit into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/tad_dftd4/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from tad_mctc import storch
from tad_mctc.batch import real_pairs
from tad_mctc.ncoord import cn_d4, erf_count
from tad_multicharge.eeq import get_charges
from tad_multicharge import get_eeq_charges

from . import data, defaults
from .cutoff import Cutoff
Expand Down Expand Up @@ -113,7 +113,7 @@ def dftd4(
if r4r2 is None:
r4r2 = data.R4R2.to(**dd)[numbers]
if q is None:
q = get_charges(numbers, positions, charge, cutoff=cutoff.cn_eeq)
q = get_eeq_charges(numbers, positions, charge, cutoff=cutoff.cn_eeq)

if numbers.shape != positions.shape[:-1]:
raise ValueError(
Expand Down
38 changes: 19 additions & 19 deletions test/test_disp/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,22 +123,22 @@ class Record(Molecule, Refs):
{
"q": torch.tensor(
[
7.733478452374956e-01,
1.076268996435148e-01,
+7.733478452374956e-01,
+1.076268996435148e-01,
-3.669996418388237e-01,
4.928336699714377e-02,
+4.928336699714377e-02,
-1.833320732359188e-01,
2.333021537750765e-01,
6.618377120945702e-02,
+2.333021537750765e-01,
+6.618377120945702e-02,
-5.439442982394790e-01,
-2.702644018249256e-01,
2.666190421598861e-01,
2.627250807290775e-01,
+2.666190421598861e-01,
+2.627250807290775e-01,
-7.153145902661326e-02,
-3.733008547230057e-01,
3.845854622327315e-02,
+3.845854622327315e-02,
-5.058512350299781e-01,
5.176772579438197e-01,
+5.176772579438197e-01,
],
dtype=torch.float64,
),
Expand Down Expand Up @@ -211,21 +211,21 @@ class Record(Molecule, Refs):
{
"q": torch.tensor(
[
7.383947733600110e-02,
+7.383947733600110e-02,
-1.683548174961888e-01,
-3.476428218085238e-01,
-7.054893587245280e-01,
7.735482364313262e-01,
2.302076155262403e-01,
1.027485077260907e-01,
9.478181684909791e-02,
2.442577506263613e-02,
2.349849530340006e-01,
+7.735482364313262e-01,
+2.302076155262403e-01,
+1.027485077260907e-01,
+9.478181684909791e-02,
+2.442577506263613e-02,
+2.349849530340006e-01,
-3.178399496308427e-01,
6.671128897373533e-01,
+6.671128897373533e-01,
-4.781198581208199e-01,
6.575365749452318e-02,
1.082591170860242e-01,
+6.575365749452318e-02,
+1.082591170860242e-01,
-3.582152405023902e-01,
],
dtype=torch.float64,
Expand Down
1 change: 0 additions & 1 deletion test/test_grad/test_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def test_single(dtype: torch.dtype, name: str) -> None:
sample["hessian"].to(**dd),
torch.Size(2 * (numbers.shape[-1], 3)),
)
print(ref)

# variable to be differentiated
positions.requires_grad_(True)
Expand Down
91 changes: 87 additions & 4 deletions test/test_model/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
class Refs(TypedDict):
"""Format of reference values."""

q: Tensor
"""EEQ charges."""

gw: Tensor
"""
Gaussian weights. Shape must be `(nrefs, natoms)`, which we have to
Expand All @@ -48,7 +51,14 @@ class Record(Molecule, Refs):

refs: dict[str, Refs] = {
"LiH": Refs(
{ # CN, q
{ # CN
"q": torch.tensor(
[
3.708714958301688e-01,
-3.708714958301688e-01,
],
dtype=torch.float64,
),
"gw": reshape_fortran(
torch.tensor(
[
Expand All @@ -75,7 +85,17 @@ class Record(Molecule, Refs):
}
),
"SiH4": Refs(
{ # CN, q
{ # CN
"q": torch.tensor(
[
-8.412842390895063e-02,
2.103210597723753e-02,
2.103210597723774e-02,
2.103210597723764e-02,
2.103210597723773e-02,
],
dtype=torch.float64,
),
"gw": reshape_fortran(
torch.tensor(
[
Expand Down Expand Up @@ -143,6 +163,27 @@ class Record(Molecule, Refs):
),
"MB16_43_01": Refs(
{ # CN
"q": torch.tensor(
[
+7.733478452374956e-01,
+1.076268996435148e-01,
-3.669996418388237e-01,
+4.928336699714377e-02,
-1.833320732359188e-01,
+2.333021537750765e-01,
+6.618377120945702e-02,
-5.439442982394790e-01,
-2.702644018249256e-01,
+2.666190421598861e-01,
+2.627250807290775e-01,
-7.153145902661326e-02,
-3.733008547230057e-01,
+3.845854622327315e-02,
-5.058512350299781e-01,
+5.176772579438197e-01,
],
dtype=torch.float64,
),
"gw": reshape_fortran(
torch.tensor(
[
Expand Down Expand Up @@ -495,7 +536,28 @@ class Record(Molecule, Refs):
}
),
"MB16_43_02": Refs(
{ # q
{
"q": torch.tensor(
[
+7.383947733600110e-02,
-1.683548174961888e-01,
-3.476428218085238e-01,
-7.054893587245280e-01,
+7.735482364313262e-01,
+2.302076155262403e-01,
+1.027485077260907e-01,
+9.478181684909791e-02,
+2.442577506263613e-02,
+2.349849530340006e-01,
-3.178399496308427e-01,
+6.671128897373533e-01,
-4.781198581208199e-01,
+6.575365749452318e-02,
+1.082591170860242e-01,
-3.582152405023902e-01,
],
dtype=torch.float64,
),
"gw": reshape_fortran(
torch.tensor(
[
Expand Down Expand Up @@ -848,7 +910,28 @@ class Record(Molecule, Refs):
}
),
"MB16_43_03": Refs(
{ # CN, q
{ # CN
"q": torch.tensor(
[
-1.7778832703574010e-01,
-8.2294323973571670e-01,
4.0457879113787724e-02,
5.7971038082866722e-01,
6.9960183636529338e-01,
6.8430976075776473e-02,
-3.4297147449169296e-01,
4.6495478328605205e-02,
6.7701246205863264e-02,
8.4993144140514468e-02,
-5.2228521752048518e-01,
-2.9251488187370783e-01,
-3.9837556749973635e-01,
2.0976964648102694e-01,
7.2314045922878123e-01,
3.6577661388763623e-02,
],
dtype=torch.float64,
),
"gw": reshape_fortran(
torch.tensor(
[
Expand Down
14 changes: 7 additions & 7 deletions test/test_model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import torch
from tad_mctc.batch import pack
from tad_mctc.ncoord import cn_d4
from tad_multicharge.eeq import get_charges # get rid!

from tad_dftd4.model import D4Model
from tad_dftd4.typing import DD
Expand All @@ -44,14 +43,12 @@ def test_single(name: str, dtype: torch.dtype) -> None:
sample = samples[name]
numbers = sample["numbers"].to(DEVICE)
positions = sample["positions"].to(**dd)
q = sample["q"].to(**dd)
ref = sample["c6"].to(**dd)

d4 = D4Model(numbers, **dd)

cn = cn_d4(numbers, positions)
total_charge = torch.tensor(0.0, **dd)
q = get_charges(numbers, positions, total_charge)

gw = d4.weight_references(cn=cn, q=q)
c6 = d4.get_atomic_c6(gw)
assert pytest.approx(ref.cpu(), abs=tol, rel=tol) == c6.cpu()
Expand All @@ -77,6 +74,12 @@ def test_batch(name1: str, name2: str, dtype: torch.dtype) -> None:
sample2["positions"].to(**dd),
]
)
q = pack(
[
sample1["q"].to(**dd),
sample2["q"].to(**dd),
]
)
refs = pack(
[
sample1["c6"].to(**dd),
Expand All @@ -87,9 +90,6 @@ def test_batch(name1: str, name2: str, dtype: torch.dtype) -> None:
d4 = D4Model(numbers, **dd)

cn = cn_d4(numbers, positions)
total_charge = torch.zeros(numbers.shape[0], **dd)
q = get_charges(numbers, positions, total_charge)

gw = d4.weight_references(cn=cn, q=q)
c6 = d4.get_atomic_c6(gw)
assert pytest.approx(refs.cpu(), abs=tol, rel=tol) == c6.cpu()
12 changes: 7 additions & 5 deletions test/test_model/test_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import torch.nn.functional as F
from tad_mctc.batch import pack
from tad_mctc.ncoord import cn_d4
from tad_multicharge.eeq import get_charges

from tad_dftd4.model import D4Model
from tad_dftd4.typing import DD
Expand Down Expand Up @@ -55,7 +54,7 @@ def single(
cn = None # positions.new_zeros(numbers.shape)

if with_q is True:
q = get_charges(numbers, positions, torch.tensor(0.0, **dd))
q = sample["q"].to(**dd)
else:
q = None # positions.new_zeros(numbers.shape)

Expand Down Expand Up @@ -120,13 +119,16 @@ def test_batch(name1: str, name2: str, dtype: torch.dtype) -> None:
sample2["positions"].to(**dd),
]
)
q = pack(
[
sample1["q"].to(**dd),
sample2["q"].to(**dd),
]
)

d4 = D4Model(numbers, **dd)

cn = cn_d4(numbers, positions)
total_charge = positions.new_zeros(numbers.shape[0])
q = get_charges(numbers, positions, total_charge)

gwvec = d4.weight_references(cn, q)

# pad reference tensor to always be of shape `(natoms, 7)`
Expand Down