From a11684710d6ba22b29acbf709bfbc05b93f3402d Mon Sep 17 00:00:00 2001 From: Raul Date: Tue, 27 Jun 2023 08:47:48 +0100 Subject: [PATCH] Make TensorNet compatible with TorchScript (#186) * Change some lines incompatible with jit script * Remove some empty lines * Fix typo * Include an assert to appease TorchScript * Change a range loop to an enumerate * Add test for skewtensor function * Small changes from merge * Update test * Update vector_to_skewtensor * Remove some parenthesis * Small changes * Remove skewtensor test * Annotate types in Atomref * Simplify a couple of operations * Check also derivative in torchscript test * Type annotate forward LLNP * Try double backward in the TorchScript test * Change test name * Annotate forward * Remove unused variables * Remove unnecessary enumerates * Add TorchScript GPU tests --- tests/test_model.py | 46 ++++++++++-- torchmdnet/models/tensornet.py | 131 +++++++++++++++++++-------------- torchmdnet/module.py | 12 ++- torchmdnet/priors/atomref.py | 5 +- 4 files changed, 128 insertions(+), 66 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 489262c85..e2e5010ee 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -37,16 +37,47 @@ def test_forward_output_modules(model_name, output_model, dtype): @mark.parametrize("model_name", models.__all__) -@mark.parametrize("dtype", [torch.float32, torch.float64]) -def test_forward_torchscript(model_name, dtype): - if model_name == "tensornet": - pytest.skip("TensorNet does not support torchscript.") +@mark.parametrize("device", ["cpu", "cuda"]) +def test_torchscript(model_name, device): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") z, pos, batch = create_example_batch() + z = z.to(device) + pos = pos.to(device) + batch = batch.to(device) model = torch.jit.script( - create_model(load_example_args(model_name, remove_prior=True, derivative=True, dtype=dtype)) - ) - model(z, pos, batch=batch) + create_model(load_example_args(model_name, remove_prior=True, derivative=True)) + ).to(device=device) + y, neg_dy = model(z, pos, batch=batch) + grad_outputs = [torch.ones_like(neg_dy)] + ddy = torch.autograd.grad( + [neg_dy], + [pos], + grad_outputs=grad_outputs, + )[0] +@mark.parametrize("model_name", models.__all__) +@mark.parametrize("device", ["cpu", "cuda"]) +def test_torchscript_dynamic_shapes(model_name, device): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + z, pos, batch = create_example_batch() + model = torch.jit.script( + create_model(load_example_args(model_name, remove_prior=True, derivative=True)) + ).to(device=device) + #Repeat the input to make it dynamic + for rep in range(0, 5): + print(rep) + zi = z.repeat_interleave(rep+1, dim=0).to(device=device) + posi = pos.repeat_interleave(rep+1, dim=0).to(device=device) + batchi = torch.randint(0, 10, (zi.shape[0],)).sort()[0].to(device=device) + y, neg_dy = model(zi, posi, batch=batchi) + grad_outputs = [torch.ones_like(neg_dy)] + ddy = torch.autograd.grad( + [neg_dy], + [posi], + grad_outputs=grad_outputs, + )[0] @mark.parametrize("model_name", models.__all__) def test_seed(model_name): @@ -59,7 +90,6 @@ def test_seed(model_name): for p1, p2 in zip(m1.parameters(), m2.parameters()): assert (p1 == p2).all(), "Parameters don't match although using the same seed." - @mark.parametrize("model_name", models.__all__) @mark.parametrize( "output_model", diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 0b0f4c932..97974864f 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -13,11 +13,23 @@ # Creates a skew-symmetric tensor from a vector def vector_to_skewtensor(vector): - tensor = torch.cross( - *torch.broadcast_tensors( - vector[..., None], torch.eye(3, 3, device=vector.device, dtype=vector.dtype)[None, None] - ) + batch_size = vector.size(0) + zero = torch.zeros(batch_size, device=vector.device, dtype=vector.dtype) + tensor = torch.stack( + ( + zero, + -vector[:, 2], + vector[:, 1], + vector[:, 2], + zero, + -vector[:, 0], + -vector[:, 1], + vector[:, 0], + zero, + ), + dim=1, ) + tensor = tensor.view(-1, 3, 3) return tensor.squeeze(0) @@ -43,9 +55,9 @@ def decompose_tensor(tensor): # Modifies tensor by multiplying invariant features to irreducible components def new_radial_tensor(I, A, S, f_I, f_A, f_S): - I = (f_I)[..., None, None] * I - A = (f_A)[..., None, None] * A - S = (f_S)[..., None, None] * S + I = f_I[..., None, None] * I + A = f_A[..., None, None] * A + S = f_S[..., None, None] * S return I, A, S @@ -102,6 +114,7 @@ def __init__( dtype=torch.float32, ): super(TensorNet, self).__init__() + assert rbf_type in rbf_class_mapping, ( f'Unknown RBF type "{rbf_type}". ' f'Choose from {", ".join(rbf_class_mapping.keys())}.' @@ -110,6 +123,7 @@ def __init__( f'Unknown activation function "{activation}". ' f'Choose from {", ".join(act_class_mapping.keys())}.' ) + assert equivariance_invariance_group in ["O(3)", "SO(3)"], ( f'Unknown group "{equivariance_invariance_group}". ' f"Choose O(3) or SO(3)." @@ -139,6 +153,7 @@ def __init__( max_z, dtype, ).jittable() + self.layers = nn.ModuleList() if num_layers != 0: for _ in range(num_layers): @@ -160,23 +175,34 @@ def __init__( def reset_parameters(self): self.tensor_embedding.reset_parameters() - for i in range(self.num_layers): - self.layers[i].reset_parameters() + for layer in self.layers: + layer.reset_parameters() self.linear.reset_parameters() self.out_norm.reset_parameters() def forward( - self, z, pos, batch, q: Optional[Tensor] = None, s: Optional[Tensor] = None - ): + self, + z: Tensor, + pos: Tensor, + batch: Tensor, + q: Optional[Tensor] = None, + s: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: + # Obtain graph, with distances and relative position vectors edge_index, edge_weight, edge_vec = self.distance(pos, batch) + # This assert convinces TorchScript that edge_vec is a Tensor and not an Optional[Tensor] + assert ( + edge_vec is not None + ), "Distance module did not return directional information" + # Expand distances with radial basis functions edge_attr = self.distance_expansion(edge_weight) # Embedding from edge-wise tensors to node-wise tensors X = self.tensor_embedding(z, edge_index, edge_weight, edge_vec, edge_attr) # Interaction layers - for i in range(self.num_layers): - X = self.layers[i](X, edge_index, edge_weight, edge_attr) + for layer in self.layers: + X = layer(X, edge_index, edge_weight, edge_attr) I, A, S = decompose_tensor(X) x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1) x = self.out_norm(x) @@ -208,15 +234,10 @@ def __init__( self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels, dtype=dtype) self.act = activation() self.linears_tensor = nn.ModuleList() - self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) - ) - self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) - ) - self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) - ) + for _ in range(3): + self.linears_tensor.append( + nn.Linear(hidden_channels, hidden_channels, bias=False) + ) self.linears_scalar = nn.ModuleList() self.linears_scalar.append( nn.Linear(hidden_channels, 2 * hidden_channels, bias=True, dtype=dtype) @@ -239,16 +260,26 @@ def reset_parameters(self): linear.reset_parameters() self.init_norm.reset_parameters() - def forward(self, z, edge_index, edge_weight, edge_vec, edge_attr): + def forward( + self, + z: Tensor, + edge_index: Tensor, + edge_weight: Tensor, + edge_vec: Tensor, + edge_attr: Tensor, + ): + Z = self.emb(z) C = self.cutoff(edge_weight) - W1 = (self.distance_proj1(edge_attr)) * C.view(-1, 1) - W2 = (self.distance_proj2(edge_attr)) * C.view(-1, 1) - W3 = (self.distance_proj3(edge_attr)) * C.view(-1, 1) + W1 = self.distance_proj1(edge_attr) * C.view(-1, 1) + W2 = self.distance_proj2(edge_attr) * C.view(-1, 1) + W3 = self.distance_proj3(edge_attr) * C.view(-1, 1) mask = edge_index[0] != edge_index[1] edge_vec[mask] = edge_vec[mask] / torch.norm(edge_vec[mask], dim=1).unsqueeze(1) Iij, Aij, Sij = new_radial_tensor( - torch.eye(3, 3, device=edge_vec.device, dtype=edge_vec.dtype)[None, None, :, :], + torch.eye(3, 3, device=edge_vec.device, dtype=edge_vec.dtype)[ + None, None, :, : + ], vector_to_skewtensor(edge_vec)[..., None, :, :], vector_to_symtensor(edge_vec)[..., None, :, :], W1, @@ -262,11 +293,12 @@ def forward(self, z, edge_index, edge_weight, edge_vec, edge_attr): I = self.linears_tensor[0](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) A = self.linears_tensor[1](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) S = self.linears_tensor[2](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - for j in range(len(self.linears_scalar)): - norm = self.act(self.linears_scalar[j](norm)) + for linear_scalar in self.linears_scalar: + norm = self.act(linear_scalar(norm)) norm = norm.reshape(norm.shape[0], self.hidden_channels, 3) I, A, S = new_radial_tensor(I, A, S, norm[..., 0], norm[..., 1], norm[..., 2]) X = I + A + S + return X def message(self, Z_i, Z_j, I, A, S): @@ -275,6 +307,7 @@ def message(self, Z_i, Z_j, I, A, S): I = Zij[..., None, None] * I A = Zij[..., None, None] * A S = Zij[..., None, None] * S + return I, A, S def aggregate( @@ -284,10 +317,12 @@ def aggregate( ptr: Optional[torch.Tensor], dim_size: Optional[int], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + I, A, S = features I = scatter(I, index, dim=self.node_dim, dim_size=dim_size) A = scatter(A, index, dim=self.node_dim, dim_size=dim_size) S = scatter(S, index, dim=self.node_dim, dim_size=dim_size) + return I, A, S def update( @@ -321,24 +356,10 @@ def __init__( nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True, dtype=dtype) ) self.linears_tensor = nn.ModuleList() - self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) - ) - self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) - ) - self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) - ) - self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) - ) - self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) - ) - self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False, dtype=dtype) - ) + for _ in range(6): + self.linears_tensor.append( + nn.Linear(hidden_channels, hidden_channels, bias=False) + ) self.act = activation() self.equivariance_invariance_group = equivariance_invariance_group self.reset_parameters() @@ -350,9 +371,10 @@ def reset_parameters(self): linear.reset_parameters() def forward(self, X, edge_index, edge_weight, edge_attr): + C = self.cutoff(edge_weight) - for i in range(len(self.linears_scalar)): - edge_attr = self.act(self.linears_scalar[i](edge_attr)) + for linear_scalar in self.linears_scalar: + edge_attr = self.act(linear_scalar(edge_attr)) edge_attr = (edge_attr * C.view(-1, 1)).reshape( edge_attr.shape[0], self.hidden_channels, 3 ) @@ -374,19 +396,17 @@ def forward(self, X, edge_index, edge_weight, edge_attr): if self.equivariance_invariance_group == "SO(3)": B = torch.matmul(Y, msg) I, A, S = decompose_tensor(2 * B) - norm = tensor_norm(I + A + S) - I = I / (norm + 1)[..., None, None] - A = A / (norm + 1)[..., None, None] - S = S / (norm + 1)[..., None, None] + normp1 = (tensor_norm(I + A + S) + 1)[..., None, None] + I, A, S = I / normp1, A / normp1, S / normp1 I = self.linears_tensor[3](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) dX = I + A + S - dX = dX + torch.matmul(dX, dX) - X = X + dX + X = X + dX + dX**2 return X def message(self, I_j, A_j, S_j, edge_attr): + I, A, S = new_radial_tensor( I_j, A_j, S_j, edge_attr[..., 0], edge_attr[..., 1], edge_attr[..., 2] ) @@ -399,6 +419,7 @@ def aggregate( ptr: Optional[torch.Tensor], dim_size: Optional[int], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + I, A, S = features I = scatter(I, index, dim=self.node_dim, dim_size=dim_size) A = scatter(A, index, dim=self.node_dim, dim_size=dim_size) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index f299b4b7c..a04142a61 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -2,6 +2,8 @@ from torch.optim import AdamW from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.nn.functional import mse_loss, l1_loss +from torch import Tensor +from typing import Optional, Dict, Tuple from pytorch_lightning import LightningModule from torchmdnet.models.model import create_model, load_model @@ -55,7 +57,15 @@ def configure_optimizers(self): } return [optimizer], [lr_scheduler] - def forward(self, z, pos, batch=None, q=None, s=None, extra_args=None): + def forward(self, + z: Tensor, + pos: Tensor, + batch: Optional[Tensor] = None, + q: Optional[Tensor] = None, + s: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Tensor]]: + return self.model(z, pos, batch=batch, q=q, s=s, extra_args=extra_args) def training_step(self, batch, batch_idx): diff --git a/torchmdnet/priors/atomref.py b/torchmdnet/priors/atomref.py index c0e069f9a..7ef4f64f1 100644 --- a/torchmdnet/priors/atomref.py +++ b/torchmdnet/priors/atomref.py @@ -1,6 +1,7 @@ from torchmdnet.priors.base import BasePrior +from typing import Optional, Dict import torch -from torch import nn +from torch import nn, Tensor from pytorch_lightning.utilities import rank_zero_warn @@ -37,5 +38,5 @@ def reset_parameters(self): def get_init_args(self): return dict(max_z=self.initial_atomref.size(0)) - def pre_reduce(self, x, z, pos, batch, extra_args): + def pre_reduce(self, x: Tensor, z: Tensor, pos: Tensor, batch: Tensor, extra_args: Optional[Dict[str, Tensor]]): return x + self.atomref(z)