Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ZBL potential #134

Merged
merged 13 commits into from
Nov 9, 2022
2 changes: 1 addition & 1 deletion tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_train(model_name, use_atomref, tmpdir):
prior = None
if use_atomref:
prior = getattr(priors, args["prior_model"])(dataset=datamodule.dataset)
args["prior_args"] = prior.get_init_args()
args["prior_init_args"] = prior.get_init_args()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loading of the pretrained models (https://github.com/torchmd/torchmd-net/tree/main/examples#loading-checkpoints) fails:

from torchmdnet.models.model import load_model
load_model('ANI1-equivariant_transformer/epoch=359-val_loss=0.0004-test_loss=0.0120.ckpt')

AssertionError: Requested prior model Atomref but the arguments are lacking the key "prior_init_args".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like we want to redo how prior args are specified as described in #26 (comment). That means we can switch back to prior_args for this value. There will still be compatibility issues, because it will need to become a list with args for multiple prior models, but I can add a check for that case for backward compatibility. I'll go ahead and make the changes in this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be a good idea to add a test case for loading model checkpoints from a previous version.


module = LNNP(args, prior_model=prior)

Expand Down
30 changes: 29 additions & 1 deletion tests/test_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytorch_lightning as pl
from torchmdnet import models
from torchmdnet.models.model import create_model
from torchmdnet.priors import Atomref
from torchmdnet.priors import Atomref, ZBL
from torch_scatter import scatter
from utils import load_example_args, create_example_batch, DummyDataset

Expand All @@ -31,3 +31,31 @@ def test_atomref(model_name):
# check if the output of both models differs by the expected atomref contribution
expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1)
torch.testing.assert_allclose(x_atomref, x_no_atomref + expected_offset)

def test_zbl():
pos = torch.tensor([[1.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, -1.0]], dtype=torch.float32) # Atom positions in Bohr
types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types
atomic_number = torch.tensor([1, 6, 8], dtype=torch.int8) # Mapping of atom types to atomic numbers
distance_scale = 5.29177210903e-11 # Convert Bohr to meters
energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules

# Use the ZBL class to compute the energy.

zbl = ZBL(10.0, 5, atomic_number, distance_scale=distance_scale, energy_scale=energy_scale)
energy = zbl.post_reduce(torch.zeros((1,)), types, pos, torch.zeros(types.shape))[0]

# Compare to the expected value.

def compute_interaction(pos1, pos2, z1, z2):
delta = pos1-pos2
r = torch.sqrt(torch.dot(delta, delta))
x = r / (0.8854/(z1**0.23 + z2**0.23))
phi = 0.1818*torch.exp(-3.2*x) + 0.5099*torch.exp(-0.9423*x) + 0.2802*torch.exp(-0.4029*x) + 0.02817*torch.exp(-0.2016*x)
cutoff = 0.5*(torch.cos(r*torch.pi/10.0) + 1.0)
return cutoff*phi*(138.935/5.29177210903e-2)*z1*z2/r

expected = 0
for i in range(len(pos)):
for j in range(i):
expected += compute_interaction(pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]])
torch.testing.assert_allclose(expected, energy)
33 changes: 20 additions & 13 deletions torchmdnet/datasets/hdf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch_geometric.data import Dataset, Data
import h5py
import numpy as np


class HDF5(Dataset):
Expand All @@ -27,7 +28,12 @@ def __init__(self, filename, **kwargs):
files = [h5py.File(f, "r") for f in self.filename.split(";")]
for file in files:
for group_name in file:
self.num_molecules += len(file[group_name]["energy"])
if group_name == '_metadata':
group = file[group_name]
for name in group:
setattr(self, name, torch.tensor(np.array(group[name])))
else:
self.num_molecules += len(file[group_name]["energy"])
file.close()

def setup_index(self):
Expand All @@ -36,18 +42,19 @@ def setup_index(self):
self.index = []
for file in files:
for group_name in file:
group = file[group_name]
types = group["types"]
pos = group["pos"]
energy = group["energy"]
if "forces" in group:
self.has_forces = True
forces = group["forces"]
for i in range(len(energy)):
self.index.append((types, pos, energy, forces, i))
else:
for i in range(len(energy)):
self.index.append((types, pos, energy, i))
if group_name != '_metadata':
group = file[group_name]
types = group["types"]
pos = group["pos"]
energy = group["energy"]
if "forces" in group:
self.has_forces = True
forces = group["forces"]
for i in range(len(energy)):
self.index.append((types, pos, energy, forces, i))
else:
for i in range(len(energy)):
self.index.append((types, pos, energy, i))

assert self.num_molecules == len(self.index), (
"Mismatch between previously calculated "
Expand Down
6 changes: 3 additions & 3 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,16 @@ def create_model(args, prior_model=None, mean=None, std=None):

# prior model
if args["prior_model"] and prior_model is None:
assert "prior_args" in args, (
assert "prior_init_args" in args, (
f"Requested prior model {args['prior_model']} but the "
f'arguments are lacking the key "prior_args".'
f'arguments are lacking the key "prior_init_args".'
)
assert hasattr(priors, args["prior_model"]), (
f'Unknown prior model {args["prior_model"]}. '
f'Available models are {", ".join(priors.__all__)}'
)
# instantiate prior model if it was not passed to create_model (i.e. when loading a model)
prior_model = getattr(priors, args["prior_model"])(**args["prior_args"])
prior_model = getattr(priors, args["prior_model"])(**args["prior_init_args"])

# create output network
output_prefix = "Equivariant" if is_equivariant else ""
Expand Down
1 change: 1 addition & 0 deletions torchmdnet/priors/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from torchmdnet.priors.atomref import Atomref
from torchmdnet.priors.zbl import ZBL
54 changes: 54 additions & 0 deletions torchmdnet/priors/zbl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
from torchmdnet.priors.base import BasePrior
from torchmdnet.models.utils import Distance, CosineCutoff

class ZBL(BasePrior):
"""This class implements the Ziegler-Biersack-Littmark (ZBL) potential for screened nuclear repulsion.
Is is described in https://doi.org/10.1007/978-3-642-68779-2_5 (equations 9 and 10 on page 147). It
is an empirical potential that does a good job of describing the repulsion between atoms at very short
distances.

To use this prior, the Dataset must provide the following attributes.

atomic_number: 1D tensor of length max_z. atomic_number[z] is the atomic number of atoms with atom type z.
distance_scale: multiply by this factor to convert coordinates stored in the dataset to meters
energy_scale: multiply by this factor to convert energies stored in the dataset to Joules
raimis marked this conversation as resolved.
Show resolved Hide resolved
"""
def __init__(self, cutoff_distance, max_num_neighbors, atomic_number=None, distance_scale=None, energy_scale=None, dataset=None):
super(ZBL, self).__init__()
if atomic_number is None:
atomic_number = dataset.atomic_number
if distance_scale is None:
distance_scale = dataset.distance_scale
raimis marked this conversation as resolved.
Show resolved Hide resolved
if energy_scale is None:
energy_scale = dataset.energy_scale
atomic_number = torch.as_tensor(atomic_number, dtype=torch.int8)
self.register_buffer("atomic_number", atomic_number)
self.distance = Distance(0, cutoff_distance, max_num_neighbors=max_num_neighbors)
self.cutoff = CosineCutoff(cutoff_upper=cutoff_distance)
self.cutoff_distance = cutoff_distance
self.max_num_neighbors = max_num_neighbors
self.distance_scale = distance_scale
self.energy_scale = energy_scale

def get_init_args(self):
return {'cutoff_distance': self.cutoff_distance,
'max_num_neighbors': self.max_num_neighbors,
'atomic_number': self.atomic_number,
'distance_scale': self.distance_scale,
'energy_scale': self.energy_scale}

def reset_parameters(self):
pass

def post_reduce(self, y, z, pos, batch):
edge_index, distance, _ = self.distance(pos, batch)
atomic_number = self.atomic_number[z[edge_index]]
# 5.29e-11 is the Bohr radius in meters. All other numbers are magic constants from the ZBL potential.
a = 0.8854*5.29177210903e-11/(atomic_number[0]**0.23 + atomic_number[1]**0.23)
d = distance*self.distance_scale/a
f = 0.1818*torch.exp(-3.2*d) + 0.5099*torch.exp(-0.9423*d) + 0.2802*torch.exp(-0.4029*d) + 0.02817*torch.exp(-0.2016*d)
f *= self.cutoff(distance)
# Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair
# appears twice.
return y + 0.5*(2.30707755e-28/self.energy_scale/self.distance_scale)*torch.sum(f*atomic_number[0]*atomic_number[1]/distance, dim=-1)
8 changes: 6 additions & 2 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def get_args():
parser.add_argument('--model', type=str, default='graph-network', choices=models.__all__, help='Which model to train')
parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model')
parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use')
parser.add_argument('--prior-args', default=None, type=str, help='Additional arguments for the prior model. Need to be specified in JSON format i.e. \'{"cutoff_distance": 10.0, "max_num_neighbors": 100}\'')
stefdoerr marked this conversation as resolved.
Show resolved Hide resolved

# architectural args
parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge')
Expand Down Expand Up @@ -125,8 +126,11 @@ def main():
f"Available models are {', '.join(priors.__all__)}"
)
# initialize the prior model
prior = getattr(priors, args.prior_model)(dataset=data.dataset)
args.prior_args = prior.get_init_args()
prior_args = args.prior_args
if prior_args is None:
prior_args = {}
prior = getattr(priors, args.prior_model)(dataset=data.dataset, **prior_args)
args.prior_init_args = prior.get_init_args()

# initialize lightning module
model = LNNP(args, prior_model=prior, mean=data.mean, std=data.std)
Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __call__(self, parser, namespace, values, option_string=None):
with open(hparams_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
for key in config.keys():
if key not in namespace and key != "prior_args":
if key not in namespace and key != "prior_init_args":
raise ValueError(f"Unknown argument in the model checkpoint: {key}")
namespace.__dict__.update(config)
namespace.__dict__.update(load_model=values)
Expand Down