Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ZBL potential #134

Merged
merged 13 commits into from
Nov 9, 2022
56 changes: 56 additions & 0 deletions tests/priors.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
activation: silu
aggr: add
atom_filter: -1
attn_activation: silu
batch_size: 128
coord_files: null
cutoff_lower: 0.0
cutoff_upper: 5.0
derivative: false
distance_influence: both
early_stopping_patience: 150
ema_alpha_neg_dy: 1.0
ema_alpha_y: 1.0
embed_files: null
embedding_dimension: 256
energy_files: null
y_weight: 1.0
force_files: null
neg_dy_weight: 1.0
inference_batch_size: 128
load_model: null
lr: 0.0004
lr_factor: 0.8
lr_min: 1.0e-07
lr_patience: 15
lr_warmup_steps: 10000
max_num_neighbors: 64
max_z: 100
model: equivariant-transformer
neighbor_embedding: true
ngpus: -1
num_epochs: 3000
num_heads: 8
num_layers: 8
num_nodes: 1
num_rbf: 64
num_workers: 6
output_model: Scalar
precision: 32
prior_model:
- ZBL:
cutoff_distance: 4.0
max_num_neighbors: 50
- Atomref
rbf_type: expnorm
redirect: false
reduce_op: add
save_interval: 10
splits: null
standardize: false
test_interval: 10
test_size: null
train_size: 110000
trainable_rbf: false
val_size: 10000
weight_decay: 0.0
67 changes: 65 additions & 2 deletions tests/test_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import torch
import pytorch_lightning as pl
from torchmdnet import models
from torchmdnet.models.model import create_model
from torchmdnet.priors import Atomref
from torchmdnet.models.model import create_model, create_prior_models
from torchmdnet.module import LNNP
from torchmdnet.priors import Atomref, ZBL
from torch_scatter import scatter
from utils import load_example_args, create_example_batch, DummyDataset
from os.path import dirname, join
import tempfile


@mark.parametrize("model_name", models.__all__)
Expand All @@ -31,3 +34,63 @@ def test_atomref(model_name):
# check if the output of both models differs by the expected atomref contribution
expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1)
torch.testing.assert_allclose(x_atomref, x_no_atomref + expected_offset)

def test_zbl():
pos = torch.tensor([[1.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, -1.0]], dtype=torch.float32) # Atom positions in Bohr
types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types
atomic_number = torch.tensor([1, 6, 8], dtype=torch.int8) # Mapping of atom types to atomic numbers
distance_scale = 5.29177210903e-11 # Convert Bohr to meters
energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules

# Use the ZBL class to compute the energy.

zbl = ZBL(10.0, 5, atomic_number, distance_scale=distance_scale, energy_scale=energy_scale)
energy = zbl.post_reduce(torch.zeros((1,)), types, pos, torch.zeros_like(types))[0]

# Compare to the expected value.

def compute_interaction(pos1, pos2, z1, z2):
delta = pos1-pos2
r = torch.sqrt(torch.dot(delta, delta))
x = r / (0.8854/(z1**0.23 + z2**0.23))
phi = 0.1818*torch.exp(-3.2*x) + 0.5099*torch.exp(-0.9423*x) + 0.2802*torch.exp(-0.4029*x) + 0.02817*torch.exp(-0.2016*x)
cutoff = 0.5*(torch.cos(r*torch.pi/10.0) + 1.0)
return cutoff*phi*(138.935/5.29177210903e-2)*z1*z2/r

expected = 0
for i in range(len(pos)):
for j in range(i):
expected += compute_interaction(pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]])
torch.testing.assert_allclose(expected, energy)

def test_multiple_priors():
# Create a model from a config file.

dataset = DummyDataset(has_atomref=True)
config_file = join(dirname(__file__), 'priors.yaml')
args = load_example_args('equivariant-transformer', config_file=config_file)
prior_models = create_prior_models(args, dataset)
args['prior_args'] = [p.get_init_args() for p in prior_models]
model = LNNP(args, prior_model=prior_models)
priors = model.model.prior_model

# Make sure the priors were created correctly.

assert len(priors) == 2
assert isinstance(priors[0], ZBL)
assert isinstance(priors[1], Atomref)
assert priors[0].cutoff_distance == 4.0
assert priors[0].max_num_neighbors == 50

# Save and load a checkpoint, and make sure the priors are correct.

with tempfile.NamedTemporaryFile() as f:
torch.save(model, f)
f.seek(0)
model2 = torch.load(f)
priors2 = model2.model.prior_model
assert len(priors2) == 2
assert isinstance(priors2[0], ZBL)
assert isinstance(priors2[1], Atomref)
assert priors2[0].cutoff_distance == priors[0].cutoff_distance
assert priors2[0].max_num_neighbors == priors[0].max_num_neighbors
9 changes: 7 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from torch_geometric.data import Dataset, Data


def load_example_args(model_name, remove_prior=False, **kwargs):
with open(join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml"), "r") as f:
def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs):
if config_file is None:
config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml")
with open(config_file, "r") as f:
args = yaml.load(f, Loader=yaml.FullLoader)
args["model"] = model_name
args["seed"] = 1234
Expand Down Expand Up @@ -69,6 +71,9 @@ def _get_atomref(self):
return self.atomref

DummyDataset.get_atomref = _get_atomref
self.atomic_number = torch.arange(max(atom_types)+1)
self.distance_scale = 1.0
self.energy_scale = 1.0

def get(self, idx):
features = dict(z=self.z[idx].clone(), pos=self.pos[idx].clone())
Expand Down
33 changes: 20 additions & 13 deletions torchmdnet/datasets/hdf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch_geometric.data import Dataset, Data
import h5py
import numpy as np


class HDF5(Dataset):
Expand All @@ -27,7 +28,12 @@ def __init__(self, filename, **kwargs):
files = [h5py.File(f, "r") for f in self.filename.split(";")]
for file in files:
for group_name in file:
self.num_molecules += len(file[group_name]["energy"])
if group_name == '_metadata':
group = file[group_name]
for name in group:
setattr(self, name, torch.tensor(np.array(group[name])))
else:
self.num_molecules += len(file[group_name]["energy"])
file.close()

def setup_index(self):
Expand All @@ -36,18 +42,19 @@ def setup_index(self):
self.index = []
for file in files:
for group_name in file:
group = file[group_name]
types = group["types"]
pos = group["pos"]
energy = group["energy"]
if "forces" in group:
self.has_forces = True
forces = group["forces"]
for i in range(len(energy)):
self.index.append((types, pos, energy, forces, i))
else:
for i in range(len(energy)):
self.index.append((types, pos, energy, i))
if group_name != '_metadata':
group = file[group_name]
types = group["types"]
pos = group["pos"]
energy = group["energy"]
if "forces" in group:
self.has_forces = True
forces = group["forces"]
for i in range(len(energy)):
self.index.append((types, pos, energy, forces, i))
else:
for i in range(len(energy)):
self.index.append((types, pos, energy, i))

assert self.num_molecules == len(self.index), (
"Mismatch between previously calculated "
Expand Down
59 changes: 45 additions & 14 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,8 @@ def create_model(args, prior_model=None, mean=None, std=None):

# prior model
if args["prior_model"] and prior_model is None:
assert "prior_args" in args, (
f"Requested prior model {args['prior_model']} but the "
f'arguments are lacking the key "prior_args".'
)
assert hasattr(priors, args["prior_model"]), (
f'Unknown prior model {args["prior_model"]}. '
f'Available models are {", ".join(priors.__all__)}'
)
# instantiate prior model if it was not passed to create_model (i.e. when loading a model)
prior_model = getattr(priors, args["prior_model"])(**args["prior_args"])
prior_model = create_prior_models(args)

# create output network
output_prefix = "Equivariant" if is_equivariant else ""
Expand Down Expand Up @@ -113,6 +105,40 @@ def load_model(filepath, args=None, device="cpu", **kwargs):
return model.to(device)


def create_prior_models(args, dataset=None):
"""Parse the prior_model configuration option and create the prior models."""
prior_models = []
if args['prior_model']:
prior_model = args['prior_model']
prior_names = []
prior_args = []
if not isinstance(prior_model, list):
prior_model = [prior_model]
for prior in prior_model:
if isinstance(prior, dict):
for key, value in prior.items():
prior_names.append(key)
if value is None:
prior_args.append({})
else:
prior_args.append(value)
else:
prior_names.append(prior)
prior_args.append({})
if 'prior_args' in args:
prior_args = args['prior_args']
if not isinstance(prior_args):
prior_args = [prior_args]
for name, arg in zip(prior_names, prior_args):
assert hasattr(priors, name), (
f"Unknown prior model {name}. "
f"Available models are {', '.join(priors.__all__)}"
)
# initialize the prior model
prior_models.append(getattr(priors, name)(dataset=dataset, **arg))
return prior_models


class TorchMD_Net(nn.Module):
def __init__(
self,
Expand All @@ -127,15 +153,17 @@ def __init__(
self.representation_model = representation_model
self.output_model = output_model

self.prior_model = prior_model
if not output_model.allow_prior_model and prior_model is not None:
self.prior_model = None
prior_model = None
rank_zero_warn(
(
"Prior model was given but the output model does "
"not allow prior models. Dropping the prior model."
)
)
if isinstance(prior_model, priors.base.BasePrior):
prior_model = [prior_model]
self.prior_model = None if prior_model is None else torch.nn.ModuleList(prior_model)

self.derivative = derivative

Expand All @@ -150,7 +178,8 @@ def reset_parameters(self):
self.representation_model.reset_parameters()
self.output_model.reset_parameters()
if self.prior_model is not None:
self.prior_model.reset_parameters()
for prior in self.prior_model:
prior.reset_parameters()

def forward(
self,
Expand Down Expand Up @@ -179,7 +208,8 @@ def forward(

# apply atom-wise prior model
if self.prior_model is not None:
x = self.prior_model.pre_reduce(x, z, pos, batch)
for prior in self.prior_model:
x = prior.pre_reduce(x, z, pos, batch)

# aggregate atoms
x = self.output_model.reduce(x, batch)
Expand All @@ -193,7 +223,8 @@ def forward(

# apply molecular-wise prior model
if self.prior_model is not None:
y = self.prior_model.post_reduce(y, z, pos, batch)
for prior in self.prior_model:
y = prior.post_reduce(y, z, pos, batch)

# compute gradients with respect to coordinates
if self.derivative:
Expand Down
3 changes: 3 additions & 0 deletions torchmdnet/priors/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from torchmdnet.priors.atomref import Atomref
from torchmdnet.priors.zbl import ZBL

__all__ = ['Atomref', 'ZBL']
54 changes: 54 additions & 0 deletions torchmdnet/priors/zbl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
from torchmdnet.priors.base import BasePrior
from torchmdnet.models.utils import Distance, CosineCutoff

class ZBL(BasePrior):
"""This class implements the Ziegler-Biersack-Littmark (ZBL) potential for screened nuclear repulsion.
Is is described in https://doi.org/10.1007/978-3-642-68779-2_5 (equations 9 and 10 on page 147). It
is an empirical potential that does a good job of describing the repulsion between atoms at very short
distances.

To use this prior, the Dataset must provide the following attributes.

atomic_number: 1D tensor of length max_z. atomic_number[z] is the atomic number of atoms with atom type z.
distance_scale: multiply by this factor to convert coordinates stored in the dataset to meters
energy_scale: multiply by this factor to convert energies stored in the dataset to Joules (*not* J/mol)
"""
def __init__(self, cutoff_distance, max_num_neighbors, atomic_number=None, distance_scale=None, energy_scale=None, dataset=None):
super(ZBL, self).__init__()
if atomic_number is None:
atomic_number = dataset.atomic_number
if distance_scale is None:
distance_scale = dataset.distance_scale
raimis marked this conversation as resolved.
Show resolved Hide resolved
if energy_scale is None:
energy_scale = dataset.energy_scale
atomic_number = torch.as_tensor(atomic_number, dtype=torch.int8)
self.register_buffer("atomic_number", atomic_number)
self.distance = Distance(0, cutoff_distance, max_num_neighbors=max_num_neighbors)
self.cutoff = CosineCutoff(cutoff_upper=cutoff_distance)
self.cutoff_distance = cutoff_distance
self.max_num_neighbors = max_num_neighbors
self.distance_scale = distance_scale
self.energy_scale = energy_scale

def get_init_args(self):
return {'cutoff_distance': self.cutoff_distance,
'max_num_neighbors': self.max_num_neighbors,
'atomic_number': self.atomic_number,
'distance_scale': self.distance_scale,
'energy_scale': self.energy_scale}

def reset_parameters(self):
pass

def post_reduce(self, y, z, pos, batch):
edge_index, distance, _ = self.distance(pos, batch)
atomic_number = self.atomic_number[z[edge_index]]
# 5.29e-11 is the Bohr radius in meters. All other numbers are magic constants from the ZBL potential.
a = 0.8854*5.29177210903e-11/(atomic_number[0]**0.23 + atomic_number[1]**0.23)
d = distance*self.distance_scale/a
f = 0.1818*torch.exp(-3.2*d) + 0.5099*torch.exp(-0.9423*d) + 0.2802*torch.exp(-0.4029*d) + 0.02817*torch.exp(-0.2016*d)
f *= self.cutoff(distance)
# Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair
# appears twice.
return y + 0.5*(2.30707755e-28/self.energy_scale/self.distance_scale)*torch.sum(f*atomic_number[0]*atomic_number[1]/distance, dim=-1)
Loading