diff --git a/tests/test_equivariance.py b/tests/test_equivariance.py index 1492a9f07..7e6b178ea 100644 --- a/tests/test_equivariance.py +++ b/tests/test_equivariance.py @@ -2,6 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) +import pytest import torch from torchmdnet.models.model import create_model from utils import load_example_args @@ -27,7 +28,8 @@ def test_scalar_invariance(): torch.testing.assert_allclose(y, y_rot) -def test_vector_equivariance(): +@pytest.mark.parametrize("model_name", ["equivariant-transformer", "tensornet"]) +def test_vector_equivariance(model_name): torch.manual_seed(1234) rotate = torch.tensor( [ @@ -36,14 +38,23 @@ def test_vector_equivariance(): [-0.0626055, 0.3134752, 0.9475304], ] ) - - model = create_model( - load_example_args( - "equivariant-transformer", - prior_model=None, - output_model="VectorOutput", + if model_name == "equivariant_transformer": + model = create_model( + load_example_args( + model_name, + prior_model=None, + output_model="VectorOutput", + ) + ) + if model_name == "tensornet": + model = create_model( + load_example_args( + model_name, + prior_model=None, + vector_output=True, + output_model="VectorOutput", + ) ) - ) z = torch.ones(100, dtype=torch.long) pos = torch.randn(100, 3) batch = torch.arange(50, dtype=torch.long).repeat_interleave(2) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index a2a80f901..e4f8dd653 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -101,6 +101,7 @@ def create_model(args, prior_model=None, mean=None, std=None): representation_model = TensorNet( equivariance_invariance_group=args["equivariance_invariance_group"], static_shapes=args["static_shapes"], + vector_output=args["vector_output"], **shared_args, ) else: diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index a2006c9e7..0d9ad1ce5 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -48,6 +48,10 @@ def vector_to_symtensor(vector): S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I return S +def skewtensor_to_vector(tensor): + '''Converts a skew-symmetric tensor to a vector.''' + return torch.stack((tensor[:, :, 1, 2], tensor[:, :, 2, 0], tensor[:, :, 0, 1]), dim=-1) + def decompose_tensor(tensor): """Full tensor decomposition into irreducible components.""" @@ -120,6 +124,7 @@ class TensorNet(nn.Module): (default: :obj:`True`) check_errors (bool, optional): Whether to check for errors in the distance module. (default: :obj:`True`) + vector_output (bool, optional): Whether to return vector features per atom """ def __init__( @@ -139,6 +144,7 @@ def __init__( check_errors=True, dtype=torch.float32, box_vecs=None, + vector_output=False ): super(TensorNet, self).__init__() @@ -210,6 +216,7 @@ def __init__( box=box_vecs, long_edge_index=True, ) + self.vector_output = vector_output self.reset_parameters() @@ -265,10 +272,22 @@ def forward( x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1) x = self.out_norm(x) x = self.act(self.linear((x))) - # # Remove the extra atom + # Remove the extra atom if self.static_shapes: x = x[:-1] - return x, None, z, pos, batch + + # calculate vector_output if needed + v = None + if self.vector_output: + # (n_atoms, hidden_channels, 3, 3) -> (n_atoms, hidden_channels, 3) + v = skewtensor_to_vector(A) + # (n_atoms, hidden_channels, 3) -> (n_atoms, 3, hidden_channels) + v = v.transpose(1, 2) + + if self.static_shapes: + v = v[:-1] + + return x, v, z, pos, batch class TensorEmbedding(nn.Module): diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index a51cfe45f..be744bf57 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -103,6 +103,7 @@ def get_argparse(): `a[1] = a[2] = b[2] = 0`;`a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff`;`a[0] >= 2*b[0]`;`a[0] >= 2*c[0]`;`b[1] >= 2*c[1]`; These requirements correspond to a particular rotation of the system and reduced form of the vectors, as well as the requirement that the cutoff be no larger than half the box width. Example: [[1,0,0],[0,1,0],[0,0,1]]""") + parser.add_argument('--vector-output', type=bool, default=False, help='If true, returns vector features per atom on top of scalars') parser.add_argument('--static_shapes', type=bool, default=False, help='If true, TensorNet will use statically shaped tensors for the network, making it capturable into a CUDA graphs. In some situations static shapes can lead to a speedup, but it increases memory usage.') # other args