diff --git a/tests/test_model.py b/tests/test_model.py index b792595b8..b268540db 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -98,6 +98,16 @@ def test_torchscript_dynamic_shapes(model_name, device): grad_outputs=grad_outputs, )[0] +@mark.parametrize("model_name", models.__all_models__) +@mark.parametrize("device", ["cpu", "cuda"]) +def test_torchscript_extra_embedding(model_name, device): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + args = load_example_args(model_name, remove_prior=True) + args["extra_embedding"] = "atomic" + model = create_model(args) + torch.jit.script(model).to(device=device) + #Currently only tensornet is CUDA graph compatible @mark.parametrize("model_name", ["tensornet"]) def test_cuda_graph_compatible(model_name): @@ -227,3 +237,14 @@ def test_gradients(model_name): torch.autograd.gradcheck( model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3 ) + + +@mark.parametrize("model_name", models.__all_models__) +@mark.parametrize("use_batch", [True, False]) +def test_extra_embedding(model_name, use_batch): + z, pos, batch = create_example_batch() + args = load_example_args(model_name, prior_model=None) + args["extra_embedding"] = ["atomic", "global"] + model = create_model(args) + batch = batch if use_batch else None + model(z, pos, batch=batch, extra_args={'atomic':torch.rand(6), 'global':torch.rand(2)}) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index a2a80f901..5f0b9b88e 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -38,6 +38,12 @@ def create_model(args, prior_model=None, mean=None, std=None): args["static_shapes"] = False if "vector_cutoff" not in args: args["vector_cutoff"] = False + if "extra_embedding" not in args: + extra_embedding = None + elif isinstance(args["extra_embedding"], str): + extra_embedding = [args["extra_embedding"]] + else: + extra_embedding = args["extra_embedding"] shared_args = dict( hidden_channels=args["embedding_dimension"], @@ -57,6 +63,7 @@ def create_model(args, prior_model=None, mean=None, std=None): else None ), dtype=dtype, + extra_embedding=extra_embedding ) # representation network @@ -370,7 +377,7 @@ def forward( If this is omitted, periodic boundary conditions are not applied. q (Tensor, optional): Atomic charges in the molecule. Shape: (N,). s (Tensor, optional): Atomic spins in the molecule. Shape: (N,). - extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model. + extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the model. Returns: Tuple[Tensor, Optional[Tensor]]: The output of the model and the derivative of the output with respect to the positions if derivative is True, None otherwise. @@ -380,9 +387,19 @@ def forward( if self.derivative: pos.requires_grad_(True) + if self.representation_model.extra_embedding is None: + extra_embedding_args = None + else: + assert extra_args is not None + extra_embedding_args = [] + for arg in self.representation_model.extra_embedding: + t = extra_args[arg] + if t.shape != z.shape: + t = t[batch] + extra_embedding_args.append(t) # run the potentially wrapped representation model x, v, z, pos, batch = self.representation_model( - z, pos, batch, box=box, q=q, s=s + z, pos, batch, box=box, q=q, s=s, extra_embedding_args=extra_embedding_args ) # apply the output network x = self.output_model.pre_reduce(x, v, z, pos, batch) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index a2006c9e7..2ab68bdb4 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -3,7 +3,7 @@ # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) import torch -from typing import Optional, Tuple +from typing import Optional, List, Tuple from torch import Tensor, nn from torchmdnet.models.utils import ( CosineCutoff, @@ -120,6 +120,9 @@ class TensorNet(nn.Module): (default: :obj:`True`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) + extra_embedding (tuple, optional): the names of extra fields to append to the embedding + vector for each atom + (default: :obj:`None`) """ def __init__( @@ -139,6 +142,7 @@ def __init__( check_errors=True, dtype=torch.float32, box_vecs=None, + extra_embedding=None ): super(TensorNet, self).__init__() @@ -163,6 +167,7 @@ def __init__( self.activation = activation self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper + self.extra_embedding = extra_embedding act_class = act_class_mapping[activation] self.distance_expansion = rbf_class_mapping[rbf_type]( cutoff_lower, cutoff_upper, num_rbf, trainable_rbf @@ -176,6 +181,7 @@ def __init__( trainable_rbf, max_z, dtype, + extra_embedding ) self.layers = nn.ModuleList() @@ -228,6 +234,7 @@ def forward( box: Optional[Tensor] = None, q: Optional[Tensor] = None, s: Optional[Tensor] = None, + extra_embedding_args: Optional[List[Tensor]] = None ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: # Obtain graph, with distances and relative position vectors edge_index, edge_weight, edge_vec = self.distance(pos, batch, box) @@ -258,7 +265,7 @@ def forward( # Normalizing edge vectors by their length can result in NaNs, breaking Autograd. # I avoid dividing by zero by setting the weight of self edges and self loops to 1 edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) - X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr) + X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr, extra_embedding_args) for layer in self.layers: X = layer(X, edge_index, edge_weight, edge_attr, q) I, A, S = decompose_tensor(X) @@ -287,6 +294,7 @@ def __init__( trainable_rbf=False, max_z=128, dtype=torch.float32, + extra_embedding=None ): super(TensorEmbedding, self).__init__() @@ -297,6 +305,10 @@ def __init__( self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper) self.max_z = max_z self.emb = nn.Embedding(max_z, hidden_channels, dtype=dtype) + if extra_embedding is not None: + self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype) + else: + self.reshape_embedding = None self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels, dtype=dtype) self.act = activation() self.linears_tensor = nn.ModuleList() @@ -319,6 +331,8 @@ def reset_parameters(self): self.distance_proj2.reset_parameters() self.distance_proj3.reset_parameters() self.emb.reset_parameters() + if self.reshape_embedding is not None: + self.reshape_embedding.reset_parameters() self.emb2.reset_parameters() for linear in self.linears_tensor: linear.reset_parameters() @@ -326,8 +340,14 @@ def reset_parameters(self): linear.reset_parameters() self.init_norm.reset_parameters() - def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor) -> Tensor: + def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor, extra_embedding_args: Optional[List[Tensor]]) -> Tensor: Z = self.emb(z) + if self.reshape_embedding is not None and extra_embedding_args is not None: + tensors = [Z] + for t in extra_embedding_args: + tensors.append(t.unsqueeze(1)) + Z = torch.cat(tensors, dim=1) + Z = self.reshape_embedding(Z) Zij = self.emb2( Z.index_select(0, edge_index.t().reshape(-1)).view( -1, self.hidden_channels * 2 @@ -362,8 +382,9 @@ def forward( edge_weight: Tensor, edge_vec_norm: Tensor, edge_attr: Tensor, + extra_embedding_args: Optional[List[Tensor]] ) -> Tensor: - Zij = self._get_atomic_number_message(z, edge_index) + Zij = self._get_atomic_number_message(z, edge_index, extra_embedding_args) Iij, Aij, Sij = self._get_tensor_messages( Zij, edge_weight, edge_vec_norm, edge_attr ) diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 5ff168d54..2befc865b 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, List, Tuple import torch from torch import Tensor, nn from torchmdnet.models.utils import ( @@ -79,7 +79,9 @@ class TorchMD_ET(nn.Module): (default: :obj:`False`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - + extra_embedding (tuple, optional): the names of extra fields to append to the embedding + vector for each atom + (default: :obj:`None`) """ def __init__( @@ -102,6 +104,7 @@ def __init__( box_vecs=None, vector_cutoff=False, dtype=torch.float32, + extra_embedding=None ): super(TorchMD_ET, self).__init__() @@ -133,10 +136,15 @@ def __init__( self.cutoff_upper = cutoff_upper self.max_z = max_z self.dtype = dtype + self.extra_embedding = extra_embedding act_class = act_class_mapping[activation] self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) + if extra_embedding is not None: + self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype) + else: + self.reshape_embedding = None self.distance = OptimizedDistance( cutoff_lower, @@ -181,6 +189,8 @@ def __init__( def reset_parameters(self): self.embedding.reset_parameters() + if self.reshape_embedding is not None: + self.reshape_embedding.reset_parameters() self.distance_expansion.reset_parameters() if self.neighbor_embedding is not None: self.neighbor_embedding.reset_parameters() @@ -196,8 +206,15 @@ def forward( box: Optional[Tensor] = None, q: Optional[Tensor] = None, s: Optional[Tensor] = None, + extra_embedding_args: Optional[List[Tensor]] = None ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: x = self.embedding(z) + if self.reshape_embedding is not None and extra_embedding_args is not None: + tensors = [x] + for t in extra_embedding_args: + tensors.append(t.unsqueeze(1)) + x = torch.cat(tensors, dim=1) + x = self.reshape_embedding(x) edge_index, edge_weight, edge_vec = self.distance(pos, batch, box) # This assert must be here to convince TorchScript that edge_vec is not None diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 31d68ae03..c836cd0d1 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, List, Tuple import torch from torch import Tensor, nn from torchmdnet.models.utils import ( @@ -86,7 +86,9 @@ class TorchMD_GN(nn.Module): (default: :obj:`None`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - + extra_embedding (tuple, optional): the names of extra fields to append to the embedding + vector for each atom + (default: :obj:`None`) """ def __init__( @@ -107,6 +109,7 @@ def __init__( aggr="add", dtype=torch.float32, box_vecs=None, + extra_embedding=None ): super(TorchMD_GN, self).__init__() @@ -136,10 +139,15 @@ def __init__( self.cutoff_upper = cutoff_upper self.max_z = max_z self.aggr = aggr + self.extra_embedding = extra_embedding act_class = act_class_mapping[activation] self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) + if extra_embedding is not None: + self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype) + else: + self.reshape_embedding = None self.distance = OptimizedDistance( cutoff_lower, @@ -184,6 +192,8 @@ def __init__( def reset_parameters(self): self.embedding.reset_parameters() + if self.reshape_embedding is not None: + self.reshape_embedding.reset_parameters() self.distance_expansion.reset_parameters() if self.neighbor_embedding is not None: self.neighbor_embedding.reset_parameters() @@ -198,8 +208,15 @@ def forward( box: Optional[Tensor] = None, s: Optional[Tensor] = None, q: Optional[Tensor] = None, + extra_embedding_args: Optional[List[Tensor]] = None ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) + if self.reshape_embedding is not None and extra_embedding_args is not None: + tensors = [x] + for t in extra_embedding_args: + tensors.append(t.unsqueeze(1)) + x = torch.cat(tensors, dim=1) + x = self.reshape_embedding(x) edge_index, edge_weight, _ = self.distance(pos, batch, box) edge_attr = self.distance_expansion(edge_weight) diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index c11efc080..92e2bdd60 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, List, Tuple import torch from torch import Tensor, nn from torchmdnet.models.utils import ( @@ -76,7 +76,9 @@ class TorchMD_T(nn.Module): (default: :obj:`None`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) - + extra_embedding (tuple, optional): the names of extra fields to append to the embedding + vector for each atom + (default: :obj:`None`) """ def __init__( @@ -98,6 +100,7 @@ def __init__( max_num_neighbors=32, dtype=torch.float, box_vecs=None, + extra_embedding=None ): super(TorchMD_T, self).__init__() @@ -124,11 +127,16 @@ def __init__( self.cutoff_lower = cutoff_lower self.cutoff_upper = cutoff_upper self.max_z = max_z + self.extra_embedding = extra_embedding act_class = act_class_mapping[activation] attn_act_class = act_class_mapping[attn_activation] self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) + if extra_embedding is not None: + self.reshape_embedding = nn.Linear(hidden_channels+len(extra_embedding), hidden_channels, dtype=dtype) + else: + self.reshape_embedding = None self.distance = OptimizedDistance( cutoff_lower, @@ -177,6 +185,8 @@ def __init__( def reset_parameters(self): self.embedding.reset_parameters() + if self.reshape_embedding is not None: + self.reshape_embedding.reset_parameters() self.distance_expansion.reset_parameters() if self.neighbor_embedding is not None: self.neighbor_embedding.reset_parameters() @@ -192,8 +202,15 @@ def forward( box: Optional[Tensor] = None, s: Optional[Tensor] = None, q: Optional[Tensor] = None, + extra_embedding_args: Optional[List[Tensor]] = None ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) + if self.reshape_embedding is not None and extra_embedding_args is not None: + tensors = [x] + for t in extra_embedding_args: + tensors.append(t.unsqueeze(1)) + x = torch.cat(tensors, dim=1) + x = self.reshape_embedding(x) edge_index, edge_weight, _ = self.distance(pos, batch, box) edge_attr = self.distance_expansion(edge_weight) diff --git a/torchmdnet/optimize.py b/torchmdnet/optimize.py index 0c7f56513..42291b2ba 100644 --- a/torchmdnet/optimize.py +++ b/torchmdnet/optimize.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, List, Tuple import torch as pt from NNPOps.CFConv import CFConv from NNPOps.CFConvNeighbors import CFConvNeighbors @@ -33,6 +33,7 @@ def __init__(self, model): super().__init__() self.model = model + self.extra_embedding = model.extra_embedding self.neighbors = CFConvNeighbors(self.model.cutoff_upper) @@ -58,12 +59,19 @@ def forward( box: Optional[pt.Tensor] = None, q: Optional[pt.Tensor] = None, s: Optional[pt.Tensor] = None, + extra_embedding_args: Optional[List[pt.Tensor]] = None ) -> Tuple[pt.Tensor, Optional[pt.Tensor], pt.Tensor, pt.Tensor, pt.Tensor]: assert pt.all(batch == 0) assert box is None, "Box is not supported" x = self.model.embedding(z) + if self.model.reshape_embedding is not None and extra_embedding_args is not None: + tensors = [x] + for t in extra_embedding_args: + tensors.append(t.unsqueeze(1)) + x = pt.cat(tensors, dim=1) + x = self.model.reshape_embedding(x) self.neighbors.build(pos) for inter, conv in zip(self.model.interactions, self.convs): diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index a51cfe45f..26b334f97 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -80,6 +80,7 @@ def get_argparse(): parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.') parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state. Set this to True if your dataset contains spin states and you want them passed down to the model.') parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension') + parser.add_argument('--extra-embedding', type=str, default=None, help='Extra fields of the dataset to pass to the model and append to the embedding vector.', action="extend", nargs="*") parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model') parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model') parser.add_argument('--activation', type=str, default='silu', choices=list(act_class_mapping.keys()), help='Activation function')