diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py index a9578283a..d9ffc2298 100644 --- a/benchmarks/neighbors.py +++ b/benchmarks/neighbors.py @@ -1,7 +1,72 @@ import os import torch import numpy as np -from torchmdnet.models.utils import Distance, OptimizedDistance +from torchmdnet.models.utils import OptimizedDistance +from torch import nn +from typing import Optional +from torch_cluster import radius_graph + +class Distance(nn.Module): + def __init__( + self, + cutoff_lower, + cutoff_upper, + max_num_neighbors=32, + return_vecs=False, + loop=False, + ): + super(Distance, self).__init__() + self.cutoff_lower = cutoff_lower + self.cutoff_upper = cutoff_upper + self.max_num_neighbors = max_num_neighbors + self.return_vecs = return_vecs + self.loop = loop + + def forward(self, pos, batch): + edge_index = radius_graph( + pos, + r=self.cutoff_upper, + batch=batch, + loop=self.loop, + max_num_neighbors=self.max_num_neighbors + 1, + ) + + # make sure we didn't miss any neighbors due to max_num_neighbors + assert not ( + torch.unique(edge_index[0], return_counts=True)[1] > self.max_num_neighbors + ).any(), ( + "The neighbor search missed some atoms due to max_num_neighbors being too low. " + "Please increase this parameter to include the maximum number of atoms within the cutoff." + ) + + edge_vec = pos[edge_index[0]] - pos[edge_index[1]] + + mask: Optional[torch.Tensor] = None + if self.loop: + # mask out self loops when computing distances because + # the norm of 0 produces NaN gradients + # NOTE: might influence force predictions as self loop gradients are ignored + mask = edge_index[0] != edge_index[1] + edge_weight = torch.zeros( + edge_vec.size(0), device=edge_vec.device, dtype=edge_vec.dtype + ) + edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1) + else: + edge_weight = torch.norm(edge_vec, dim=-1) + + lower_mask = edge_weight >= self.cutoff_lower + if self.loop and mask is not None: + # keep self loops even though they might be below the lower cutoff + lower_mask = lower_mask | ~mask + edge_index = edge_index[:, lower_mask] + edge_weight = edge_weight[lower_mask] + + if self.return_vecs: + edge_vec = edge_vec[lower_mask] + return edge_index, edge_weight, edge_vec + # TODO: return only `edge_index` and `edge_weight` once + # Union typing works with TorchScript (https://github.com/pytorch/pytorch/pull/53180) + return edge_index, edge_weight, None def benchmark_neighbors( diff --git a/environment.yml b/environment.yml index ddada89d6..5814ed3ba 100644 --- a/environment.yml +++ b/environment.yml @@ -7,7 +7,6 @@ dependencies: - nnpops==0.5 - pip - pytorch==2.0.* - - pytorch_cluster==1.6.1 - pytorch_geometric==2.3.1 - pytorch_scatter==2.1.1 - pytorch_sparse==0.6.17 diff --git a/tests/test_cfconv.py b/tests/test_cfconv.py index f2aeb57a6..94c7a2bbc 100644 --- a/tests/test_cfconv.py +++ b/tests/test_cfconv.py @@ -2,7 +2,7 @@ from pytest import mark import torch as pt from torchmdnet.models.torchmd_gn import CFConv as RefCFConv -from torchmdnet.models.utils import Distance, GaussianSmearing, ShiftedSoftplus +from torchmdnet.models.utils import OptimizedDistance, GaussianSmearing, ShiftedSoftplus from NNPOps.CFConv import CFConv from NNPOps.CFConvNeighbors import CFConvNeighbors @@ -27,7 +27,7 @@ def test_cfconv(device, num_atoms, num_filters, num_rbfs, cutoff_upper): input = pt.rand(num_atoms, num_filters, dtype=pt.float32, device=device) # Construct a non-optimized CFConv object - dist = Distance(0.0, cutoff_upper).to(device) + dist = OptimizedDistance(0.0, cutoff_upper).to(device) rbf = GaussianSmearing(0.0, cutoff_upper, num_rbfs, trainable=False).to(device) net = pt.nn.Sequential( pt.nn.Linear(num_rbfs, num_filters), diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py index 7fc07b03e..2c36db0a9 100644 --- a/tests/test_model_utils.py +++ b/tests/test_model_utils.py @@ -2,7 +2,7 @@ import torch from torch.autograd import grad from torchmdnet.models.model import create_model -from torchmdnet.models.utils import Distance +from torchmdnet.models.utils import OptimizedDistance from utils import load_example_args @@ -11,10 +11,10 @@ @mark.parametrize("return_vecs", [False, True]) @mark.parametrize("loop", [False, True]) def test_distance_calculation(cutoff_lower, cutoff_upper, return_vecs, loop): - dist = Distance( + dist = OptimizedDistance( cutoff_lower, cutoff_upper, - max_num_neighbors=100, + max_num_pairs=-100, return_vecs=return_vecs, loop=loop, ) @@ -65,7 +65,7 @@ def test_distance_calculation(cutoff_lower, cutoff_upper, return_vecs, loop): def test_neighbor_count_error(): - dist = Distance(0, 5, max_num_neighbors=32) + dist = OptimizedDistance(0, 5, max_num_pairs=-32) # single molecule that should produce an error due to exceeding # the maximum number of neighbors diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index b62a19ea2..8bff3e2b2 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -3,8 +3,7 @@ import torch import torch.jit import numpy as np -from torchmdnet.models.utils import Distance, OptimizedDistance - +from torchmdnet.models.utils import OptimizedDistance def sort_neighbors(neighbors, deltas, distances): i_sorted = np.lexsort(neighbors) @@ -293,116 +292,6 @@ def test_neighbor_autograds( torch.autograd.gradcheck(lambda_dist, (positions, batch), eps=1e-4, atol=1e-4, rtol=1e-4, nondet_tol=1e-4) torch.autograd.gradgradcheck(lambda_dist, (positions, batch), eps=1e-4, atol=1e-4, rtol=1e-4, nondet_tol=1e-4) -@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")]) -@pytest.mark.parametrize("n_batches", [1, 2, 3, 4]) -@pytest.mark.parametrize("cutoff", [0.1, 1.0, 1000.0]) -@pytest.mark.parametrize("loop", [True, False]) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) -@pytest.mark.parametrize("grad", ["deltas", "distances", "combined"]) -def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop, dtype, grad): - if device == "cuda" and not torch.cuda.is_available(): - pytest.skip("CUDA not available") - if device == "cpu" and strategy != "brute": - pytest.skip("Only brute force supported on CPU") - - torch.manual_seed(4321) - n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) - batch = torch.repeat_interleave( - torch.arange(n_batches, dtype=torch.long), n_atoms_per_batch - ).to(device) - cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) - lbox = 10.0 - pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox - # Ensure there is at least one pair - pos[0, :] = torch.zeros(3) - pos[1, :] = torch.zeros(3) - - ref_pos_cpu = pos.clone().detach() - ref_neighbors, _, _ = compute_ref_neighbors( - ref_pos_cpu, batch, loop, True, cutoff, None - ) - # Find the particle appearing in the most pairs - max_num_neighbors = int(torch.max(torch.bincount(torch.tensor(ref_neighbors[0, :])))) - d = Distance( - cutoff_lower=0.0, - cutoff_upper=cutoff, - loop=loop, - max_num_neighbors=max_num_neighbors, - return_vecs=True, - ) - ref_pos = pos.clone().detach().to(device) - ref_pos.requires_grad = True - ref_neighbors, ref_distances, ref_distance_vecs = d(ref_pos, batch) - - max_num_pairs = ref_neighbors.shape[1] - nl = OptimizedDistance( - cutoff_lower=0.0, - loop=loop, - cutoff_upper=cutoff, - max_num_pairs=-max_num_neighbors, - strategy=strategy, - return_vecs=True, - include_transpose=True, - ) - pos.requires_grad = True - neighbors, distances, distance_vecs = nl(pos, batch) - # Compute gradients - if grad == "deltas": - ref_distance_vecs.sum().backward(retain_graph=True) - distance_vecs.sum().backward(retain_graph=True) - elif grad == "distances": - ref_distances.sum().backward(retain_graph=True) - distances.sum().backward(retain_graph=True) - elif grad == "combined": - (ref_distance_vecs.sum() + ref_distances.sum()).backward(retain_graph=True) - (distance_vecs.sum() + distances.sum()).backward(retain_graph=True) - else: - raise ValueError("grad") - # Save the gradients (first derivatives) - ref_first_deriv = ref_pos.grad.clone().requires_grad_(True) - first_deriv = pos.grad.clone().requires_grad_(True) - - # Check first derivatives are correct - ref_pos_grad_sorted = ref_first_deriv.cpu().detach().numpy() - pos_grad_sorted = first_deriv.cpu().detach().numpy() - if dtype == torch.float32: - assert np.allclose(ref_pos_grad_sorted, pos_grad_sorted, atol=1e-2, rtol=1e-2) - else: - assert np.allclose(ref_pos_grad_sorted, pos_grad_sorted, atol=1e-8, rtol=1e-5) - - # Zero out the gradients of ref_positions and positions - ref_pos.grad.zero_() - pos.grad.zero_() - - # Compute second derivatives - ref_first_deriv.sum().backward() # compute second derivatives - first_deriv.sum().backward() # compute second derivatives - - # Check second derivatives are correct - ref_pos_grad2_sorted = ref_pos.grad.cpu().detach().numpy() - pos_grad2_sorted = pos.grad.cpu().detach().numpy() - if dtype == torch.float32: - assert np.allclose(ref_pos_grad2_sorted, pos_grad2_sorted, atol=1e-2, rtol=1e-2) - else: - assert np.allclose(ref_pos_grad2_sorted, pos_grad2_sorted, atol=1e-8, rtol=1e-5) - - ref_neighbors = ref_neighbors.cpu().detach().numpy() - ref_distance_vecs = ref_distance_vecs.cpu().detach().numpy() - ref_distances = ref_distances.cpu().detach().numpy() - ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors( - ref_neighbors, ref_distance_vecs, ref_distances - ) - - neighbors = neighbors.cpu().detach().numpy() - distance_vecs = distance_vecs.cpu().detach().numpy() - distances = distances.cpu().detach().numpy() - neighbors, distance_vecs, distances = sort_neighbors( - neighbors, distance_vecs, distances - ) - assert np.allclose(neighbors, ref_neighbors) - assert np.allclose(distances, ref_distances) - assert np.allclose(distance_vecs, ref_distance_vecs) - @pytest.mark.parametrize("strategy", ["brute", "cell", "shared"]) @pytest.mark.parametrize("n_batches", [1, 2, 3, 4]) @@ -427,17 +316,9 @@ def test_large_size(strategy, n_batches): pos.requires_grad = True # Find the particle appearing in the most pairs max_num_neighbors = 64 - d = Distance( - cutoff_lower=0.0, - cutoff_upper=cutoff, - loop=loop, - max_num_neighbors=max_num_neighbors, - return_vecs=True, + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors( + pos, batch, loop, True, cutoff, None ) - ref_neighbors, ref_distances, ref_distance_vecs = d(pos, batch) - ref_neighbors = ref_neighbors.cpu().detach().numpy() - ref_distance_vecs = ref_distance_vecs.cpu().detach().numpy() - ref_distances = ref_distances.cpu().detach().numpy() ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors( ref_neighbors, ref_distance_vecs, ref_distances ) diff --git a/torchmdnet/extensions/neighbors/common.cuh b/torchmdnet/extensions/neighbors/common.cuh index 3585f4719..db83b95d8 100644 --- a/torchmdnet/extensions/neighbors/common.cuh +++ b/torchmdnet/extensions/neighbors/common.cuh @@ -94,12 +94,14 @@ template struct PairListAccessor { template __device__ void writeAtomPair(PairListAccessor& list, int i, int j, scalar3 delta, scalar_t distance, int i_pair) { + if(i_pair < list.neighbors.size(1)){ list.neighbors[0][i_pair] = i; list.neighbors[1][i_pair] = j; list.deltas[i_pair][0] = delta.x; list.deltas[i_pair][1] = delta.y; list.deltas[i_pair][2] = delta.z; list.distances[i_pair] = distance; + } } template @@ -107,11 +109,9 @@ __device__ void addAtomPairToList(PairListAccessor& list, int i, int j scalar3 delta, scalar_t distance, bool add_transpose) { const int32_t i_pair = atomicAdd(&list.i_curr_pair[0], add_transpose ? 2 : 1); // Neighbors after the max number of pairs are ignored, although the pair is counted - if (i_pair + add_transpose < list.neighbors.size(1)) { - writeAtomPair(list, i, j, delta, distance, i_pair); - if (add_transpose) { - writeAtomPair(list, j, i, {-delta.x, -delta.y, -delta.z}, distance, i_pair + 1); - } + writeAtomPair(list, i, j, delta, distance, i_pair); + if (add_transpose) { + writeAtomPair(list, j, i, {-delta.x, -delta.y, -delta.z}, distance, i_pair + 1); } } diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 8c78e766c..3b7f7ec9d 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -5,7 +5,7 @@ from torchmdnet.models.utils import ( NeighborEmbedding, CosineCutoff, - Distance, + OptimizedDistance, rbf_class_mapping, act_class_mapping, ) @@ -107,8 +107,8 @@ def __init__( self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) - self.distance = Distance( - cutoff_lower, cutoff_upper, max_num_neighbors=max_num_neighbors + self.distance = OptimizedDistance( + cutoff_lower, cutoff_upper, max_num_pairs=-max_num_neighbors ) self.distance_expansion = rbf_class_mapping[rbf_type]( cutoff_lower, cutoff_upper, num_rbf, trainable_rbf, dtype=dtype diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index dce869880..8c93e1317 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -5,11 +5,11 @@ from torchmdnet.models.utils import ( NeighborEmbedding, CosineCutoff, - Distance, + OptimizedDistance, rbf_class_mapping, act_class_mapping, ) - +from torch_scatter import scatter class TorchMD_T(nn.Module): r"""The TorchMD Transformer architecture. @@ -99,8 +99,8 @@ def __init__( self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype) - self.distance = Distance( - cutoff_lower, cutoff_upper, max_num_neighbors=max_num_neighbors, loop=True + self.distance = OptimizedDistance( + cutoff_lower, cutoff_upper, max_num_pairs=-max_num_neighbors, loop=True ) self.distance_expansion = rbf_class_mapping[rbf_type]( cutoff_lower, cutoff_upper, num_rbf, trainable_rbf, dtype=dtype diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 134aa3906..c94313103 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -5,7 +5,6 @@ from torch import nn import torch.nn.functional as F from torch_geometric.nn import MessagePassing -from torch_cluster import radius_graph from torchmdnet.extensions import get_neighbor_pairs_kernel import warnings @@ -230,7 +229,7 @@ def forward( """ self.box = self.box.to(pos.dtype) - max_pairs = self.max_num_pairs + max_pairs : int = self.max_num_pairs if self.max_num_pairs < 0: max_pairs = -self.max_num_pairs * pos.shape[0] if batch is None: @@ -239,7 +238,7 @@ def forward( strategy=self.strategy, positions=pos, batch=batch, - max_num_pairs=max_pairs, + max_num_pairs=int(max_pairs), cutoff_lower=self.cutoff_lower, cutoff_upper=self.cutoff_upper, loop=self.loop, @@ -249,7 +248,7 @@ def forward( ) if self.check_errors: if num_pairs[0] > max_pairs: - raise RuntimeError( + raise AssertionError( "Found num_pairs({}) > max_num_pairs({})".format( num_pairs[0], max_pairs ) @@ -388,70 +387,6 @@ def forward(self, distances: Tensor) -> Tensor: cutoffs = cutoffs * (distances < self.cutoff_upper) return cutoffs - -class Distance(nn.Module): - def __init__( - self, - cutoff_lower, - cutoff_upper, - max_num_neighbors=32, - return_vecs=False, - loop=False, - ): - super(Distance, self).__init__() - self.cutoff_lower = cutoff_lower - self.cutoff_upper = cutoff_upper - self.max_num_neighbors = max_num_neighbors - self.return_vecs = return_vecs - self.loop = loop - - def forward(self, pos, batch): - edge_index = radius_graph( - pos, - r=self.cutoff_upper, - batch=batch, - loop=self.loop, - max_num_neighbors=self.max_num_neighbors + 1, - ) - - # make sure we didn't miss any neighbors due to max_num_neighbors - assert not ( - torch.unique(edge_index[0], return_counts=True)[1] > self.max_num_neighbors - ).any(), ( - "The neighbor search missed some atoms due to max_num_neighbors being too low. " - "Please increase this parameter to include the maximum number of atoms within the cutoff." - ) - - edge_vec = pos[edge_index[0]] - pos[edge_index[1]] - - mask: Optional[torch.Tensor] = None - if self.loop: - # mask out self loops when computing distances because - # the norm of 0 produces NaN gradients - # NOTE: might influence force predictions as self loop gradients are ignored - mask = edge_index[0] != edge_index[1] - edge_weight = torch.zeros( - edge_vec.size(0), device=edge_vec.device, dtype=edge_vec.dtype - ) - edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1) - else: - edge_weight = torch.norm(edge_vec, dim=-1) - - lower_mask = edge_weight >= self.cutoff_lower - if self.loop and mask is not None: - # keep self loops even though they might be below the lower cutoff - lower_mask = lower_mask | ~mask - edge_index = edge_index[:, lower_mask] - edge_weight = edge_weight[lower_mask] - - if self.return_vecs: - edge_vec = edge_vec[lower_mask] - return edge_index, edge_weight, edge_vec - # TODO: return only `edge_index` and `edge_weight` once - # Union typing works with TorchScript (https://github.com/pytorch/pytorch/pull/53180) - return edge_index, edge_weight, None - - class GatedEquivariantBlock(nn.Module): """Gated Equivariant Block as defined in Schütt et al. (2021): Equivariant message passing for the prediction of tensorial properties and molecular spectra diff --git a/torchmdnet/priors/coulomb.py b/torchmdnet/priors/coulomb.py index 75e335fc2..45972f0e5 100644 --- a/torchmdnet/priors/coulomb.py +++ b/torchmdnet/priors/coulomb.py @@ -1,7 +1,7 @@ import torch from torchmdnet.priors.base import BasePrior -from torchmdnet.models.utils import Distance from torch_scatter import scatter +from torchmdnet.models.utils import OptimizedDistance from typing import Optional, Dict class Coulomb(BasePrior): @@ -20,7 +20,7 @@ def __init__(self, alpha, max_num_neighbors, distance_scale=None, energy_scale=N distance_scale = dataset.distance_scale if energy_scale is None: energy_scale = dataset.energy_scale - self.distance = Distance(0, torch.inf, max_num_neighbors=max_num_neighbors) + self.distance = OptimizedDistance(0, torch.inf, max_num_pairs=-max_num_neighbors) self.alpha = alpha self.max_num_neighbors = max_num_neighbors self.distance_scale = float(distance_scale) diff --git a/torchmdnet/priors/d2.py b/torchmdnet/priors/d2.py index f526c0d1e..953e5a54b 100644 --- a/torchmdnet/priors/d2.py +++ b/torchmdnet/priors/d2.py @@ -1,5 +1,5 @@ from torchmdnet.priors.base import BasePrior -from torchmdnet.models.utils import Distance +from torchmdnet.models.utils import OptimizedDistance import torch as pt from torch_scatter import scatter @@ -135,10 +135,10 @@ def __init__( ) # Distance calculator - self.distances = Distance( + self.distances = OptimizedDistance( cutoff_lower=0, cutoff_upper=self.cutoff_distance, - max_num_neighbors=self.max_num_neighbors, + max_num_pairs=-self.max_num_neighbors, ) # Parameters (default values from the reference) diff --git a/torchmdnet/priors/zbl.py b/torchmdnet/priors/zbl.py index 9dc73a654..2e896ea4e 100644 --- a/torchmdnet/priors/zbl.py +++ b/torchmdnet/priors/zbl.py @@ -1,7 +1,7 @@ import torch from torchmdnet.priors.base import BasePrior -from torchmdnet.models.utils import Distance, CosineCutoff from torch_scatter import scatter +from torchmdnet.models.utils import OptimizedDistance, CosineCutoff from typing import Optional, Dict class ZBL(BasePrior): @@ -26,7 +26,7 @@ def __init__(self, cutoff_distance, max_num_neighbors, atomic_number=None, dista energy_scale = dataset.energy_scale atomic_number = torch.as_tensor(atomic_number, dtype=torch.long) self.register_buffer("atomic_number", atomic_number) - self.distance = Distance(0, cutoff_distance, max_num_neighbors=max_num_neighbors) + self.distance = OptimizedDistance(0, cutoff_distance, max_num_pairs=-max_num_neighbors) self.cutoff = CosineCutoff(cutoff_upper=cutoff_distance) self.cutoff_distance = cutoff_distance self.max_num_neighbors = max_num_neighbors