diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 97974864f..9484de20c 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -402,7 +402,7 @@ def forward(self, X, edge_index, edge_weight, edge_attr): A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) dX = I + A + S - X = X + dX + dX**2 + X = X + dX + torch.matmul(dX,dX) return X def message(self, I_j, A_j, S_j, edge_attr):