Skip to content

Commit

Permalink
Remove dependency on torch_cluster (#228)
Browse files Browse the repository at this point in the history
* Remove dependency on torch_cluster

* Remove import

* Change RuntimeError to AssertionError to mimic old Distance behavior

* Fix import

* Hint max_pairs to be an int

* Fix potential invalid access when adding neighbor pairs

* Fix test

* Fix test again
  • Loading branch information
RaulPPelaez authored Oct 9, 2023
1 parent a54a8a6 commit 0ab52d5
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 217 deletions.
67 changes: 66 additions & 1 deletion benchmarks/neighbors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,72 @@
import os
import torch
import numpy as np
from torchmdnet.models.utils import Distance, OptimizedDistance
from torchmdnet.models.utils import OptimizedDistance
from torch import nn
from typing import Optional
from torch_cluster import radius_graph

class Distance(nn.Module):
def __init__(
self,
cutoff_lower,
cutoff_upper,
max_num_neighbors=32,
return_vecs=False,
loop=False,
):
super(Distance, self).__init__()
self.cutoff_lower = cutoff_lower
self.cutoff_upper = cutoff_upper
self.max_num_neighbors = max_num_neighbors
self.return_vecs = return_vecs
self.loop = loop

def forward(self, pos, batch):
edge_index = radius_graph(
pos,
r=self.cutoff_upper,
batch=batch,
loop=self.loop,
max_num_neighbors=self.max_num_neighbors + 1,
)

# make sure we didn't miss any neighbors due to max_num_neighbors
assert not (
torch.unique(edge_index[0], return_counts=True)[1] > self.max_num_neighbors
).any(), (
"The neighbor search missed some atoms due to max_num_neighbors being too low. "
"Please increase this parameter to include the maximum number of atoms within the cutoff."
)

edge_vec = pos[edge_index[0]] - pos[edge_index[1]]

mask: Optional[torch.Tensor] = None
if self.loop:
# mask out self loops when computing distances because
# the norm of 0 produces NaN gradients
# NOTE: might influence force predictions as self loop gradients are ignored
mask = edge_index[0] != edge_index[1]
edge_weight = torch.zeros(
edge_vec.size(0), device=edge_vec.device, dtype=edge_vec.dtype
)
edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1)
else:
edge_weight = torch.norm(edge_vec, dim=-1)

lower_mask = edge_weight >= self.cutoff_lower
if self.loop and mask is not None:
# keep self loops even though they might be below the lower cutoff
lower_mask = lower_mask | ~mask
edge_index = edge_index[:, lower_mask]
edge_weight = edge_weight[lower_mask]

if self.return_vecs:
edge_vec = edge_vec[lower_mask]
return edge_index, edge_weight, edge_vec
# TODO: return only `edge_index` and `edge_weight` once
# Union typing works with TorchScript (https://github.com/pytorch/pytorch/pull/53180)
return edge_index, edge_weight, None


def benchmark_neighbors(
Expand Down
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ dependencies:
- nnpops==0.5
- pip
- pytorch==2.0.*
- pytorch_cluster==1.6.1
- pytorch_geometric==2.3.1
- pytorch_scatter==2.1.1
- pytorch_sparse==0.6.17
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cfconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pytest import mark
import torch as pt
from torchmdnet.models.torchmd_gn import CFConv as RefCFConv
from torchmdnet.models.utils import Distance, GaussianSmearing, ShiftedSoftplus
from torchmdnet.models.utils import OptimizedDistance, GaussianSmearing, ShiftedSoftplus

from NNPOps.CFConv import CFConv
from NNPOps.CFConvNeighbors import CFConvNeighbors
Expand All @@ -27,7 +27,7 @@ def test_cfconv(device, num_atoms, num_filters, num_rbfs, cutoff_upper):
input = pt.rand(num_atoms, num_filters, dtype=pt.float32, device=device)

# Construct a non-optimized CFConv object
dist = Distance(0.0, cutoff_upper).to(device)
dist = OptimizedDistance(0.0, cutoff_upper).to(device)
rbf = GaussianSmearing(0.0, cutoff_upper, num_rbfs, trainable=False).to(device)
net = pt.nn.Sequential(
pt.nn.Linear(num_rbfs, num_filters),
Expand Down
8 changes: 4 additions & 4 deletions tests/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch.autograd import grad
from torchmdnet.models.model import create_model
from torchmdnet.models.utils import Distance
from torchmdnet.models.utils import OptimizedDistance
from utils import load_example_args


Expand All @@ -11,10 +11,10 @@
@mark.parametrize("return_vecs", [False, True])
@mark.parametrize("loop", [False, True])
def test_distance_calculation(cutoff_lower, cutoff_upper, return_vecs, loop):
dist = Distance(
dist = OptimizedDistance(
cutoff_lower,
cutoff_upper,
max_num_neighbors=100,
max_num_pairs=-100,
return_vecs=return_vecs,
loop=loop,
)
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_distance_calculation(cutoff_lower, cutoff_upper, return_vecs, loop):


def test_neighbor_count_error():
dist = Distance(0, 5, max_num_neighbors=32)
dist = OptimizedDistance(0, 5, max_num_pairs=-32)

# single molecule that should produce an error due to exceeding
# the maximum number of neighbors
Expand Down
125 changes: 3 additions & 122 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import torch
import torch.jit
import numpy as np
from torchmdnet.models.utils import Distance, OptimizedDistance

from torchmdnet.models.utils import OptimizedDistance

def sort_neighbors(neighbors, deltas, distances):
i_sorted = np.lexsort(neighbors)
Expand Down Expand Up @@ -293,116 +292,6 @@ def test_neighbor_autograds(
torch.autograd.gradcheck(lambda_dist, (positions, batch), eps=1e-4, atol=1e-4, rtol=1e-4, nondet_tol=1e-4)
torch.autograd.gradgradcheck(lambda_dist, (positions, batch), eps=1e-4, atol=1e-4, rtol=1e-4, nondet_tol=1e-4)

@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")])
@pytest.mark.parametrize("n_batches", [1, 2, 3, 4])
@pytest.mark.parametrize("cutoff", [0.1, 1.0, 1000.0])
@pytest.mark.parametrize("loop", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
@pytest.mark.parametrize("grad", ["deltas", "distances", "combined"])
def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop, dtype, grad):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")
if device == "cpu" and strategy != "brute":
pytest.skip("Only brute force supported on CPU")

torch.manual_seed(4321)
n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,))
batch = torch.repeat_interleave(
torch.arange(n_batches, dtype=torch.long), n_atoms_per_batch
).to(device)
cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch]))
lbox = 10.0
pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox
# Ensure there is at least one pair
pos[0, :] = torch.zeros(3)
pos[1, :] = torch.zeros(3)

ref_pos_cpu = pos.clone().detach()
ref_neighbors, _, _ = compute_ref_neighbors(
ref_pos_cpu, batch, loop, True, cutoff, None
)
# Find the particle appearing in the most pairs
max_num_neighbors = int(torch.max(torch.bincount(torch.tensor(ref_neighbors[0, :]))))
d = Distance(
cutoff_lower=0.0,
cutoff_upper=cutoff,
loop=loop,
max_num_neighbors=max_num_neighbors,
return_vecs=True,
)
ref_pos = pos.clone().detach().to(device)
ref_pos.requires_grad = True
ref_neighbors, ref_distances, ref_distance_vecs = d(ref_pos, batch)

max_num_pairs = ref_neighbors.shape[1]
nl = OptimizedDistance(
cutoff_lower=0.0,
loop=loop,
cutoff_upper=cutoff,
max_num_pairs=-max_num_neighbors,
strategy=strategy,
return_vecs=True,
include_transpose=True,
)
pos.requires_grad = True
neighbors, distances, distance_vecs = nl(pos, batch)
# Compute gradients
if grad == "deltas":
ref_distance_vecs.sum().backward(retain_graph=True)
distance_vecs.sum().backward(retain_graph=True)
elif grad == "distances":
ref_distances.sum().backward(retain_graph=True)
distances.sum().backward(retain_graph=True)
elif grad == "combined":
(ref_distance_vecs.sum() + ref_distances.sum()).backward(retain_graph=True)
(distance_vecs.sum() + distances.sum()).backward(retain_graph=True)
else:
raise ValueError("grad")
# Save the gradients (first derivatives)
ref_first_deriv = ref_pos.grad.clone().requires_grad_(True)
first_deriv = pos.grad.clone().requires_grad_(True)

# Check first derivatives are correct
ref_pos_grad_sorted = ref_first_deriv.cpu().detach().numpy()
pos_grad_sorted = first_deriv.cpu().detach().numpy()
if dtype == torch.float32:
assert np.allclose(ref_pos_grad_sorted, pos_grad_sorted, atol=1e-2, rtol=1e-2)
else:
assert np.allclose(ref_pos_grad_sorted, pos_grad_sorted, atol=1e-8, rtol=1e-5)

# Zero out the gradients of ref_positions and positions
ref_pos.grad.zero_()
pos.grad.zero_()

# Compute second derivatives
ref_first_deriv.sum().backward() # compute second derivatives
first_deriv.sum().backward() # compute second derivatives

# Check second derivatives are correct
ref_pos_grad2_sorted = ref_pos.grad.cpu().detach().numpy()
pos_grad2_sorted = pos.grad.cpu().detach().numpy()
if dtype == torch.float32:
assert np.allclose(ref_pos_grad2_sorted, pos_grad2_sorted, atol=1e-2, rtol=1e-2)
else:
assert np.allclose(ref_pos_grad2_sorted, pos_grad2_sorted, atol=1e-8, rtol=1e-5)

ref_neighbors = ref_neighbors.cpu().detach().numpy()
ref_distance_vecs = ref_distance_vecs.cpu().detach().numpy()
ref_distances = ref_distances.cpu().detach().numpy()
ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors(
ref_neighbors, ref_distance_vecs, ref_distances
)

neighbors = neighbors.cpu().detach().numpy()
distance_vecs = distance_vecs.cpu().detach().numpy()
distances = distances.cpu().detach().numpy()
neighbors, distance_vecs, distances = sort_neighbors(
neighbors, distance_vecs, distances
)
assert np.allclose(neighbors, ref_neighbors)
assert np.allclose(distances, ref_distances)
assert np.allclose(distance_vecs, ref_distance_vecs)


@pytest.mark.parametrize("strategy", ["brute", "cell", "shared"])
@pytest.mark.parametrize("n_batches", [1, 2, 3, 4])
Expand All @@ -427,17 +316,9 @@ def test_large_size(strategy, n_batches):
pos.requires_grad = True
# Find the particle appearing in the most pairs
max_num_neighbors = 64
d = Distance(
cutoff_lower=0.0,
cutoff_upper=cutoff,
loop=loop,
max_num_neighbors=max_num_neighbors,
return_vecs=True,
ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(
pos, batch, loop, True, cutoff, None
)
ref_neighbors, ref_distances, ref_distance_vecs = d(pos, batch)
ref_neighbors = ref_neighbors.cpu().detach().numpy()
ref_distance_vecs = ref_distance_vecs.cpu().detach().numpy()
ref_distances = ref_distances.cpu().detach().numpy()
ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors(
ref_neighbors, ref_distance_vecs, ref_distances
)
Expand Down
10 changes: 5 additions & 5 deletions torchmdnet/extensions/neighbors/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -94,24 +94,24 @@ template <class scalar_t> struct PairListAccessor {
template <typename scalar_t>
__device__ void writeAtomPair(PairListAccessor<scalar_t>& list, int i, int j,
scalar3<scalar_t> delta, scalar_t distance, int i_pair) {
if(i_pair < list.neighbors.size(1)){
list.neighbors[0][i_pair] = i;
list.neighbors[1][i_pair] = j;
list.deltas[i_pair][0] = delta.x;
list.deltas[i_pair][1] = delta.y;
list.deltas[i_pair][2] = delta.z;
list.distances[i_pair] = distance;
}
}

template <typename scalar_t>
__device__ void addAtomPairToList(PairListAccessor<scalar_t>& list, int i, int j,
scalar3<scalar_t> delta, scalar_t distance, bool add_transpose) {
const int32_t i_pair = atomicAdd(&list.i_curr_pair[0], add_transpose ? 2 : 1);
// Neighbors after the max number of pairs are ignored, although the pair is counted
if (i_pair + add_transpose < list.neighbors.size(1)) {
writeAtomPair(list, i, j, delta, distance, i_pair);
if (add_transpose) {
writeAtomPair(list, j, i, {-delta.x, -delta.y, -delta.z}, distance, i_pair + 1);
}
writeAtomPair(list, i, j, delta, distance, i_pair);
if (add_transpose) {
writeAtomPair(list, j, i, {-delta.x, -delta.y, -delta.z}, distance, i_pair + 1);
}
}

Expand Down
6 changes: 3 additions & 3 deletions torchmdnet/models/torchmd_gn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torchmdnet.models.utils import (
NeighborEmbedding,
CosineCutoff,
Distance,
OptimizedDistance,
rbf_class_mapping,
act_class_mapping,
)
Expand Down Expand Up @@ -107,8 +107,8 @@ def __init__(

self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype)

self.distance = Distance(
cutoff_lower, cutoff_upper, max_num_neighbors=max_num_neighbors
self.distance = OptimizedDistance(
cutoff_lower, cutoff_upper, max_num_pairs=-max_num_neighbors
)
self.distance_expansion = rbf_class_mapping[rbf_type](
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf, dtype=dtype
Expand Down
8 changes: 4 additions & 4 deletions torchmdnet/models/torchmd_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from torchmdnet.models.utils import (
NeighborEmbedding,
CosineCutoff,
Distance,
OptimizedDistance,
rbf_class_mapping,
act_class_mapping,
)

from torch_scatter import scatter

class TorchMD_T(nn.Module):
r"""The TorchMD Transformer architecture.
Expand Down Expand Up @@ -99,8 +99,8 @@ def __init__(

self.embedding = nn.Embedding(self.max_z, hidden_channels, dtype=dtype)

self.distance = Distance(
cutoff_lower, cutoff_upper, max_num_neighbors=max_num_neighbors, loop=True
self.distance = OptimizedDistance(
cutoff_lower, cutoff_upper, max_num_pairs=-max_num_neighbors, loop=True
)
self.distance_expansion = rbf_class_mapping[rbf_type](
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf, dtype=dtype
Expand Down
Loading

0 comments on commit 0ab52d5

Please sign in to comment.