diff --git a/tests/priors.yaml b/tests/priors.yaml new file mode 100644 index 000000000..4ec754cac --- /dev/null +++ b/tests/priors.yaml @@ -0,0 +1,56 @@ +activation: silu +aggr: add +atom_filter: -1 +attn_activation: silu +batch_size: 128 +coord_files: null +cutoff_lower: 0.0 +cutoff_upper: 5.0 +derivative: false +distance_influence: both +early_stopping_patience: 150 +ema_alpha_neg_dy: 1.0 +ema_alpha_y: 1.0 +embed_files: null +embedding_dimension: 256 +energy_files: null +y_weight: 1.0 +force_files: null +neg_dy_weight: 1.0 +inference_batch_size: 128 +load_model: null +lr: 0.0004 +lr_factor: 0.8 +lr_min: 1.0e-07 +lr_patience: 15 +lr_warmup_steps: 10000 +max_num_neighbors: 64 +max_z: 100 +model: equivariant-transformer +neighbor_embedding: true +ngpus: -1 +num_epochs: 3000 +num_heads: 8 +num_layers: 8 +num_nodes: 1 +num_rbf: 64 +num_workers: 6 +output_model: Scalar +precision: 32 +prior_model: + - ZBL: + cutoff_distance: 4.0 + max_num_neighbors: 50 + - Atomref +rbf_type: expnorm +redirect: false +reduce_op: add +save_interval: 10 +splits: null +standardize: false +test_interval: 10 +test_size: null +train_size: 110000 +trainable_rbf: false +val_size: 10000 +weight_decay: 0.0 diff --git a/tests/test_priors.py b/tests/test_priors.py index 31481233d..374718929 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -3,10 +3,13 @@ import torch import pytorch_lightning as pl from torchmdnet import models -from torchmdnet.models.model import create_model -from torchmdnet.priors import Atomref +from torchmdnet.models.model import create_model, create_prior_models +from torchmdnet.module import LNNP +from torchmdnet.priors import Atomref, ZBL from torch_scatter import scatter from utils import load_example_args, create_example_batch, DummyDataset +from os.path import dirname, join +import tempfile @mark.parametrize("model_name", models.__all__) @@ -31,3 +34,63 @@ def test_atomref(model_name): # check if the output of both models differs by the expected atomref contribution expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1) torch.testing.assert_allclose(x_atomref, x_no_atomref + expected_offset) + +def test_zbl(): + pos = torch.tensor([[1.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, -1.0]], dtype=torch.float32) # Atom positions in Bohr + types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types + atomic_number = torch.tensor([1, 6, 8], dtype=torch.int8) # Mapping of atom types to atomic numbers + distance_scale = 5.29177210903e-11 # Convert Bohr to meters + energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules + + # Use the ZBL class to compute the energy. + + zbl = ZBL(10.0, 5, atomic_number, distance_scale=distance_scale, energy_scale=energy_scale) + energy = zbl.post_reduce(torch.zeros((1,)), types, pos, torch.zeros_like(types))[0] + + # Compare to the expected value. + + def compute_interaction(pos1, pos2, z1, z2): + delta = pos1-pos2 + r = torch.sqrt(torch.dot(delta, delta)) + x = r / (0.8854/(z1**0.23 + z2**0.23)) + phi = 0.1818*torch.exp(-3.2*x) + 0.5099*torch.exp(-0.9423*x) + 0.2802*torch.exp(-0.4029*x) + 0.02817*torch.exp(-0.2016*x) + cutoff = 0.5*(torch.cos(r*torch.pi/10.0) + 1.0) + return cutoff*phi*(138.935/5.29177210903e-2)*z1*z2/r + + expected = 0 + for i in range(len(pos)): + for j in range(i): + expected += compute_interaction(pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]]) + torch.testing.assert_allclose(expected, energy) + +def test_multiple_priors(): + # Create a model from a config file. + + dataset = DummyDataset(has_atomref=True) + config_file = join(dirname(__file__), 'priors.yaml') + args = load_example_args('equivariant-transformer', config_file=config_file) + prior_models = create_prior_models(args, dataset) + args['prior_args'] = [p.get_init_args() for p in prior_models] + model = LNNP(args, prior_model=prior_models) + priors = model.model.prior_model + + # Make sure the priors were created correctly. + + assert len(priors) == 2 + assert isinstance(priors[0], ZBL) + assert isinstance(priors[1], Atomref) + assert priors[0].cutoff_distance == 4.0 + assert priors[0].max_num_neighbors == 50 + + # Save and load a checkpoint, and make sure the priors are correct. + + with tempfile.NamedTemporaryFile() as f: + torch.save(model, f) + f.seek(0) + model2 = torch.load(f) + priors2 = model2.model.prior_model + assert len(priors2) == 2 + assert isinstance(priors2[0], ZBL) + assert isinstance(priors2[1], Atomref) + assert priors2[0].cutoff_distance == priors[0].cutoff_distance + assert priors2[0].max_num_neighbors == priors[0].max_num_neighbors diff --git a/tests/utils.py b/tests/utils.py index d8cd8322b..b297b5217 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,8 +4,10 @@ from torch_geometric.data import Dataset, Data -def load_example_args(model_name, remove_prior=False, **kwargs): - with open(join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml"), "r") as f: +def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs): + if config_file is None: + config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml") + with open(config_file, "r") as f: args = yaml.load(f, Loader=yaml.FullLoader) args["model"] = model_name args["seed"] = 1234 @@ -69,6 +71,9 @@ def _get_atomref(self): return self.atomref DummyDataset.get_atomref = _get_atomref + self.atomic_number = torch.arange(max(atom_types)+1) + self.distance_scale = 1.0 + self.energy_scale = 1.0 def get(self, idx): features = dict(z=self.z[idx].clone(), pos=self.pos[idx].clone()) diff --git a/torchmdnet/datasets/hdf.py b/torchmdnet/datasets/hdf.py index fcebd01fa..4bfb1edf8 100644 --- a/torchmdnet/datasets/hdf.py +++ b/torchmdnet/datasets/hdf.py @@ -1,6 +1,7 @@ import torch from torch_geometric.data import Dataset, Data import h5py +import numpy as np class HDF5(Dataset): @@ -27,7 +28,12 @@ def __init__(self, filename, **kwargs): files = [h5py.File(f, "r") for f in self.filename.split(";")] for file in files: for group_name in file: - self.num_molecules += len(file[group_name]["energy"]) + if group_name == '_metadata': + group = file[group_name] + for name in group: + setattr(self, name, torch.tensor(np.array(group[name]))) + else: + self.num_molecules += len(file[group_name]["energy"]) file.close() def setup_index(self): @@ -36,18 +42,19 @@ def setup_index(self): self.index = [] for file in files: for group_name in file: - group = file[group_name] - types = group["types"] - pos = group["pos"] - energy = group["energy"] - if "forces" in group: - self.has_forces = True - forces = group["forces"] - for i in range(len(energy)): - self.index.append((types, pos, energy, forces, i)) - else: - for i in range(len(energy)): - self.index.append((types, pos, energy, i)) + if group_name != '_metadata': + group = file[group_name] + types = group["types"] + pos = group["pos"] + energy = group["energy"] + if "forces" in group: + self.has_forces = True + forces = group["forces"] + for i in range(len(energy)): + self.index.append((types, pos, energy, forces, i)) + else: + for i in range(len(energy)): + self.index.append((types, pos, energy, i)) assert self.num_molecules == len(self.index), ( "Mismatch between previously calculated " diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index c548f1ad1..2fede80b0 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -65,16 +65,8 @@ def create_model(args, prior_model=None, mean=None, std=None): # prior model if args["prior_model"] and prior_model is None: - assert "prior_args" in args, ( - f"Requested prior model {args['prior_model']} but the " - f'arguments are lacking the key "prior_args".' - ) - assert hasattr(priors, args["prior_model"]), ( - f'Unknown prior model {args["prior_model"]}. ' - f'Available models are {", ".join(priors.__all__)}' - ) # instantiate prior model if it was not passed to create_model (i.e. when loading a model) - prior_model = getattr(priors, args["prior_model"])(**args["prior_args"]) + prior_model = create_prior_models(args) # create output network output_prefix = "Equivariant" if is_equivariant else "" @@ -113,6 +105,40 @@ def load_model(filepath, args=None, device="cpu", **kwargs): return model.to(device) +def create_prior_models(args, dataset=None): + """Parse the prior_model configuration option and create the prior models.""" + prior_models = [] + if args['prior_model']: + prior_model = args['prior_model'] + prior_names = [] + prior_args = [] + if not isinstance(prior_model, list): + prior_model = [prior_model] + for prior in prior_model: + if isinstance(prior, dict): + for key, value in prior.items(): + prior_names.append(key) + if value is None: + prior_args.append({}) + else: + prior_args.append(value) + else: + prior_names.append(prior) + prior_args.append({}) + if 'prior_args' in args: + prior_args = args['prior_args'] + if not isinstance(prior_args): + prior_args = [prior_args] + for name, arg in zip(prior_names, prior_args): + assert hasattr(priors, name), ( + f"Unknown prior model {name}. " + f"Available models are {', '.join(priors.__all__)}" + ) + # initialize the prior model + prior_models.append(getattr(priors, name)(dataset=dataset, **arg)) + return prior_models + + class TorchMD_Net(nn.Module): def __init__( self, @@ -127,15 +153,17 @@ def __init__( self.representation_model = representation_model self.output_model = output_model - self.prior_model = prior_model if not output_model.allow_prior_model and prior_model is not None: - self.prior_model = None + prior_model = None rank_zero_warn( ( "Prior model was given but the output model does " "not allow prior models. Dropping the prior model." ) ) + if isinstance(prior_model, priors.base.BasePrior): + prior_model = [prior_model] + self.prior_model = None if prior_model is None else torch.nn.ModuleList(prior_model) self.derivative = derivative @@ -150,7 +178,8 @@ def reset_parameters(self): self.representation_model.reset_parameters() self.output_model.reset_parameters() if self.prior_model is not None: - self.prior_model.reset_parameters() + for prior in self.prior_model: + prior.reset_parameters() def forward( self, @@ -179,7 +208,8 @@ def forward( # apply atom-wise prior model if self.prior_model is not None: - x = self.prior_model.pre_reduce(x, z, pos, batch) + for prior in self.prior_model: + x = prior.pre_reduce(x, z, pos, batch) # aggregate atoms x = self.output_model.reduce(x, batch) @@ -193,7 +223,8 @@ def forward( # apply molecular-wise prior model if self.prior_model is not None: - y = self.prior_model.post_reduce(y, z, pos, batch) + for prior in self.prior_model: + y = prior.post_reduce(y, z, pos, batch) # compute gradients with respect to coordinates if self.derivative: diff --git a/torchmdnet/priors/__init__.py b/torchmdnet/priors/__init__.py index 50c134b30..c5cfecd94 100644 --- a/torchmdnet/priors/__init__.py +++ b/torchmdnet/priors/__init__.py @@ -1 +1,4 @@ from torchmdnet.priors.atomref import Atomref +from torchmdnet.priors.zbl import ZBL + +__all__ = ['Atomref', 'ZBL'] \ No newline at end of file diff --git a/torchmdnet/priors/zbl.py b/torchmdnet/priors/zbl.py new file mode 100644 index 000000000..544d25192 --- /dev/null +++ b/torchmdnet/priors/zbl.py @@ -0,0 +1,54 @@ +import torch +from torchmdnet.priors.base import BasePrior +from torchmdnet.models.utils import Distance, CosineCutoff + +class ZBL(BasePrior): + """This class implements the Ziegler-Biersack-Littmark (ZBL) potential for screened nuclear repulsion. + Is is described in https://doi.org/10.1007/978-3-642-68779-2_5 (equations 9 and 10 on page 147). It + is an empirical potential that does a good job of describing the repulsion between atoms at very short + distances. + + To use this prior, the Dataset must provide the following attributes. + + atomic_number: 1D tensor of length max_z. atomic_number[z] is the atomic number of atoms with atom type z. + distance_scale: multiply by this factor to convert coordinates stored in the dataset to meters + energy_scale: multiply by this factor to convert energies stored in the dataset to Joules (*not* J/mol) + """ + def __init__(self, cutoff_distance, max_num_neighbors, atomic_number=None, distance_scale=None, energy_scale=None, dataset=None): + super(ZBL, self).__init__() + if atomic_number is None: + atomic_number = dataset.atomic_number + if distance_scale is None: + distance_scale = dataset.distance_scale + if energy_scale is None: + energy_scale = dataset.energy_scale + atomic_number = torch.as_tensor(atomic_number, dtype=torch.int8) + self.register_buffer("atomic_number", atomic_number) + self.distance = Distance(0, cutoff_distance, max_num_neighbors=max_num_neighbors) + self.cutoff = CosineCutoff(cutoff_upper=cutoff_distance) + self.cutoff_distance = cutoff_distance + self.max_num_neighbors = max_num_neighbors + self.distance_scale = distance_scale + self.energy_scale = energy_scale + + def get_init_args(self): + return {'cutoff_distance': self.cutoff_distance, + 'max_num_neighbors': self.max_num_neighbors, + 'atomic_number': self.atomic_number, + 'distance_scale': self.distance_scale, + 'energy_scale': self.energy_scale} + + def reset_parameters(self): + pass + + def post_reduce(self, y, z, pos, batch): + edge_index, distance, _ = self.distance(pos, batch) + atomic_number = self.atomic_number[z[edge_index]] + # 5.29e-11 is the Bohr radius in meters. All other numbers are magic constants from the ZBL potential. + a = 0.8854*5.29177210903e-11/(atomic_number[0]**0.23 + atomic_number[1]**0.23) + d = distance*self.distance_scale/a + f = 0.1818*torch.exp(-3.2*d) + 0.5099*torch.exp(-0.9423*d) + 0.2802*torch.exp(-0.4029*d) + 0.02817*torch.exp(-0.2016*d) + f *= self.cutoff(distance) + # Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair + # appears twice. + return y + 0.5*(2.30707755e-28/self.energy_scale/self.distance_scale)*torch.sum(f*atomic_number[0]*atomic_number[1]/distance, dim=-1) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 633f988c1..62ead204b 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -11,6 +11,7 @@ from torchmdnet import datasets, priors, models from torchmdnet.data import DataModule from torchmdnet.models import output_modules +from torchmdnet.models.model import create_prior_models from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number @@ -118,18 +119,11 @@ def main(): data.prepare_data() data.setup("fit") - prior = None - if args.prior_model: - assert hasattr(priors, args.prior_model), ( - f"Unknown prior model {args['prior_model']}. " - f"Available models are {', '.join(priors.__all__)}" - ) - # initialize the prior model - prior = getattr(priors, args.prior_model)(dataset=data.dataset) - args.prior_args = prior.get_init_args() + prior_models = create_prior_models(vars(args), data.dataset) + args.prior_args = [p.get_init_args() for p in prior_models] # initialize lightning module - model = LNNP(args, prior_model=prior, mean=data.mean, std=data.std) + model = LNNP(args, prior_model=prior_models, mean=data.mean, std=data.std) checkpoint_callback = ModelCheckpoint( dirpath=args.log_dir,