From 76a48d1cf2f82452727c2dc8c181a9f97caf6a1c Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Sat, 29 Apr 2023 18:46:40 +0200 Subject: [PATCH 01/76] Initial implementation, brute force --- tests/test_neighbors.py | 123 +++++++++++++++ torchmdnet/neighbors/__init__.py | 10 ++ torchmdnet/neighbors/neighbors.cpp | 5 + torchmdnet/neighbors/neighbors_cpu.cpp | 95 ++++++++++++ torchmdnet/neighbors/neighbors_cuda.cu | 206 +++++++++++++++++++++++++ 5 files changed, 439 insertions(+) create mode 100644 tests/test_neighbors.py create mode 100644 torchmdnet/neighbors/__init__.py create mode 100644 torchmdnet/neighbors/neighbors.cpp create mode 100644 torchmdnet/neighbors/neighbors_cpu.cpp create mode 100644 torchmdnet/neighbors/neighbors_cuda.cu diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py new file mode 100644 index 000000000..2696b5099 --- /dev/null +++ b/tests/test_neighbors.py @@ -0,0 +1,123 @@ +import os +import pytest +import torch +import numpy as np +from torch.cuda import check_error + +from torchmdnet.neighbors import get_neighbor_pairs + + +class DistanceCellList(torch.nn.Module): + def __init__( + self, + cutoff_upper, + max_num_pairs=32, + loop=False, + ): + super(DistanceCellList, self).__init__() + """ Compute the neighbor list for a given cutoff. + Parameters + ---------- + cutoff_upper : float + Upper cutoff for the neighbor list. + max_num_pairs : int + Maximum number of pairs to store. + loop : bool + Whether to include self interactions (pair (i,i)). + """ + self.cutoff_upper = cutoff_upper + self.max_num_pairs = max_num_pairs + self.loop = loop + + def forward(self, pos, batch): + """ + Parameters + ---------- + pos : torch.Tensor + shape (N, 3) + batch : torch.Tensor + shape (N,) + Returns + ------- + neighbors : torch.Tensor + List of neighbors for each atom in the batch. + shape (2, max_num_pairs) + distances : torch.Tensor + List of distances for each atom in the batch. + shape (max_num_pairs,) + distance_vecs : torch.Tensor + List of distance vectors for each atom in the batch. + shape (max_num_pairs, 3) + + """ + neighbors, distance_vecs, distances = get_neighbor_pairs( + pos, + cutoff=self.cutoff_upper, + batch=batch, + max_num_pairs=self.max_num_pairs, + check_errors=True + ) + return neighbors, distances, distance_vecs + +def sort_neighbors(neighbors, deltas, distances): + i_sorted = np.lexsort(neighbors) + return neighbors[:, i_sorted], deltas[i_sorted], distances[i_sorted] + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("n_batches", [1, 2, 3]) +@pytest.mark.parametrize("cutoff", [0.1, 1.5, 1000.0]) +def test_neighbors(device, n_batches, cutoff): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + n_atoms_per_batch = np.random.randint(2, 4, size=n_batches) + batch = torch.tensor([i for i in range(n_batches) for j in range(n_atoms_per_batch[i])]) + cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) + pos = torch.randn(cumsum[-1], 3, device=device) + #Ensure there is at least one pair + pos[0,:] = torch.zeros(3) + pos[1,:] = torch.zeros(3) + pos.requires_grad = True + print("Pos") + print(pos.shape) + print("Batch") + print(batch) + ref_neighbors = np.concatenate([np.tril_indices(n_atoms_per_batch[i], -1)+cumsum[i] for i in range(n_batches)], axis=1) + print("Neighbors_i concat") + print(ref_neighbors.shape) + print(ref_neighbors) + pos_np = pos.cpu().detach().numpy() + ref_distances = np.linalg.norm(pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]], axis=-1) + ref_distance_vecs = pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]] + ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors(ref_neighbors, ref_distance_vecs, ref_distances) + #remove pairs with distance > cutoff + mask = ref_distances < cutoff + ref_neighbors = ref_neighbors[:, mask] + ref_distance_vecs = ref_distance_vecs[mask] + ref_distances = ref_distances[mask] + max_num_pairs = ref_neighbors.shape[1] + + nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs) + batch.to(device) + neighbors, distances, distance_vecs = nl(pos, batch) + neighbors = neighbors.cpu().detach().numpy() + distance_vecs = distance_vecs.cpu().detach().numpy() + distances = distances.cpu().detach().numpy() + assert neighbors.shape == (2, max_num_pairs) + assert distances.shape == (max_num_pairs,) + assert distance_vecs.shape == (max_num_pairs, 3) + + print("Neighbors") + print(neighbors) + print(ref_neighbors) + # print("Distances") + # print(distances) + # print(ref_distances) + # print("Distance vecs") + # print(distance_vecs) + # print(ref_distance_vecs) + neighbors, distance_vecs, distances = sort_neighbors(neighbors, distance_vecs, distances) + + + assert np.allclose(neighbors, ref_neighbors) + assert np.allclose(distances, ref_distances) + assert np.allclose(distance_vecs, ref_distance_vecs) diff --git a/torchmdnet/neighbors/__init__.py b/torchmdnet/neighbors/__init__.py new file mode 100644 index 000000000..24919f466 --- /dev/null +++ b/torchmdnet/neighbors/__init__.py @@ -0,0 +1,10 @@ +import os +import torch as pt +from torch.utils import cpp_extension + +src_dir = os.path.dirname(__file__) +sources = ['neighbors.cpp', 'neighbors_cpu.cpp'] + (['neighbors_cuda.cu'] if pt.cuda.is_available() else []) +sources = [os.path.join(src_dir, name) for name in sources] + +cpp_extension.load(name='neighbors', sources=sources, is_python_module=False) +get_neighbor_pairs = pt.ops.neighbors.get_neighbor_pairs diff --git a/torchmdnet/neighbors/neighbors.cpp b/torchmdnet/neighbors/neighbors.cpp new file mode 100644 index 000000000..45af7f723 --- /dev/null +++ b/torchmdnet/neighbors/neighbors.cpp @@ -0,0 +1,5 @@ +#include + +TORCH_LIBRARY(neighbors, m) { + m.def("get_neighbor_pairs(Tensor positions, Tensor batch, Scalar cutoff, Scalar max_num_pairs, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); +} diff --git a/torchmdnet/neighbors/neighbors_cpu.cpp b/torchmdnet/neighbors/neighbors_cpu.cpp new file mode 100644 index 000000000..57bc5e332 --- /dev/null +++ b/torchmdnet/neighbors/neighbors_cpu.cpp @@ -0,0 +1,95 @@ +#include +#include + +using std::tuple; +using torch::div; +using torch::full; +using torch::index_select; +using torch::indexing::Slice; +using torch::arange; +using torch::frobenius_norm; +using torch::kInt32; +using torch::Scalar; +using torch::hstack; +using torch::vstack; +using torch::Tensor; +using torch::outer; +using torch::round; + +static tuple forward(const Tensor& positions, const Tensor& batch, const Scalar& cutoff, + const Scalar& max_num_pairs, bool checkErrors) { + TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); + TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); + TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); + TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); + + TORCH_CHECK(cutoff.to() > 0, "Expected \"cutoff\" to be positive"); + auto box_vectors = torch::empty(0); + if (box_vectors.size(0) != 0) { + TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions"); + TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, + "Expected \"box_vectors\" to have shape (3, 3)"); + double v[3][3]; + for (int i = 0; i < 3; i++) + for (int j = 0; j < 3; j++) + v[i][j] = box_vectors[i][j].item(); + double c = cutoff.to(); + TORCH_CHECK(v[0][1] == 0, "Invalid box vectors: box_vectors[0][1] != 0"); + TORCH_CHECK(v[0][2] == 0, "Invalid box vectors: box_vectors[0][2] != 0"); + TORCH_CHECK(v[1][2] == 0, "Invalid box vectors: box_vectors[1][2] != 0"); + TORCH_CHECK(v[0][0] >= 2 * c, "Invalid box vectors: box_vectors[0][0] < 2*cutoff"); + TORCH_CHECK(v[1][1] >= 2 * c, "Invalid box vectors: box_vectors[1][1] < 2*cutoff"); + TORCH_CHECK(v[2][2] >= 2 * c, "Invalid box vectors: box_vectors[2][2] < 2*cutoff"); + TORCH_CHECK(v[0][0] >= 2 * v[1][0], "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]"); + TORCH_CHECK(v[0][0] >= 2 * v[2][0], "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]"); + TORCH_CHECK(v[1][1] >= 2 * v[2][1], "Invalid box vectors: box_vectors[1][1] < 2*box_vectors[2][1]"); + } + const int max_num_pairs_ = max_num_pairs.to(); + TORCH_CHECK(max_num_pairs_ > 0, + "Expected \"max_num_neighbors\" to be positive"); + const int n_atoms = positions.size(0); + const int n_batches = batch[n_atoms - 1].item() + 1; + int current_offset = 0; + std::vector batch_i; + int n_pairs = 0; + Tensor neighbors = torch::empty({0}, positions.options().dtype(kInt32)); + Tensor distances = torch::empty({0}, positions.options()); + Tensor deltas = torch::empty({0}, positions.options()); + for(int i = 0; i < n_batches; i++){ + batch_i.clear(); + for(int j = current_offset; j < n_atoms; j++){ + if(batch[j].item() == i){ + batch_i.push_back(j); + } + else{ + break; + } + } + const int n_atoms_i = batch_i.size(); + Tensor positions_i = index_select(positions, 0, torch::tensor(batch_i, kInt32)); + Tensor indices_i = arange(0, n_atoms_i * (n_atoms_i - 1) / 2, positions.options().dtype(kInt32)); + Tensor rows_i = (((8 * indices_i + 1).sqrt() + 1) / 2).floor().to(kInt32); + rows_i -= (rows_i * (rows_i - 1) > 2 * indices_i).to(kInt32); + Tensor columns_i = indices_i - div(rows_i * (rows_i - 1), 2, "floor"); + Tensor neighbors_i = vstack({rows_i, columns_i}); + Tensor deltas_i = index_select(positions_i, 0, rows_i) - index_select(positions_i, 0, columns_i); + Tensor distances_i = frobenius_norm(deltas_i, 1); + const Tensor mask = distances_i <= cutoff; + neighbors_i = neighbors_i.index({Slice(), mask}) + current_offset; + n_pairs += distances_i.size(0); + TORCH_CHECK(n_pairs >= 0, "The maximum number of pairs has been exceed! Increase \"max_num_neighbors\""); + neighbors = torch::hstack({neighbors, neighbors_i}); + current_offset += n_atoms_i; + } + if(n_batches > 1){ + neighbors = torch::cat(neighbors,0).to(kInt32); + } + deltas = index_select(positions, 0, neighbors[0]) - index_select(positions, 0, neighbors[1]); + distances = frobenius_norm(deltas, 1); + + return {neighbors, deltas, distances}; +} + +TORCH_LIBRARY_IMPL(neighbors, CPU, m) { + m.impl("get_neighbor_pairs", &forward); +} diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu new file mode 100644 index 000000000..15a2a687c --- /dev/null +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -0,0 +1,206 @@ +#include +#include +#include +#include +#include + +using c10::cuda::CUDAStreamGuard; +using c10::cuda::getCurrentCUDAStream; +using std::make_tuple; +using std::max; +using torch::empty; +using torch::full; +using torch::kInt32; +using torch::PackedTensorAccessor32; +using torch::RestrictPtrTraits; +using torch::Scalar; +using torch::Tensor; +using torch::TensorOptions; +using torch::zeros; +using torch::autograd::AutogradContext; +using torch::autograd::Function; +using torch::autograd::tensor_list; + +template +using Accessor = PackedTensorAccessor32; + +template inline Accessor get_accessor(const Tensor& tensor) { + return tensor.packed_accessor32(); +}; + +template __device__ __forceinline__ scalar_t sqrt_(scalar_t x){}; +template <> __device__ __forceinline__ float sqrt_(float x) { + return ::sqrtf(x); +}; +template <> __device__ __forceinline__ double sqrt_(double x) { + return ::sqrt(x); +}; + +template __global__ void forward_kernel( + const int32_t num_all_pairs, + const Accessor positions, + const Accessor batch_offsets, + const int32_t batch_index, + const scalar_t cutoff2, + Accessor i_curr_pair, + Accessor neighbors, + Accessor deltas, + Accessor distances) { + const int32_t index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= num_all_pairs) return; + + int32_t row = floor((sqrtf(8 * index + 1) + 1) / 2); + if (row * (row - 1) > 2 * index) row--; + const int32_t position_offset = batch_index?batch_offsets[batch_index-1]:0; + const int32_t column = (index - row * (row - 1) / 2) + position_offset; + row += position_offset; + scalar_t delta_x = positions[row][0] - positions[column][0]; + scalar_t delta_y = positions[row][1] - positions[column][1]; + scalar_t delta_z = positions[row][2] - positions[column][2]; + const scalar_t distance2 = delta_x * delta_x + delta_y * delta_y + delta_z * delta_z; + + if (distance2 > cutoff2) return; + + const int32_t i_pair = atomicAdd(&i_curr_pair[0], 1); + //We handle too many neighbors outside of the kernel + if (i_pair < neighbors.size(1)) { + neighbors[0][i_pair] = row; + neighbors[1][i_pair] = column; + deltas[i_pair][0] = delta_x; + deltas[i_pair][1] = delta_y; + deltas[i_pair][2] = delta_z; + distances[i_pair] = sqrt_(distance2); + } +} + +template +__global__ void backward_kernel(const Accessor neighbors, const Accessor deltas, + const Accessor distances, const Accessor grad_distances, + Accessor grad_positions) { + const int32_t i_pair = blockIdx.x * blockDim.x + threadIdx.x; + const int32_t num_pairs = neighbors.size(1); + if (i_pair >= num_pairs) + return; + + const int32_t i_dir = blockIdx.y; + const int32_t i_atom = neighbors[i_dir][i_pair]; + if (i_atom < 0) + return; + + const int32_t i_comp = blockIdx.z; + const scalar_t grad = deltas[i_pair][i_comp] / distances[i_pair] * grad_distances[i_pair]; + atomicAdd(&grad_positions[i_atom][i_comp], (i_dir ? -1 : 1) * grad); +} + +static void checkInput(const Tensor& positions, const Tensor& batch) { + TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); + TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); + TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); + TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); + + TORCH_CHECK(batch.dim() == 1, "Expected \"batch\" to have one dimension"); + TORCH_CHECK( + batch.size(0) == positions.size(0), + "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension size of \"positions\""); + TORCH_CHECK(batch.is_contiguous(), "Expected \"batch\" to be contiguous"); +} + +class Autograd : public Function { +public: + static tensor_list forward(AutogradContext* ctx, const Tensor& positions, + const Tensor& batch, + const Scalar& cutoff, const Scalar& max_num_pairs, + bool checkErrors){ + // This version works with batches + // Batch contains the molecule index for each atom in positions + // Neighbors are only calculated within the same molecule + // Batch is a 1D tensor of size (N_atoms) + // Batch is assumed to be sorted and starts at zero. + // Batch is assumed to be contiguous + // Batch is assumed to be of type int32 + // Batch is assumed to be non-negative + // Each batch can have a different number of atoms + checkInput(positions, batch); + const int max_num_pairs_ = max_num_pairs.to(); + TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); + const int num_atoms = positions.size(0); + const int num_pairs = max_num_pairs_; + const Tensor num_atoms_per_batch = torch::bincount(batch); + const int n_batches = num_atoms_per_batch.size(0); + const TensorOptions options = positions.options(); + const Tensor batch_offsets = torch::cumsum(num_atoms_per_batch, 0, torch::kInt32).to(positions.device()); + const auto stream = getCurrentCUDAStream(positions.get_device()); + + const Tensor neighbors = full({2, num_pairs}, -1, options.dtype(kInt32)); + const Tensor deltas = empty({num_pairs, 3}, options); + const Tensor distances = full(num_pairs, 0, options); + const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); + { + const CUDAStreamGuard guard(stream); + for(int i=0; i(); + const int num_all_pairs = num_atoms * (num_atoms - 1) / 2; + const int num_threads = 128; + const int num_blocks = max((num_all_pairs + num_threads - 1) / num_threads, 1); + AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { + const scalar_t cutoff_ = cutoff.to(); + TORCH_CHECK(cutoff_ > 0, "Expected \"cutoff\" to be positive"); + forward_kernel<<>>( + num_all_pairs, + get_accessor(positions), + get_accessor(batch_offsets), + i, + cutoff_ * cutoff_, get_accessor(i_curr_pair), + get_accessor(neighbors), get_accessor(deltas), + get_accessor(distances)); + }); + } + } + // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA graphs + if (checkErrors) { + int num_found_pairs = i_curr_pair.item(); + TORCH_CHECK(num_found_pairs <= max_num_pairs_, "Too many neighbor pairs found. Maximum is " + std::to_string(max_num_pairs_), " but found " + std::to_string(num_found_pairs)); + } + + neighbors.resize_({2, i_curr_pair[0].item()}); + deltas.resize_({i_curr_pair[0].item(), 3}); + distances.resize_(i_curr_pair[0].item()); + ctx->save_for_backward({neighbors, deltas, distances}); + ctx->saved_data["num_atoms"] = num_atoms; + return {neighbors, deltas, distances}; + } + + static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) { + const Tensor grad_distances = grad_inputs[1]; + const int num_atoms = ctx->saved_data["num_atoms"].toInt(); + const int num_pairs = grad_distances.size(0); + const int num_threads = 128; + const int num_blocks_x = max((num_pairs + num_threads - 1) / num_threads, 1); + const dim3 blocks(num_blocks_x, 2, 3); + const auto stream = getCurrentCUDAStream(grad_distances.get_device()); + + const tensor_list data = ctx->get_saved_variables(); + const Tensor neighbors = data[0]; + const Tensor deltas = data[1]; + const Tensor distances = data[2]; + const Tensor grad_positions = zeros({num_atoms, 3}, grad_distances.options()); + + AT_DISPATCH_FLOATING_TYPES(grad_distances.scalar_type(), "get_neighbor_pairs_backward", [&]() { + const CUDAStreamGuard guard(stream); + backward_kernel<<>>( + get_accessor(neighbors), get_accessor(deltas), + get_accessor(distances), get_accessor(grad_distances), + get_accessor(grad_positions)); + }); + + return {grad_positions, Tensor(), Tensor()}; + } +}; + +TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { + m.impl("get_neighbor_pairs", + [](const Tensor& positions, const Tensor& batch, const Scalar& cutoff, const Scalar& max_num_pairs, bool checkErrors) { + const tensor_list results = Autograd::apply(positions, batch, cutoff, max_num_pairs, checkErrors); + return std::make_tuple(results[0], results[1], results[2]); + }); +} From 92eaea8ff89e50eacd4b3e2892c99d22a9fa50e7 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 1 May 2023 13:39:14 +0200 Subject: [PATCH 02/76] Move DistanceCellList to utils --- tests/test_neighbors.py | 85 +++----------------------------------- torchmdnet/models/utils.py | 56 ++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 80 deletions(-) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 2696b5099..3b19811bc 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -2,74 +2,20 @@ import pytest import torch import numpy as np -from torch.cuda import check_error - -from torchmdnet.neighbors import get_neighbor_pairs - - -class DistanceCellList(torch.nn.Module): - def __init__( - self, - cutoff_upper, - max_num_pairs=32, - loop=False, - ): - super(DistanceCellList, self).__init__() - """ Compute the neighbor list for a given cutoff. - Parameters - ---------- - cutoff_upper : float - Upper cutoff for the neighbor list. - max_num_pairs : int - Maximum number of pairs to store. - loop : bool - Whether to include self interactions (pair (i,i)). - """ - self.cutoff_upper = cutoff_upper - self.max_num_pairs = max_num_pairs - self.loop = loop - - def forward(self, pos, batch): - """ - Parameters - ---------- - pos : torch.Tensor - shape (N, 3) - batch : torch.Tensor - shape (N,) - Returns - ------- - neighbors : torch.Tensor - List of neighbors for each atom in the batch. - shape (2, max_num_pairs) - distances : torch.Tensor - List of distances for each atom in the batch. - shape (max_num_pairs,) - distance_vecs : torch.Tensor - List of distance vectors for each atom in the batch. - shape (max_num_pairs, 3) - - """ - neighbors, distance_vecs, distances = get_neighbor_pairs( - pos, - cutoff=self.cutoff_upper, - batch=batch, - max_num_pairs=self.max_num_pairs, - check_errors=True - ) - return neighbors, distances, distance_vecs +from torchmdnet.models.utils import DistanceCellList def sort_neighbors(neighbors, deltas, distances): i_sorted = np.lexsort(neighbors) return neighbors[:, i_sorted], deltas[i_sorted], distances[i_sorted] @pytest.mark.parametrize("device", ["cpu", "cuda"]) -@pytest.mark.parametrize("n_batches", [1, 2, 3]) +@pytest.mark.parametrize("strategy", ["brute", "cell"]) +@pytest.mark.parametrize("n_batches", [1, 2, 3,100]) @pytest.mark.parametrize("cutoff", [0.1, 1.5, 1000.0]) -def test_neighbors(device, n_batches, cutoff): +def test_neighbors(device, strategy, n_batches, cutoff): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") - n_atoms_per_batch = np.random.randint(2, 4, size=n_batches) + n_atoms_per_batch = np.random.randint(2, 100, size=n_batches) batch = torch.tensor([i for i in range(n_batches) for j in range(n_atoms_per_batch[i])]) cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) pos = torch.randn(cumsum[-1], 3, device=device) @@ -77,14 +23,7 @@ def test_neighbors(device, n_batches, cutoff): pos[0,:] = torch.zeros(3) pos[1,:] = torch.zeros(3) pos.requires_grad = True - print("Pos") - print(pos.shape) - print("Batch") - print(batch) ref_neighbors = np.concatenate([np.tril_indices(n_atoms_per_batch[i], -1)+cumsum[i] for i in range(n_batches)], axis=1) - print("Neighbors_i concat") - print(ref_neighbors.shape) - print(ref_neighbors) pos_np = pos.cpu().detach().numpy() ref_distances = np.linalg.norm(pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]], axis=-1) ref_distance_vecs = pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]] @@ -96,7 +35,7 @@ def test_neighbors(device, n_batches, cutoff): ref_distances = ref_distances[mask] max_num_pairs = ref_neighbors.shape[1] - nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs) + nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy) batch.to(device) neighbors, distances, distance_vecs = nl(pos, batch) neighbors = neighbors.cpu().detach().numpy() @@ -105,19 +44,7 @@ def test_neighbors(device, n_batches, cutoff): assert neighbors.shape == (2, max_num_pairs) assert distances.shape == (max_num_pairs,) assert distance_vecs.shape == (max_num_pairs, 3) - - print("Neighbors") - print(neighbors) - print(ref_neighbors) - # print("Distances") - # print(distances) - # print(ref_distances) - # print("Distance vecs") - # print(distance_vecs) - # print(ref_distance_vecs) neighbors, distance_vecs, distances = sort_neighbors(neighbors, distance_vecs, distances) - - assert np.allclose(neighbors, ref_neighbors) assert np.allclose(distances, ref_distances) assert np.allclose(distance_vecs, ref_distance_vecs) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index d1f18c6e6..ccd2b516d 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -76,6 +76,61 @@ def forward(self, z, x, edge_index, edge_weight, edge_attr): def message(self, x_j, W): return x_j * W +from torchmdnet.neighbors import get_neighbor_pairs, get_neighbor_pairs_cell +class DistanceCellList(torch.nn.Module): + def __init__( + self, + cutoff_upper, + max_num_pairs=32, + loop=False, + strategy="cell", + ): + super(DistanceCellList, self).__init__() + """ Compute the neighbor list for a given cutoff. + Parameters + ---------- + cutoff_upper : float + Upper cutoff for the neighbor list. + max_num_pairs : int + Maximum number of pairs to store. + loop : bool + Whether to include self interactions (pair (i,i)). + """ + self.cutoff_upper = cutoff_upper + self.max_num_pairs = max_num_pairs + self.loop = loop + self.strategy = strategy + + def forward(self, pos, batch): + """ + Parameters + ---------- + pos : torch.Tensor + shape (N, 3) + batch : torch.Tensor + shape (N,) + Returns + ------- + neighbors : torch.Tensor + List of neighbors for each atom in the batch. + shape (2, max_num_pairs) + distances : torch.Tensor + List of distances for each atom in the batch. + shape (max_num_pairs,) + distance_vecs : torch.Tensor + List of distance vectors for each atom in the batch. + shape (max_num_pairs, 3) + + """ + function = get_neighbor_pairs if self.strategy == "brute" else get_neighbor_pairs_cell + neighbors, distance_vecs, distances = function( + pos, + cutoff=self.cutoff_upper, + batch=batch, + max_num_pairs=self.max_num_pairs, + check_errors=True + ) + return neighbors, distances, distance_vecs class GaussianSmearing(nn.Module): def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True): @@ -252,7 +307,6 @@ def forward(self, pos, batch): # Union typing works with TorchScript (https://github.com/pytorch/pytorch/pull/53180) return edge_index, edge_weight, None - class GatedEquivariantBlock(nn.Module): """Gated Equivariant Block as defined in Schütt et al. (2021): Equivariant message passing for the prediction of tensorial properties and molecular spectra From 6097a19eb983639e51e01e36df3606b25091f477 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 1 May 2023 13:39:34 +0200 Subject: [PATCH 03/76] Add benchmark --- benchmarks/neighbors.py | 73 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 benchmarks/neighbors.py diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py new file mode 100644 index 000000000..dfe9b91d2 --- /dev/null +++ b/benchmarks/neighbors.py @@ -0,0 +1,73 @@ +import os +import torch +import numpy as np +from torchmdnet.models.utils import DistanceCellList + + +def benchmark_neighbors(device, strategy, n_batches, total_num_particles): + """Benchmark the neighbor list generation. + + Parameters + ---------- + device : str + Device to use for the benchmark. + strategy : str + Strategy to use for the neighbor list generation (cell, brute). + n_batches : int + Number of batches to generate. + total_num_particles : int + Total number of particles. + Returns + ------- + float + Average time per batch in seconds. + """ + density = 0.5; + num_particles = total_num_particles // n_batches + expected_num_neighbors = min(num_particles, 32); + cutoff = np.cbrt(3 * expected_num_neighbors / (4 * np.pi * density)); + n_atoms_per_batch = np.random.randint(num_particles-10, num_particles+10, size=n_batches) + #Fix the last batch so that the total number of particles is correct + n_atoms_per_batch[-1] += total_num_particles - n_atoms_per_batch.sum() + if n_atoms_per_batch[-1] < 0: + n_atoms_per_batch[-1] = 1 + + lbox = np.cbrt(num_particles / density); + batch = torch.tensor([i for i in range(n_batches) for j in range(n_atoms_per_batch[i])], device=device) + cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) + pos = torch.rand(cumsum[-1], 3, device=device)*lbox + max_num_pairs = torch.tensor(expected_num_neighbors * n_atoms_per_batch.sum(), dtype=torch.int64).item() + nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy) + #Warmup + neighbors, distances, distance_vecs = nl(pos, batch) + if device == 'cuda': + torch.cuda.synchronize() + #Benchmark using torch profiler + nruns = 100 + if device == 'cuda': + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + with torch.autograd.profiler.profile(use_cuda=True) as prof: + for i in range(nruns): + neighbors, distances, distance_vecs = nl(pos, batch) + end.record() + if device == 'cuda': + torch.cuda.synchronize() + #Final time + return (start.elapsed_time(end) / nruns) + + +if __name__ == '__main__': + n_particles = 10000 + print("Benchmarking neighbor list generation for {} particles".format(n_particles)) + #Loop over different number of batches + for n_batches in [1, 10, 100, 1000]: + time = benchmark_neighbors(device='cuda', + strategy='brute', + n_batches=n_batches, + total_num_particles=n_particles) + print("Time for {} batches: {} ms".format(n_batches, time)) From 85b77b315a2c5f2d451ca9375763f062840e9b66 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 1 May 2023 13:39:47 +0200 Subject: [PATCH 04/76] Add CellList implementation source --- torchmdnet/neighbors/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchmdnet/neighbors/__init__.py b/torchmdnet/neighbors/__init__.py index 24919f466..6282fb558 100644 --- a/torchmdnet/neighbors/__init__.py +++ b/torchmdnet/neighbors/__init__.py @@ -3,8 +3,9 @@ from torch.utils import cpp_extension src_dir = os.path.dirname(__file__) -sources = ['neighbors.cpp', 'neighbors_cpu.cpp'] + (['neighbors_cuda.cu'] if pt.cuda.is_available() else []) +sources = ['neighbors.cpp', 'neighbors_cpu.cpp'] + (['neighbors_cuda.cu', 'neighbors_cuda_cell.cu'] if pt.cuda.is_available() else []) sources = [os.path.join(src_dir, name) for name in sources] - cpp_extension.load(name='neighbors', sources=sources, is_python_module=False) + get_neighbor_pairs = pt.ops.neighbors.get_neighbor_pairs +get_neighbor_pairs_cell = pt.ops.neighbors.get_neighbor_pairs_cell From 6f7c165d38a3adca4106ddfaaf7ccc9e0471c7b6 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 1 May 2023 13:40:30 +0200 Subject: [PATCH 05/76] Add CellList implementation source --- torchmdnet/neighbors/neighbors.cpp | 1 + torchmdnet/neighbors/neighbors_cpu.cpp | 4 +- torchmdnet/neighbors/neighbors_cuda.cu | 130 ++++++------- torchmdnet/neighbors/neighbors_cuda_cell.cu | 201 ++++++++++++++++++++ 4 files changed, 267 insertions(+), 69 deletions(-) create mode 100644 torchmdnet/neighbors/neighbors_cuda_cell.cu diff --git a/torchmdnet/neighbors/neighbors.cpp b/torchmdnet/neighbors/neighbors.cpp index 45af7f723..4ff69ce31 100644 --- a/torchmdnet/neighbors/neighbors.cpp +++ b/torchmdnet/neighbors/neighbors.cpp @@ -2,4 +2,5 @@ TORCH_LIBRARY(neighbors, m) { m.def("get_neighbor_pairs(Tensor positions, Tensor batch, Scalar cutoff, Scalar max_num_pairs, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); + m.def("get_neighbor_pairs_cell(Tensor positions, Tensor batch, Scalar cutoff, Scalar max_num_pairs, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); } diff --git a/torchmdnet/neighbors/neighbors_cpu.cpp b/torchmdnet/neighbors/neighbors_cpu.cpp index 57bc5e332..3ee5e250b 100644 --- a/torchmdnet/neighbors/neighbors_cpu.cpp +++ b/torchmdnet/neighbors/neighbors_cpu.cpp @@ -44,8 +44,7 @@ static tuple forward(const Tensor& positions, const Tens TORCH_CHECK(v[0][0] >= 2 * v[2][0], "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]"); TORCH_CHECK(v[1][1] >= 2 * v[2][1], "Invalid box vectors: box_vectors[1][1] < 2*box_vectors[2][1]"); } - const int max_num_pairs_ = max_num_pairs.to(); - TORCH_CHECK(max_num_pairs_ > 0, + TORCH_CHECK(max_num_pairs.toLong() > 0, "Expected \"max_num_neighbors\" to be positive"); const int n_atoms = positions.size(0); const int n_batches = batch[n_atoms - 1].item() + 1; @@ -92,4 +91,5 @@ static tuple forward(const Tensor& positions, const Tens TORCH_LIBRARY_IMPL(neighbors, CPU, m) { m.impl("get_neighbor_pairs", &forward); + m.impl("get_neighbor_pairs_cell", &forward); } diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index 15a2a687c..62991f22c 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -3,7 +3,9 @@ #include #include #include - +#include +#include +#include using c10::cuda::CUDAStreamGuard; using c10::cuda::getCurrentCUDAStream; using std::make_tuple; @@ -36,22 +38,19 @@ template <> __device__ __forceinline__ double sqrt_(double x) { return ::sqrt(x); }; -template __global__ void forward_kernel( - const int32_t num_all_pairs, - const Accessor positions, - const Accessor batch_offsets, - const int32_t batch_index, - const scalar_t cutoff2, - Accessor i_curr_pair, - Accessor neighbors, - Accessor deltas, - Accessor distances) { +template +__global__ void forward_kernel(const int32_t num_all_pairs, const Accessor positions, + const Accessor batch_offsets, const int32_t batch_index, + const scalar_t cutoff2, Accessor i_curr_pair, Accessor neighbors, + Accessor deltas, Accessor distances) { const int32_t index = blockIdx.x * blockDim.x + threadIdx.x; - if (index >= num_all_pairs) return; + if (index >= num_all_pairs) + return; int32_t row = floor((sqrtf(8 * index + 1) + 1) / 2); - if (row * (row - 1) > 2 * index) row--; - const int32_t position_offset = batch_index?batch_offsets[batch_index-1]:0; + if (row * (row - 1) > 2 * index) + row--; + const int32_t position_offset = batch_index ? batch_offsets[batch_index - 1] : 0; const int32_t column = (index - row * (row - 1) / 2) + position_offset; row += position_offset; scalar_t delta_x = positions[row][0] - positions[column][0]; @@ -59,10 +58,11 @@ template __global__ void forward_kernel( scalar_t delta_z = positions[row][2] - positions[column][2]; const scalar_t distance2 = delta_x * delta_x + delta_y * delta_y + delta_z * delta_z; - if (distance2 > cutoff2) return; + if (distance2 > cutoff2) + return; const int32_t i_pair = atomicAdd(&i_curr_pair[0], 1); - //We handle too many neighbors outside of the kernel + // We handle too many neighbors outside of the kernel if (i_pair < neighbors.size(1)) { neighbors[0][i_pair] = row; neighbors[1][i_pair] = column; @@ -73,6 +73,7 @@ template __global__ void forward_kernel( } } + template __global__ void backward_kernel(const Accessor neighbors, const Accessor deltas, const Accessor distances, const Accessor grad_distances, @@ -93,6 +94,15 @@ __global__ void backward_kernel(const Accessor neighbors, const Acce } static void checkInput(const Tensor& positions, const Tensor& batch) { + // This version works with batches + // Batch contains the molecule index for each atom in positions + // Neighbors are only calculated within the same molecule + // Batch is a 1D tensor of size (N_atoms) + // Batch is assumed to be sorted and starts at zero. + // Batch is assumed to be contiguous + // Batch is assumed to be of type torch::kLong + // Batch is assumed to be non-negative + // Each batch can have a different number of atoms TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); @@ -103,74 +113,60 @@ static void checkInput(const Tensor& positions, const Tensor& batch) { batch.size(0) == positions.size(0), "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension size of \"positions\""); TORCH_CHECK(batch.is_contiguous(), "Expected \"batch\" to be contiguous"); + TORCH_CHECK(batch.dtype() == torch::kLong, "Expected \"batch\" to have torch::kLong dtype"); } class Autograd : public Function { public: - static tensor_list forward(AutogradContext* ctx, const Tensor& positions, - const Tensor& batch, - const Scalar& cutoff, const Scalar& max_num_pairs, - bool checkErrors){ - // This version works with batches - // Batch contains the molecule index for each atom in positions - // Neighbors are only calculated within the same molecule - // Batch is a 1D tensor of size (N_atoms) - // Batch is assumed to be sorted and starts at zero. - // Batch is assumed to be contiguous - // Batch is assumed to be of type int32 - // Batch is assumed to be non-negative - // Each batch can have a different number of atoms + static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, const Scalar& cutoff, + const Scalar& max_num_pairs, bool checkErrors) { checkInput(positions, batch); - const int max_num_pairs_ = max_num_pairs.to(); + const auto max_num_pairs_ = max_num_pairs.toLong(); TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); const int num_atoms = positions.size(0); const int num_pairs = max_num_pairs_; - const Tensor num_atoms_per_batch = torch::bincount(batch); + const TensorOptions options = positions.options(); + const auto stream = getCurrentCUDAStream(positions.get_device()); + const Tensor num_atoms_per_batch = torch::bincount(batch).to(torch::kCPU); const int n_batches = num_atoms_per_batch.size(0); - const TensorOptions options = positions.options(); const Tensor batch_offsets = torch::cumsum(num_atoms_per_batch, 0, torch::kInt32).to(positions.device()); - const auto stream = getCurrentCUDAStream(positions.get_device()); - const Tensor neighbors = full({2, num_pairs}, -1, options.dtype(kInt32)); const Tensor deltas = empty({num_pairs, 3}, options); const Tensor distances = full(num_pairs, 0, options); const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); - { - const CUDAStreamGuard guard(stream); - for(int i=0; i(); - const int num_all_pairs = num_atoms * (num_atoms - 1) / 2; - const int num_threads = 128; - const int num_blocks = max((num_all_pairs + num_threads - 1) / num_threads, 1); - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { - const scalar_t cutoff_ = cutoff.to(); - TORCH_CHECK(cutoff_ > 0, "Expected \"cutoff\" to be positive"); - forward_kernel<<>>( - num_all_pairs, - get_accessor(positions), - get_accessor(batch_offsets), - i, - cutoff_ * cutoff_, get_accessor(i_curr_pair), - get_accessor(neighbors), get_accessor(deltas), - get_accessor(distances)); - }); - } - } - // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA graphs + { + const CUDAStreamGuard guard(stream); + for (int i = 0; i < n_batches; i++) { + const int num_atoms = num_atoms_per_batch[i].item(); + const int num_all_pairs = num_atoms * (num_atoms - 1) / 2; + const int num_threads = 128; + const int num_blocks = max((num_all_pairs + num_threads - 1) / num_threads, 1); + AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { + const scalar_t cutoff_ = cutoff.to(); + TORCH_CHECK(cutoff_ > 0, "Expected \"cutoff\" to be positive"); + forward_kernel<<>>( + num_all_pairs, get_accessor(positions), get_accessor(batch_offsets), i, + cutoff_ * cutoff_, get_accessor(i_curr_pair), get_accessor(neighbors), + get_accessor(deltas), get_accessor(distances)); + }); + } + } + // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA graphs if (checkErrors) { - int num_found_pairs = i_curr_pair.item(); - TORCH_CHECK(num_found_pairs <= max_num_pairs_, "Too many neighbor pairs found. Maximum is " + std::to_string(max_num_pairs_), " but found " + std::to_string(num_found_pairs)); + int num_found_pairs = i_curr_pair.item(); + TORCH_CHECK(num_found_pairs <= max_num_pairs_, + "Too many neighbor pairs found. Maximum is " + std::to_string(max_num_pairs_), + " but found " + std::to_string(num_found_pairs)); } - - neighbors.resize_({2, i_curr_pair[0].item()}); + neighbors.resize_({2, i_curr_pair[0].item()}); deltas.resize_({i_curr_pair[0].item(), 3}); distances.resize_(i_curr_pair[0].item()); ctx->save_for_backward({neighbors, deltas, distances}); ctx->saved_data["num_atoms"] = num_atoms; return {neighbors, deltas, distances}; - } + } - static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) { + static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) { const Tensor grad_distances = grad_inputs[1]; const int num_atoms = ctx->saved_data["num_atoms"].toInt(); const int num_pairs = grad_distances.size(0); @@ -198,9 +194,9 @@ public: }; TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { - m.impl("get_neighbor_pairs", - [](const Tensor& positions, const Tensor& batch, const Scalar& cutoff, const Scalar& max_num_pairs, bool checkErrors) { - const tensor_list results = Autograd::apply(positions, batch, cutoff, max_num_pairs, checkErrors); - return std::make_tuple(results[0], results[1], results[2]); - }); + m.impl("get_neighbor_pairs", [](const Tensor& positions, const Tensor& batch, const Scalar& cutoff, + const Scalar& max_num_pairs, bool checkErrors) { + const tensor_list results = Autograd::apply(positions, batch, cutoff, max_num_pairs, checkErrors); + return std::make_tuple(results[0], results[1], results[2]); + }); } diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu new file mode 100644 index 000000000..3d3f7610e --- /dev/null +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -0,0 +1,201 @@ +#include +#include +#include +#include +#include +#include +#include +#include +using c10::cuda::CUDAStreamGuard; +using c10::cuda::getCurrentCUDAStream; +using std::make_tuple; +using std::max; +using torch::empty; +using torch::full; +using torch::kInt32; +using torch::PackedTensorAccessor32; +using torch::RestrictPtrTraits; +using torch::Scalar; +using torch::Tensor; +using torch::TensorOptions; +using torch::zeros; +using torch::autograd::AutogradContext; +using torch::autograd::Function; +using torch::autograd::tensor_list; + +template +using Accessor = PackedTensorAccessor32; + +template inline Accessor get_accessor(const Tensor& tensor) { + return tensor.packed_accessor32(); +}; + +template __device__ __forceinline__ scalar_t sqrt_(scalar_t x){}; +template <> __device__ __forceinline__ float sqrt_(float x) { + return ::sqrtf(x); +}; +template <> __device__ __forceinline__ double sqrt_(double x) { + return ::sqrt(x); +}; + +template +__global__ void forward_kernel(const int32_t num_all_pairs, const Accessor positions, + const Accessor batch_offsets, const int32_t batch_index, + const scalar_t cutoff2, Accessor i_curr_pair, Accessor neighbors, + Accessor deltas, Accessor distances) { + const int32_t index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= num_all_pairs) + return; + + int32_t row = floor((sqrtf(8 * index + 1) + 1) / 2); + if (row * (row - 1) > 2 * index) + row--; + const int32_t position_offset = batch_index ? batch_offsets[batch_index - 1] : 0; + const int32_t column = (index - row * (row - 1) / 2) + position_offset; + row += position_offset; + scalar_t delta_x = positions[row][0] - positions[column][0]; + scalar_t delta_y = positions[row][1] - positions[column][1]; + scalar_t delta_z = positions[row][2] - positions[column][2]; + const scalar_t distance2 = delta_x * delta_x + delta_y * delta_y + delta_z * delta_z; + + if (distance2 > cutoff2) + return; + + const int32_t i_pair = atomicAdd(&i_curr_pair[0], 1); + // We handle too many neighbors outside of the kernel + if (i_pair < neighbors.size(1)) { + neighbors[0][i_pair] = row; + neighbors[1][i_pair] = column; + deltas[i_pair][0] = delta_x; + deltas[i_pair][1] = delta_y; + deltas[i_pair][2] = delta_z; + distances[i_pair] = sqrt_(distance2); + } +} + +template +__global__ void backward_kernel(const Accessor neighbors, const Accessor deltas, + const Accessor distances, const Accessor grad_distances, + Accessor grad_positions) { + const int32_t i_pair = blockIdx.x * blockDim.x + threadIdx.x; + const int32_t num_pairs = neighbors.size(1); + if (i_pair >= num_pairs) + return; + + const int32_t i_dir = blockIdx.y; + const int32_t i_atom = neighbors[i_dir][i_pair]; + if (i_atom < 0) + return; + + const int32_t i_comp = blockIdx.z; + const scalar_t grad = deltas[i_pair][i_comp] / distances[i_pair] * grad_distances[i_pair]; + atomicAdd(&grad_positions[i_atom][i_comp], (i_dir ? -1 : 1) * grad); +} + +static void checkInput(const Tensor& positions, const Tensor& batch) { + // This version works with batches + // Batch contains the molecule index for each atom in positions + // Neighbors are only calculated within the same molecule + // Batch is a 1D tensor of size (N_atoms) + // Batch is assumed to be sorted and starts at zero. + // Batch is assumed to be contiguous + // Batch is assumed to be of type torch::kLong + // Batch is assumed to be non-negative + // Each batch can have a different number of atoms + TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); + TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); + TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); + TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); + + TORCH_CHECK(batch.dim() == 1, "Expected \"batch\" to have one dimension"); + TORCH_CHECK( + batch.size(0) == positions.size(0), + "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension size of \"positions\""); + TORCH_CHECK(batch.is_contiguous(), "Expected \"batch\" to be contiguous"); + TORCH_CHECK(batch.dtype() == torch::kLong, "Expected \"batch\" to have torch::kLong dtype"); +} + +class Autograd : public Function { +public: + static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, const Scalar& cutoff, + const Scalar& max_num_pairs, bool checkErrors) { + checkInput(positions, batch); + const auto max_num_pairs_ = max_num_pairs.toLong(); + TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); + const int num_atoms = positions.size(0); + const int num_pairs = max_num_pairs_; + const TensorOptions options = positions.options(); + const auto stream = getCurrentCUDAStream(positions.get_device()); + const Tensor num_atoms_per_batch = torch::bincount(batch); + const int n_batches = num_atoms_per_batch.size(0); + const Tensor batch_offsets = torch::cumsum(num_atoms_per_batch, 0, torch::kInt32).to(positions.device()); + const Tensor neighbors = full({2, num_pairs}, -1, options.dtype(kInt32)); + const Tensor deltas = empty({num_pairs, 3}, options); + const Tensor distances = full(num_pairs, 0, options); + const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); + { + const CUDAStreamGuard guard(stream); + for (int i = 0; i < n_batches; i++) { + const int num_atoms = num_atoms_per_batch[i].item(); + const int num_all_pairs = num_atoms * (num_atoms - 1) / 2; + const int num_threads = 128; + const int num_blocks = max((num_all_pairs + num_threads - 1) / num_threads, 1); + AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { + const scalar_t cutoff_ = cutoff.to(); + TORCH_CHECK(cutoff_ > 0, "Expected \"cutoff\" to be positive"); + forward_kernel<<>>( + num_all_pairs, get_accessor(positions), get_accessor(batch_offsets), i, + cutoff_ * cutoff_, get_accessor(i_curr_pair), get_accessor(neighbors), + get_accessor(deltas), get_accessor(distances)); + }); + } + } + // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA graphs + if (checkErrors) { + int num_found_pairs = i_curr_pair.item(); + TORCH_CHECK(num_found_pairs <= max_num_pairs_, + "Too many neighbor pairs found. Maximum is " + std::to_string(max_num_pairs_), + " but found " + std::to_string(num_found_pairs)); + } + neighbors.resize_({2, i_curr_pair[0].item()}); + deltas.resize_({i_curr_pair[0].item(), 3}); + distances.resize_(i_curr_pair[0].item()); + ctx->save_for_backward({neighbors, deltas, distances}); + ctx->saved_data["num_atoms"] = num_atoms; + return {neighbors, deltas, distances}; + } + + static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) { + const Tensor grad_distances = grad_inputs[1]; + const int num_atoms = ctx->saved_data["num_atoms"].toInt(); + const int num_pairs = grad_distances.size(0); + const int num_threads = 128; + const int num_blocks_x = max((num_pairs + num_threads - 1) / num_threads, 1); + const dim3 blocks(num_blocks_x, 2, 3); + const auto stream = getCurrentCUDAStream(grad_distances.get_device()); + + const tensor_list data = ctx->get_saved_variables(); + const Tensor neighbors = data[0]; + const Tensor deltas = data[1]; + const Tensor distances = data[2]; + const Tensor grad_positions = zeros({num_atoms, 3}, grad_distances.options()); + + AT_DISPATCH_FLOATING_TYPES(grad_distances.scalar_type(), "get_neighbor_pairs_backward", [&]() { + const CUDAStreamGuard guard(stream); + backward_kernel<<>>( + get_accessor(neighbors), get_accessor(deltas), + get_accessor(distances), get_accessor(grad_distances), + get_accessor(grad_positions)); + }); + + return {grad_positions, Tensor(), Tensor()}; + } +}; + +TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { + m.impl("get_neighbor_pairs_cell", [](const Tensor& positions, const Tensor& batch, const Scalar& cutoff, + const Scalar& max_num_pairs, bool checkErrors) { + const tensor_list results = Autograd::apply(positions, batch, cutoff, max_num_pairs, checkErrors); + return std::make_tuple(results[0], results[1], results[2]); + }); +} From f0ddc727e9564eb8bd057dca7fee619ee0e25d0a Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 2 May 2023 20:09:33 +0200 Subject: [PATCH 06/76] Add strategy and box as parameters Use kInt32 for batch Update cell list impl. --- torchmdnet/models/utils.py | 15 +- torchmdnet/neighbors/neighbors.cpp | 4 +- torchmdnet/neighbors/neighbors_cpu.cpp | 2 +- torchmdnet/neighbors/neighbors_cuda.cu | 4 +- torchmdnet/neighbors/neighbors_cuda_cell.cu | 467 ++++++++++++++++---- 5 files changed, 396 insertions(+), 96 deletions(-) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index ccd2b516d..8b314ef55 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -82,8 +82,8 @@ def __init__( self, cutoff_upper, max_num_pairs=32, - loop=False, strategy="cell", + box=None ): super(DistanceCellList, self).__init__() """ Compute the neighbor list for a given cutoff. @@ -93,13 +93,17 @@ def __init__( Upper cutoff for the neighbor list. max_num_pairs : int Maximum number of pairs to store. - loop : bool - Whether to include self interactions (pair (i,i)). + strategy : str + Strategy to use for computing the neighbor list. Can be one of + ["brute", "cell"]. + box : torch.Tensor + Size of the box shape (3,) or None + """ self.cutoff_upper = cutoff_upper self.max_num_pairs = max_num_pairs - self.loop = loop self.strategy = strategy + self.box = box def forward(self, pos, batch): """ @@ -128,7 +132,8 @@ def forward(self, pos, batch): cutoff=self.cutoff_upper, batch=batch, max_num_pairs=self.max_num_pairs, - check_errors=True + check_errors=True, + box_size=self.box ) return neighbors, distances, distance_vecs diff --git a/torchmdnet/neighbors/neighbors.cpp b/torchmdnet/neighbors/neighbors.cpp index 4ff69ce31..cb2bfe2d3 100644 --- a/torchmdnet/neighbors/neighbors.cpp +++ b/torchmdnet/neighbors/neighbors.cpp @@ -1,6 +1,6 @@ #include TORCH_LIBRARY(neighbors, m) { - m.def("get_neighbor_pairs(Tensor positions, Tensor batch, Scalar cutoff, Scalar max_num_pairs, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); - m.def("get_neighbor_pairs_cell(Tensor positions, Tensor batch, Scalar cutoff, Scalar max_num_pairs, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); + m.def("get_neighbor_pairs(Tensor positions, Tensor batch, Tensor box_size, Scalar cutoff, Scalar max_num_pairs, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); + m.def("get_neighbor_pairs_cell(Tensor positions, Tensor batch, Tensor box_size,Scalar cutoff, Scalar max_num_pairs, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); } diff --git a/torchmdnet/neighbors/neighbors_cpu.cpp b/torchmdnet/neighbors/neighbors_cpu.cpp index 3ee5e250b..954a0d5fe 100644 --- a/torchmdnet/neighbors/neighbors_cpu.cpp +++ b/torchmdnet/neighbors/neighbors_cpu.cpp @@ -16,7 +16,7 @@ using torch::Tensor; using torch::outer; using torch::round; -static tuple forward(const Tensor& positions, const Tensor& batch, const Scalar& cutoff, +static tuple forward(const Tensor& positions, const Tensor& batch, const Tensor &box_size, const Scalar& cutoff, const Scalar& max_num_pairs, bool checkErrors) { TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index 62991f22c..8d9f45918 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -113,7 +113,7 @@ static void checkInput(const Tensor& positions, const Tensor& batch) { batch.size(0) == positions.size(0), "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension size of \"positions\""); TORCH_CHECK(batch.is_contiguous(), "Expected \"batch\" to be contiguous"); - TORCH_CHECK(batch.dtype() == torch::kLong, "Expected \"batch\" to have torch::kLong dtype"); + TORCH_CHECK(batch.dtype() == torch::kInt32, "Expected \"batch\" to have torch::kInt32 dtype"); } class Autograd : public Function { @@ -194,7 +194,7 @@ public: }; TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { - m.impl("get_neighbor_pairs", [](const Tensor& positions, const Tensor& batch, const Scalar& cutoff, + m.impl("get_neighbor_pairs", [](const Tensor& positions, const Tensor& batch, const Tensor& box_size, const Scalar& cutoff, const Scalar& max_num_pairs, bool checkErrors) { const tensor_list results = Autograd::apply(positions, batch, cutoff, max_num_pairs, checkErrors); return std::make_tuple(results[0], results[1], results[2]); diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index 3d3f7610e..b98cd06ec 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -1,11 +1,12 @@ #include #include +#include #include -#include -#include #include -#include #include +#include +#include +#include using c10::cuda::CUDAStreamGuard; using c10::cuda::getCurrentCUDAStream; using std::make_tuple; @@ -26,7 +27,8 @@ using torch::autograd::tensor_list; template using Accessor = PackedTensorAccessor32; -template inline Accessor get_accessor(const Tensor& tensor) { +template +inline Accessor get_accessor(const Tensor& tensor) { return tensor.packed_accessor32(); }; @@ -39,44 +41,10 @@ template <> __device__ __forceinline__ double sqrt_(double x) { }; template -__global__ void forward_kernel(const int32_t num_all_pairs, const Accessor positions, - const Accessor batch_offsets, const int32_t batch_index, - const scalar_t cutoff2, Accessor i_curr_pair, Accessor neighbors, - Accessor deltas, Accessor distances) { - const int32_t index = blockIdx.x * blockDim.x + threadIdx.x; - if (index >= num_all_pairs) - return; - - int32_t row = floor((sqrtf(8 * index + 1) + 1) / 2); - if (row * (row - 1) > 2 * index) - row--; - const int32_t position_offset = batch_index ? batch_offsets[batch_index - 1] : 0; - const int32_t column = (index - row * (row - 1) / 2) + position_offset; - row += position_offset; - scalar_t delta_x = positions[row][0] - positions[column][0]; - scalar_t delta_y = positions[row][1] - positions[column][1]; - scalar_t delta_z = positions[row][2] - positions[column][2]; - const scalar_t distance2 = delta_x * delta_x + delta_y * delta_y + delta_z * delta_z; - - if (distance2 > cutoff2) - return; - - const int32_t i_pair = atomicAdd(&i_curr_pair[0], 1); - // We handle too many neighbors outside of the kernel - if (i_pair < neighbors.size(1)) { - neighbors[0][i_pair] = row; - neighbors[1][i_pair] = column; - deltas[i_pair][0] = delta_x; - deltas[i_pair][1] = delta_y; - deltas[i_pair][2] = delta_z; - distances[i_pair] = sqrt_(distance2); - } -} - -template -__global__ void backward_kernel(const Accessor neighbors, const Accessor deltas, - const Accessor distances, const Accessor grad_distances, - Accessor grad_positions) { +__global__ void +backward_kernel(const Accessor neighbors, const Accessor deltas, + const Accessor distances, const Accessor grad_distances, + Accessor grad_positions) { const int32_t i_pair = blockIdx.x * blockDim.x + threadIdx.x; const int32_t num_pairs = neighbors.size(1); if (i_pair >= num_pairs) @@ -99,67 +67,391 @@ static void checkInput(const Tensor& positions, const Tensor& batch) { // Batch is a 1D tensor of size (N_atoms) // Batch is assumed to be sorted and starts at zero. // Batch is assumed to be contiguous - // Batch is assumed to be of type torch::kLong + // Batch is assumed to be of type torch::kInt32 // Batch is assumed to be non-negative // Each batch can have a different number of atoms TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); - TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); + TORCH_CHECK(positions.size(0) > 0, + "Expected the 1nd dimension size of \"positions\" to be more than 0"); TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); TORCH_CHECK(batch.dim() == 1, "Expected \"batch\" to have one dimension"); - TORCH_CHECK( - batch.size(0) == positions.size(0), - "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension size of \"positions\""); + TORCH_CHECK(batch.size(0) == positions.size(0), + "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension " + "size of \"positions\""); TORCH_CHECK(batch.is_contiguous(), "Expected \"batch\" to be contiguous"); - TORCH_CHECK(batch.dtype() == torch::kLong, "Expected \"batch\" to have torch::kLong dtype"); + TORCH_CHECK(batch.dtype() == torch::kInt32, "Expected \"batch\" to be of type torch::kInt32"); +} + +// Encodes an integer lower than 1024 as a 32 bit integer by filling every third bit. +inline __host__ __device__ uint encodeMorton(const uint& i) { + uint x = i; + x &= 0x3ff; + x = (x | x << 16) & 0x30000ff; + x = (x | x << 8) & 0x300f00f; + x = (x | x << 4) & 0x30c30c3; + x = (x | x << 2) & 0x9249249; + return x; +} + +// Fuse three 10 bit numbers in 32 bits, producing a Z order Morton hash +inline __host__ __device__ uint hashMorton(int3 ci) { + return encodeMorton(ci.x) | (encodeMorton(ci.y) << 1) | (encodeMorton(ci.z) << 2); +} + +// Use Minimum Image Convention to take a point to the unit cell +__device__ auto takeToUnitCell(float3 p, float3 box_size) { + p.x = p.x - floorf(p.x / box_size.x + float(0.5)) * box_size.x; + p.y = p.y - floorf(p.y / box_size.y + float(0.5)) * box_size.y; + p.z = p.z - floorf(p.z / box_size.z + float(0.5)) * box_size.z; + return p; +} + +// Get the number of cells in each dimension +__host__ __device__ int3 getNumberCells(float3 box_size, float cutoff) { + int3 cell_dim = make_int3(box_size.x / cutoff, box_size.y / cutoff, box_size.z / cutoff); + // Minimum 3 cells in each dimension + cell_dim.x = thrust::max(cell_dim.x, 3); + cell_dim.y = thrust::max(cell_dim.y, 3); + cell_dim.z = thrust::max(cell_dim.z, 3); +// In the host, throw if there are more than 1024 cells in any dimension +#ifndef __CUDA_ARCH__ + if (cell_dim.x > 1024 || cell_dim.y > 1024 || cell_dim.z > 1024) { + throw std::runtime_error("Too many cells in one dimension. Maximum is 1024"); + } +#endif + return cell_dim; +} + +// Get the cell coordinates of a point +__device__ int3 getCell(float3 p, float3 box_size, float cutoff) { + p = takeToUnitCell(p, box_size); + int cx = floorf(p.x / cutoff); + int cy = floorf(p.y / cutoff); + int cz = floorf(p.z / cutoff); + int3 cell_dim = getNumberCells(box_size, cutoff); + if (cx == cell_dim.x) + cx = 0; + if (cy == cell_dim.y) + cy = 0; + if (cz == cell_dim.z) + cz = 0; + return make_int3(cx, cy, cz); +} + +// Assign a hash to each atom based on its position and batch. +// This hash is such that atoms in the same cell and batch have the same hash. +template +__global__ void assignHash(const Accessor positions, uint64_t* hash_keys, + Accessor hash_values, const Accessor batch, + float3 box_size, float cutoff, int32_t num_atoms) { + const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; + if (i_atom >= num_atoms) + return; + const int32_t i_batch = batch[i_atom]; + // Move to the unit cell + float3 pi = make_float3(positions[i_atom][0], positions[i_atom][1], positions[i_atom][2]); + auto ci = getCell(pi, box_size, cutoff); + // Calculate the hash + const int32_t hash = hashMorton(ci); + // Create a hash combining the Morton hash and the batch index, so that atoms in the same batch + // are contiguous + const int64_t hash_final = (static_cast(i_batch) << 32) | hash; + hash_keys[i_atom] = hash_final; + hash_values[i_atom] = i_atom; +} + +// Adaptor from pytorch cached allocator to thrust +template class CudaAllocator { +public: + using value_type = T; + CudaAllocator() { + } + T* allocate(std::ptrdiff_t num_elements) { + return static_cast(at::cuda::getCUDADeviceAllocator()->raw_allocate(num_elements * sizeof(T))); + } + void deallocate(T* ptr, size_t) { + at::cuda::getCUDADeviceAllocator()->raw_deallocate(ptr); + } +}; + +// Sort the positions by hash, based on the cell assigned to each position and the batch index +static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, float3 box_size, + float cutoff) { + const int num_atoms = positions.size(0); + const auto options = positions.options(); + thrust::device_vector hash_keys(num_atoms); + Tensor hash_values = empty({num_atoms}, options.dtype(torch::kInt32)); + const int threads = 128; + const int blocks = (num_atoms + threads - 1) / threads; + auto stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "assignHash", [&] { + assignHash<<>>( + get_accessor(positions), thrust::raw_pointer_cast(hash_keys.data()), + get_accessor(hash_values), get_accessor(batch), box_size, + cutoff, num_atoms); + }); + + std::cout << "hash_values: " << hash_values << std::endl; + std::cout << "hash_keys: " << std::endl; + for (int i = 0; i < num_atoms; i++) { + uint64_t hi = hash_keys[i]; + std::cout << std::bitset<64>(hi) << std::endl; + } + + // Sort positions by hash_values using thrust + thrust::device_ptr index_ptr(hash_values.data_ptr()); + // pytorch allocator + CudaAllocator allocator; + // Adapted for thrust + thrust::sort_by_key(thrust::cuda::par.on(stream), hash_keys.begin(), hash_keys.end(), + index_ptr); + std::cout << "sorted hash_values: " << hash_values << std::endl; + // Print values of hash_keys in binary + std::cout << "sorted hash_keys: " << std::endl; + for (int i = 0; i < num_atoms; i++) { + uint64_t hi = hash_keys[i]; + std::cout << std::bitset<64>(hi) << std::endl; + } + + Tensor sorted_positions = positions.index_select(0, hash_values); + return std::make_tuple(sorted_positions, hash_values); +} + +__device__ int getCellIndex(int3 cell, int3 cell_dim) { + return cell.x + cell_dim.x * (cell.y + cell_dim.y * cell.z); +} + +template +__global__ void fillCellOffsetsD(const Accessor sorted_positions, + const Accessor sorted_indices, + Accessor cell_start, Accessor cell_end, + const Accessor batch, float3 box_size, float cutoff) { + const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; + if (i_atom >= sorted_positions.size(0)) + return; + const int32_t i_batch = batch[sorted_indices[i_atom]]; + const float3 pi = make_float3(sorted_positions[i_atom][0], sorted_positions[i_atom][1], + sorted_positions[i_atom][2]); + const int3 cell_dim = getNumberCells(box_size, cutoff); + const int icell = getCellIndex(getCell(pi, box_size, cutoff), cell_dim); + int im1_cell; + if (i_atom > 0) { + int im1 = i_atom - 1; + const float3 pim1 = make_float3(sorted_positions[im1][0], sorted_positions[im1][1], + sorted_positions[im1][2]); + im1_cell = getCellIndex(getCell(pim1, box_size, cutoff), cell_dim); + } else { + im1_cell = 0; + } + if (icell != im1_cell || i_atom == 0) { + int n_cells = cell_start.size(0); + if(icell>=n_cells or im1_cell>=n_cells) { + printf("icell: %d, im1_cell: %d, n_cells: %d\n", icell, im1_cell, n_cells); + return; + } + cell_start[icell][i_batch] = i_atom; + if (i_atom > 0) + cell_end[im1_cell][i_batch] = i_atom; + } + if(i_atom == sorted_positions.size(0) - 1) { + cell_end[icell][i_batch] = i_atom + 1; + } +} + +// Fill the cell offsets for each batch, identifying the start and end of each cell for each batch +// in the sorted positions +static auto fillCellOffsets(const Tensor& sorted_positions, + const Tensor& sorted_indices, + const Tensor& batch, float3 box_size, + float cutoff) { + const TensorOptions options = sorted_positions.options(); + const int num_batches = batch[-1].item() + 1; + const int3 cell_dim = getNumberCells(box_size, cutoff); + const int num_cells = cell_dim.x * cell_dim.y * cell_dim.z; + const Tensor cell_start = full({num_cells, num_batches}, -1, options.dtype(torch::kInt)); + const Tensor cell_end = empty({num_cells, num_batches}, options.dtype(torch::kInt)); + std::cerr<<"num_cells: "<>>( + get_accessor(sorted_positions), + get_accessor(sorted_indices), + get_accessor(cell_start), + get_accessor(cell_end), get_accessor(batch), box_size, cutoff); + }); + std::cerr<<"cell_start: "<= cell_dim.x) + periodic_cell.x -= cell_dim.x; + if (cell.y < 0) + periodic_cell.y += cell_dim.y; + if (cell.y >= cell_dim.y) + periodic_cell.y -= cell_dim.y; + if (cell.z < 0) + periodic_cell.z += cell_dim.z; + if (cell.z >= cell_dim.z) + periodic_cell.z -= cell_dim.z; + return periodic_cell; +} + +// Traverse the cell list for each atom and find the neighbors +template +__global__ void +forward_kernel(const Accessor sorted_positions, + const Accessor original_index, const Accessor batch, + const Accessor cell_start, const Accessor cell_end, + Accessor neighbors, Accessor deltas, + Accessor distances, Accessor i_curr_pair, int num_atoms, + int num_pairs, float3 box_size, float cutoff) { + const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; + if (i_atom >= num_atoms) + return; + // Each batch has its own cell list, starting at cell_start[0][i_batch] and ending at + // cell_start[ncells-1][i_batch] Each thread is responsible for a single atom Each thread will + // loop over all atoms in the cell list of the current batch + const int ori = original_index[i_atom]; + const int i_batch = batch[ori]; + float3 pi = make_float3(sorted_positions[i_atom][0], sorted_positions[i_atom][1], + sorted_positions[i_atom][2]); + const int3 cell_i = getCell(pi, box_size, cutoff); + const int3 cell_dim = getNumberCells(box_size, cutoff); + const int i_cell_index = getCellIndex(cell_i, cell_dim); + // Loop over the 27 cells around the current cell + for (int i = 0; i < 27; i++) { + auto cell_j = cell_i; + cell_j.x += i % 3 - 1; + cell_j.y += (i / 3) % 3 - 1; + cell_j.z += i / 9 - 1; + cell_j = getPeriodicCell(cell_j, cell_dim); + int icellj = getCellIndex(cell_j, cell_dim); + const int firstParticle = cell_start[icellj][i_batch]; + if (firstParticle != -1) { // Continue only if there are particles in this cell + // Index of the last particle in the cell's list + const int lastParticle = cell_end[icellj][i_batch]; + const int nincell = lastParticle - firstParticle; + for (int j = 0; j < nincell; j++) { + int cur_j = j + firstParticle; + if (cur_j < i_atom) { + float3 pj = make_float3(sorted_positions[cur_j][0], sorted_positions[cur_j][1], + sorted_positions[cur_j][2]); + const scalar_t dx = pi.x - pj.x; + const scalar_t dy = pi.y - pj.y; + const scalar_t dz = pi.z - pj.z; + const scalar_t distance2 = dx * dx + dy * dy + dz * dz; + if (distance2 < cutoff * cutoff) { + const int32_t i_pair = atomicAdd(&i_curr_pair[0], 1); + // We handle too many neighbors outside of the kernel + if (i_pair < neighbors.size(1)) { + neighbors[0][i_pair] = ori; + neighbors[1][i_pair] = original_index[cur_j]; + deltas[i_pair][0] = dx; + deltas[i_pair][1] = dy; + deltas[i_pair][2] = dz; + distances[i_pair] = sqrt_(distance2); + } + } + } // endfor + } // endif + } // endfor + } } class Autograd : public Function { public: - static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, const Scalar& cutoff, + static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, + const Tensor& box_size, const Scalar& cutoff, const Scalar& max_num_pairs, bool checkErrors) { + + // The algorithm for the cell list construction can be summarized in three separate steps: + // 1. Hash (label) the particles according to the cell (bin) they lie in. + // 2. Sort the particles and hashes using the hashes as the ordering label + // (technically this is known as sorting by key). So that particles with positions + // lying in the same cell become contiguous in memory. + // 3. Identify where each cell starts and ends in the sorted particle positions + // array. checkInput(positions, batch); + TORCH_CHECK(box_size.size(0) == 3, "Expected \"box_size\" to have 3 elements"); const auto max_num_pairs_ = max_num_pairs.toLong(); TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); const int num_atoms = positions.size(0); const int num_pairs = max_num_pairs_; const TensorOptions options = positions.options(); - const auto stream = getCurrentCUDAStream(positions.get_device()); - const Tensor num_atoms_per_batch = torch::bincount(batch); - const int n_batches = num_atoms_per_batch.size(0); - const Tensor batch_offsets = torch::cumsum(num_atoms_per_batch, 0, torch::kInt32).to(positions.device()); + // Steps 1 and 2 + float3 box_size_ = make_float3(box_size[0].item(), box_size[1].item(), + box_size[2].item()); + float cutoff_ = cutoff.toFloat(); + Tensor sorted_positions, hash_values; + std::cerr << "before sortPositionsByHash" << std::endl; + std::tie(sorted_positions, hash_values) = + sortPositionsByHash(positions, batch, box_size_, cutoff_); + cudaDeviceSynchronize(); + std::cerr << "after sortPositionsByHash" << std::endl; + // Step 3 + Tensor cell_start, cell_end; + std::cerr << "before fillCellOffsets" << std::endl; + std::tie(cell_start, cell_end) = + fillCellOffsets(sorted_positions, hash_values, + batch, box_size_, cutoff_); + cudaDeviceSynchronize(); + std::cerr << "after fillCellOffsets" << std::endl; + cudaDeviceSynchronize(); + std::cerr << "Number of pairs: " << num_pairs << std::endl; + cudaDeviceSynchronize(); + std::cerr <<"Allocating memory for neighbors" << std::endl; const Tensor neighbors = full({2, num_pairs}, -1, options.dtype(kInt32)); + std::cerr << "Allocating memory for deltas" << std::endl; const Tensor deltas = empty({num_pairs, 3}, options); + std::cerr << "Allocating memory for distances" << std::endl; const Tensor distances = full(num_pairs, 0, options); const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); - { + cudaDeviceSynchronize(); + std::cerr << "before forward_kernel" << std::endl; + const auto stream = getCurrentCUDAStream(positions.get_device()); + { // Use the cell list for each batch to find the neighbors const CUDAStreamGuard guard(stream); - for (int i = 0; i < n_batches; i++) { - const int num_atoms = num_atoms_per_batch[i].item(); - const int num_all_pairs = num_atoms * (num_atoms - 1) / 2; - const int num_threads = 128; - const int num_blocks = max((num_all_pairs + num_threads - 1) / num_threads, 1); - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { - const scalar_t cutoff_ = cutoff.to(); - TORCH_CHECK(cutoff_ > 0, "Expected \"cutoff\" to be positive"); - forward_kernel<<>>( - num_all_pairs, get_accessor(positions), get_accessor(batch_offsets), i, - cutoff_ * cutoff_, get_accessor(i_curr_pair), get_accessor(neighbors), - get_accessor(deltas), get_accessor(distances)); - }); - } + AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "forward", [&] { + const int threads = 128; + const int blocks = (num_atoms + threads - 1) / threads; + forward_kernel<<>>( + get_accessor(sorted_positions), + get_accessor(hash_values), get_accessor(batch), + get_accessor(cell_start), get_accessor(cell_end), + get_accessor(neighbors), get_accessor(deltas), + get_accessor(distances), get_accessor(i_curr_pair), + num_atoms, num_pairs, box_size_, cutoff_); + }); } - // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA graphs + cudaDeviceSynchronize(); + std::cerr << "after forward_kernel" << std::endl; + // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA + // graphs if (checkErrors) { - int num_found_pairs = i_curr_pair.item(); + std::cout<<"checking errors"<(); TORCH_CHECK(num_found_pairs <= max_num_pairs_, - "Too many neighbor pairs found. Maximum is " + std::to_string(max_num_pairs_), + "Too many neighbor pairs found. Maximum is " + + std::to_string(max_num_pairs_), " but found " + std::to_string(num_found_pairs)); + std::cout<<"no errors"<()}); - deltas.resize_({i_curr_pair[0].item(), 3}); - distances.resize_(i_curr_pair[0].item()); + std::cout<<"before resize"<()}); + deltas.resize_({i_curr_pair[0].item(), 3}); + distances.resize_(i_curr_pair[0].item()); + std::cout<<"after resize"<save_for_backward({neighbors, deltas, distances}); ctx->saved_data["num_atoms"] = num_atoms; return {neighbors, deltas, distances}; @@ -180,22 +472,25 @@ public: const Tensor distances = data[2]; const Tensor grad_positions = zeros({num_atoms, 3}, grad_distances.options()); - AT_DISPATCH_FLOATING_TYPES(grad_distances.scalar_type(), "get_neighbor_pairs_backward", [&]() { - const CUDAStreamGuard guard(stream); - backward_kernel<<>>( - get_accessor(neighbors), get_accessor(deltas), - get_accessor(distances), get_accessor(grad_distances), - get_accessor(grad_positions)); - }); + AT_DISPATCH_FLOATING_TYPES( + grad_distances.scalar_type(), "get_neighbor_pairs_backward", [&]() { + const CUDAStreamGuard guard(stream); + backward_kernel<<>>( + get_accessor(neighbors), get_accessor(deltas), + get_accessor(distances), get_accessor(grad_distances), + get_accessor(grad_positions)); + }); return {grad_positions, Tensor(), Tensor()}; } }; TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { - m.impl("get_neighbor_pairs_cell", [](const Tensor& positions, const Tensor& batch, const Scalar& cutoff, - const Scalar& max_num_pairs, bool checkErrors) { - const tensor_list results = Autograd::apply(positions, batch, cutoff, max_num_pairs, checkErrors); - return std::make_tuple(results[0], results[1], results[2]); - }); + m.impl("get_neighbor_pairs_cell", + [](const Tensor& positions, const Tensor& batch, const Tensor& box_size, + const Scalar& cutoff, const Scalar& max_num_pairs, bool checkErrors) { + const tensor_list results = + Autograd::apply(positions, batch, box_size, cutoff, max_num_pairs, checkErrors); + return std::make_tuple(results[0], results[1], results[2]); + }); } From cb8a5865c1db557449678b122b9367ae08961912 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 2 May 2023 20:10:06 +0200 Subject: [PATCH 07/76] Update test --- tests/test_neighbors.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 3b19811bc..2a42564c4 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -9,20 +9,29 @@ def sort_neighbors(neighbors, deltas, distances): return neighbors[:, i_sorted], deltas[i_sorted], distances[i_sorted] @pytest.mark.parametrize("device", ["cpu", "cuda"]) -@pytest.mark.parametrize("strategy", ["brute", "cell"]) -@pytest.mark.parametrize("n_batches", [1, 2, 3,100]) -@pytest.mark.parametrize("cutoff", [0.1, 1.5, 1000.0]) +@pytest.mark.parametrize("strategy", ["cell"]) +@pytest.mark.parametrize("n_batches", [1]) +@pytest.mark.parametrize("cutoff", [1]) def test_neighbors(device, strategy, n_batches, cutoff): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") - n_atoms_per_batch = np.random.randint(2, 100, size=n_batches) - batch = torch.tensor([i for i in range(n_batches) for j in range(n_atoms_per_batch[i])]) + + np.random.seed(1234) + torch.manual_seed(4321) + n_atoms_per_batch = np.random.randint(2, 10, size=n_batches) + batch = torch.tensor([i for i in range(n_batches) for j in range(n_atoms_per_batch[i])], device=device, dtype=torch.int) cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) - pos = torch.randn(cumsum[-1], 3, device=device) + lbox=10.0 + pos = torch.rand(cumsum[-1], 3, device=device)*lbox #Ensure there is at least one pair pos[0,:] = torch.zeros(3) pos[1,:] = torch.zeros(3) pos.requires_grad = True + print("batch") + print(batch) + print("pos") + print(pos) + ref_neighbors = np.concatenate([np.tril_indices(n_atoms_per_batch[i], -1)+cumsum[i] for i in range(n_batches)], axis=1) pos_np = pos.cpu().detach().numpy() ref_distances = np.linalg.norm(pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]], axis=-1) @@ -34,13 +43,17 @@ def test_neighbors(device, strategy, n_batches, cutoff): ref_distance_vecs = ref_distance_vecs[mask] ref_distances = ref_distances[mask] max_num_pairs = ref_neighbors.shape[1] - - nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy) + box = torch.tensor([lbox, lbox, lbox]) + nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box) batch.to(device) neighbors, distances, distance_vecs = nl(pos, batch) neighbors = neighbors.cpu().detach().numpy() distance_vecs = distance_vecs.cpu().detach().numpy() distances = distances.cpu().detach().numpy() + print("neighbors") + print(neighbors) + print("ref_neighbors") + print(ref_neighbors) assert neighbors.shape == (2, max_num_pairs) assert distances.shape == (max_num_pairs,) assert distance_vecs.shape == (max_num_pairs, 3) From 0700aae5afb2ef5d38d93a106d0fed7b3c0bf981 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 2 May 2023 23:38:11 +0200 Subject: [PATCH 08/76] Working implementation --- torchmdnet/neighbors/neighbors_cuda_cell.cu | 112 ++++++-------------- 1 file changed, 33 insertions(+), 79 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index b98cd06ec..1636ad16b 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -1,6 +1,6 @@ +#include #include #include -#include #include #include #include @@ -127,9 +127,9 @@ __host__ __device__ int3 getNumberCells(float3 box_size, float cutoff) { // Get the cell coordinates of a point __device__ int3 getCell(float3 p, float3 box_size, float cutoff) { p = takeToUnitCell(p, box_size); - int cx = floorf(p.x / cutoff); - int cy = floorf(p.y / cutoff); - int cz = floorf(p.z / cutoff); + int cx = floorf((p.x + float(0.5) * box_size.x) / cutoff); + int cy = floorf((p.y + float(0.5) * box_size.y) / cutoff); + int cz = floorf((p.z + float(0.5) * box_size.z) / cutoff); int3 cell_dim = getNumberCells(box_size, cutoff); if (cx == cell_dim.x) cx = 0; @@ -155,9 +155,9 @@ __global__ void assignHash(const Accessor positions, uint64_t* hash auto ci = getCell(pi, box_size, cutoff); // Calculate the hash const int32_t hash = hashMorton(ci); - // Create a hash combining the Morton hash and the batch index, so that atoms in the same batch + // Create a hash combining the Morton hash and the batch index, so that atoms in the same cell // are contiguous - const int64_t hash_final = (static_cast(i_batch) << 32) | hash; + const int64_t hash_final = (static_cast(hash) << 32) | i_batch; hash_keys[i_atom] = hash_final; hash_values[i_atom] = i_atom; } @@ -169,10 +169,11 @@ public: CudaAllocator() { } T* allocate(std::ptrdiff_t num_elements) { - return static_cast(at::cuda::getCUDADeviceAllocator()->raw_allocate(num_elements * sizeof(T))); + return static_cast( + at::cuda::getCUDADeviceAllocator()->raw_allocate(num_elements * sizeof(T))); } void deallocate(T* ptr, size_t) { - at::cuda::getCUDADeviceAllocator()->raw_deallocate(ptr); + at::cuda::getCUDADeviceAllocator()->raw_deallocate(ptr); } }; @@ -192,14 +193,6 @@ static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, fl get_accessor(hash_values), get_accessor(batch), box_size, cutoff, num_atoms); }); - - std::cout << "hash_values: " << hash_values << std::endl; - std::cout << "hash_keys: " << std::endl; - for (int i = 0; i < num_atoms; i++) { - uint64_t hi = hash_keys[i]; - std::cout << std::bitset<64>(hi) << std::endl; - } - // Sort positions by hash_values using thrust thrust::device_ptr index_ptr(hash_values.data_ptr()); // pytorch allocator @@ -207,14 +200,6 @@ static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, fl // Adapted for thrust thrust::sort_by_key(thrust::cuda::par.on(stream), hash_keys.begin(), hash_keys.end(), index_ptr); - std::cout << "sorted hash_values: " << hash_values << std::endl; - // Print values of hash_keys in binary - std::cout << "sorted hash_keys: " << std::endl; - for (int i = 0; i < num_atoms; i++) { - uint64_t hi = hash_keys[i]; - std::cout << std::bitset<64>(hi) << std::endl; - } - Tensor sorted_positions = positions.index_select(0, hash_values); return std::make_tuple(sorted_positions, hash_values); } @@ -225,13 +210,13 @@ __device__ int getCellIndex(int3 cell, int3 cell_dim) { template __global__ void fillCellOffsetsD(const Accessor sorted_positions, - const Accessor sorted_indices, - Accessor cell_start, Accessor cell_end, + const Accessor sorted_indices, + Accessor cell_start, Accessor cell_end, const Accessor batch, float3 box_size, float cutoff) { const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; if (i_atom >= sorted_positions.size(0)) return; - const int32_t i_batch = batch[sorted_indices[i_atom]]; + // const int32_t i_batch = batch[sorted_indices[i_atom]]; const float3 pi = make_float3(sorted_positions[i_atom][0], sorted_positions[i_atom][1], sorted_positions[i_atom][2]); const int3 cell_dim = getNumberCells(box_size, cutoff); @@ -246,46 +231,34 @@ __global__ void fillCellOffsetsD(const Accessor sorted_positions, im1_cell = 0; } if (icell != im1_cell || i_atom == 0) { - int n_cells = cell_start.size(0); - if(icell>=n_cells or im1_cell>=n_cells) { - printf("icell: %d, im1_cell: %d, n_cells: %d\n", icell, im1_cell, n_cells); - return; - } - cell_start[icell][i_batch] = i_atom; + int n_cells = cell_start.size(0); + cell_start[icell] = i_atom; if (i_atom > 0) - cell_end[im1_cell][i_batch] = i_atom; + cell_end[im1_cell] = i_atom; } - if(i_atom == sorted_positions.size(0) - 1) { - cell_end[icell][i_batch] = i_atom + 1; + if (i_atom == sorted_positions.size(0) - 1) { + cell_end[icell] = i_atom + 1; } } // Fill the cell offsets for each batch, identifying the start and end of each cell for each batch // in the sorted positions -static auto fillCellOffsets(const Tensor& sorted_positions, - const Tensor& sorted_indices, - const Tensor& batch, float3 box_size, - float cutoff) { +static auto fillCellOffsets(const Tensor& sorted_positions, const Tensor& sorted_indices, + const Tensor& batch, float3 box_size, float cutoff) { const TensorOptions options = sorted_positions.options(); - const int num_batches = batch[-1].item() + 1; const int3 cell_dim = getNumberCells(box_size, cutoff); const int num_cells = cell_dim.x * cell_dim.y * cell_dim.z; - const Tensor cell_start = full({num_cells, num_batches}, -1, options.dtype(torch::kInt)); - const Tensor cell_end = empty({num_cells, num_batches}, options.dtype(torch::kInt)); - std::cerr<<"num_cells: "<>>( - get_accessor(sorted_positions), - get_accessor(sorted_indices), - get_accessor(cell_start), - get_accessor(cell_end), get_accessor(batch), box_size, cutoff); + get_accessor(sorted_positions), get_accessor(sorted_indices), + get_accessor(cell_start), get_accessor(cell_end), + get_accessor(batch), box_size, cutoff); }); - std::cerr<<"cell_start: "< __global__ void forward_kernel(const Accessor sorted_positions, const Accessor original_index, const Accessor batch, - const Accessor cell_start, const Accessor cell_end, + const Accessor cell_start, const Accessor cell_end, Accessor neighbors, Accessor deltas, Accessor distances, Accessor i_curr_pair, int num_atoms, int num_pairs, float3 box_size, float cutoff) { @@ -337,14 +310,16 @@ forward_kernel(const Accessor sorted_positions, cell_j.z += i / 9 - 1; cell_j = getPeriodicCell(cell_j, cell_dim); int icellj = getCellIndex(cell_j, cell_dim); - const int firstParticle = cell_start[icellj][i_batch]; + const int firstParticle = cell_start[icellj]; if (firstParticle != -1) { // Continue only if there are particles in this cell // Index of the last particle in the cell's list - const int lastParticle = cell_end[icellj][i_batch]; + const int lastParticle = cell_end[icellj]; const int nincell = lastParticle - firstParticle; for (int j = 0; j < nincell; j++) { int cur_j = j + firstParticle; - if (cur_j < i_atom) { + int orj = original_index[cur_j]; + int j_batch = batch[orj]; + if (cur_j < i_atom and j_batch == i_batch) { float3 pj = make_float3(sorted_positions[cur_j][0], sorted_positions[cur_j][1], sorted_positions[cur_j][2]); const scalar_t dx = pi.x - pj.x; @@ -394,32 +369,17 @@ public: box_size[2].item()); float cutoff_ = cutoff.toFloat(); Tensor sorted_positions, hash_values; - std::cerr << "before sortPositionsByHash" << std::endl; std::tie(sorted_positions, hash_values) = sortPositionsByHash(positions, batch, box_size_, cutoff_); cudaDeviceSynchronize(); - std::cerr << "after sortPositionsByHash" << std::endl; - // Step 3 Tensor cell_start, cell_end; - std::cerr << "before fillCellOffsets" << std::endl; std::tie(cell_start, cell_end) = - fillCellOffsets(sorted_positions, hash_values, - batch, box_size_, cutoff_); - cudaDeviceSynchronize(); - std::cerr << "after fillCellOffsets" << std::endl; - cudaDeviceSynchronize(); - std::cerr << "Number of pairs: " << num_pairs << std::endl; - cudaDeviceSynchronize(); - std::cerr <<"Allocating memory for neighbors" << std::endl; + fillCellOffsets(sorted_positions, hash_values, batch, box_size_, cutoff_); const Tensor neighbors = full({2, num_pairs}, -1, options.dtype(kInt32)); - std::cerr << "Allocating memory for deltas" << std::endl; const Tensor deltas = empty({num_pairs, 3}, options); - std::cerr << "Allocating memory for distances" << std::endl; const Tensor distances = full(num_pairs, 0, options); const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); - cudaDeviceSynchronize(); - std::cerr << "before forward_kernel" << std::endl; - const auto stream = getCurrentCUDAStream(positions.get_device()); + const auto stream = getCurrentCUDAStream(positions.get_device()); { // Use the cell list for each batch to find the neighbors const CUDAStreamGuard guard(stream); AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "forward", [&] { @@ -428,30 +388,24 @@ public: forward_kernel<<>>( get_accessor(sorted_positions), get_accessor(hash_values), get_accessor(batch), - get_accessor(cell_start), get_accessor(cell_end), + get_accessor(cell_start), get_accessor(cell_end), get_accessor(neighbors), get_accessor(deltas), get_accessor(distances), get_accessor(i_curr_pair), num_atoms, num_pairs, box_size_, cutoff_); }); } - cudaDeviceSynchronize(); - std::cerr << "after forward_kernel" << std::endl; // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA // graphs if (checkErrors) { - std::cout<<"checking errors"<(); TORCH_CHECK(num_found_pairs <= max_num_pairs_, "Too many neighbor pairs found. Maximum is " + std::to_string(max_num_pairs_), " but found " + std::to_string(num_found_pairs)); - std::cout<<"no errors"<()}); deltas.resize_({i_curr_pair[0].item(), 3}); distances.resize_(i_curr_pair[0].item()); - std::cout<<"after resize"<save_for_backward({neighbors, deltas, distances}); ctx->saved_data["num_atoms"] = num_atoms; return {neighbors, deltas, distances}; From 658eb0a79004e06378ed1fcda3d24c471c28c527 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 3 May 2023 11:42:06 +0200 Subject: [PATCH 09/76] Adapt to work with scalar_t instead of float/double Document cell list implementation Clean up a bit --- torchmdnet/neighbors/neighbors_cuda_cell.cu | 264 ++++++++++++++------ 1 file changed, 182 insertions(+), 82 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index 1636ad16b..6bd1a4353 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -1,3 +1,6 @@ +/* Raul P. Pelaez 2023. Batched cell list neighbor list implementation for CUDA. + + */ #include #include #include @@ -24,6 +27,10 @@ using torch::autograd::AutogradContext; using torch::autograd::Function; using torch::autograd::tensor_list; +template struct scalar3 { + scalar_t x, y, z; +}; + template using Accessor = PackedTensorAccessor32; @@ -45,6 +52,11 @@ __global__ void backward_kernel(const Accessor neighbors, const Accessor deltas, const Accessor distances, const Accessor grad_distances, Accessor grad_positions) { + // What the backward kernel does: + // For each pair of atoms, it calculates the gradient of the distance between them + // with respect to the positions of the atoms. + // The gradient is then added to the gradient of the positions. + // The gradient of the distance is calculated using the chain rule: const int32_t i_pair = blockIdx.x * blockDim.x + threadIdx.x; const int32_t num_pairs = neighbors.size(1); if (i_pair >= num_pairs) @@ -61,7 +73,6 @@ backward_kernel(const Accessor neighbors, const Accessor +__device__ auto takeToUnitCell(scalar3 p, scalar3 box_size) { + p.x = p.x - floorf(p.x / box_size.x + scalar_t(0.5)) * box_size.x; + p.y = p.y - floorf(p.y / box_size.y + scalar_t(0.5)) * box_size.y; + p.z = p.z - floorf(p.z / box_size.z + scalar_t(0.5)) * box_size.z; return p; } -// Get the number of cells in each dimension -__host__ __device__ int3 getNumberCells(float3 box_size, float cutoff) { +/* + * @brief Calculates the cell dimensions for a given box size and cutoff + * @param box_size The box size + * @param cutoff The cutoff + * @return The cell dimensions + */ +template +__host__ __device__ int3 getCellDimensions(scalar3 box_size, scalar_t cutoff) { int3 cell_dim = make_int3(box_size.x / cutoff, box_size.y / cutoff, box_size.z / cutoff); // Minimum 3 cells in each dimension cell_dim.x = thrust::max(cell_dim.x, 3); @@ -124,13 +157,23 @@ __host__ __device__ int3 getNumberCells(float3 box_size, float cutoff) { return cell_dim; } -// Get the cell coordinates of a point -__device__ int3 getCell(float3 p, float3 box_size, float cutoff) { +/* + * @brief Get the cell index of a point + * @param p The point position + * @param box_size The size of the box in each dimension + * @param cutoff The cutoff + * @return The cell index + */ +template +__device__ int3 getCell(scalar3 p, scalar3 box_size, scalar_t cutoff) { p = takeToUnitCell(p, box_size); - int cx = floorf((p.x + float(0.5) * box_size.x) / cutoff); - int cy = floorf((p.y + float(0.5) * box_size.y) / cutoff); - int cz = floorf((p.z + float(0.5) * box_size.z) / cutoff); - int3 cell_dim = getNumberCells(box_size, cutoff); + // Take to the [0, box_size] range and divide by cutoff (which is the cell size) + int cx = floorf((p.x + scalar_t(0.5) * box_size.x) / cutoff); + int cy = floorf((p.y + scalar_t(0.5) * box_size.y) / cutoff); + int cz = floorf((p.z + scalar_t(0.5) * box_size.z) / cutoff); + int3 cell_dim = getCellDimensions(box_size, cutoff); + // Wrap around. If the position of a particle is exactly box_size, it will be in the last cell, + // which results in an illegal access down the line. if (cx == cell_dim.x) cx = 0; if (cy == cell_dim.y) @@ -140,18 +183,50 @@ __device__ int3 getCell(float3 p, float3 box_size, float cutoff) { return make_int3(cx, cy, cz); } +/* + * @brief Get the index of a cell in a 1D array of cells. + * @param cell The cell coordinates, assumed to be in the range [0, cell_dim]. + * @param cell_dim The number of cells in each dimension + */ +__device__ int getCellIndex(int3 cell, int3 cell_dim) { + return cell.x + cell_dim.x * (cell.y + cell_dim.y * cell.z); +} + +/* + @brief Fold a cell coordinate to the range [0, cell_dim) + @param cell The cell coordinate + @param cell_dim The dimensions of the grid + @return The folded cell coordinate +*/ +__device__ int3 getPeriodicCell(int3 cell, int3 cell_dim) { + int3 periodic_cell = cell; + if (cell.x < 0) + periodic_cell.x += cell_dim.x; + if (cell.x >= cell_dim.x) + periodic_cell.x -= cell_dim.x; + if (cell.y < 0) + periodic_cell.y += cell_dim.y; + if (cell.y >= cell_dim.y) + periodic_cell.y -= cell_dim.y; + if (cell.z < 0) + periodic_cell.z += cell_dim.z; + if (cell.z >= cell_dim.z) + periodic_cell.z -= cell_dim.z; + return periodic_cell; +} + // Assign a hash to each atom based on its position and batch. // This hash is such that atoms in the same cell and batch have the same hash. template __global__ void assignHash(const Accessor positions, uint64_t* hash_keys, Accessor hash_values, const Accessor batch, - float3 box_size, float cutoff, int32_t num_atoms) { + scalar3 box_size, scalar_t cutoff, int32_t num_atoms) { const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; if (i_atom >= num_atoms) return; const int32_t i_batch = batch[i_atom]; // Move to the unit cell - float3 pi = make_float3(positions[i_atom][0], positions[i_atom][1], positions[i_atom][2]); + scalar3 pi = {positions[i_atom][0], positions[i_atom][1], positions[i_atom][2]}; auto ci = getCell(pi, box_size, cutoff); // Calculate the hash const int32_t hash = hashMorton(ci); @@ -177,9 +252,19 @@ public: } }; -// Sort the positions by hash, based on the cell assigned to each position and the batch index -static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, float3 box_size, - float cutoff) { +/* + * @brief Sort the positions by hash, based on the cell assigned to each position and the batch + * index + * @param positions The positions of the atoms + * @param batch The batch index of each atom + * @param box_size The size of the box in each dimension + * @param cutoff The cutoff + * @return A tuple of the sorted positions and the original indices of each atom in the sorted list + */ + +static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, + const Tensor& box_size, const Scalar& cutoff) { + const int num_atoms = positions.size(0); const auto options = positions.options(); thrust::device_vector hash_keys(num_atoms); @@ -188,44 +273,43 @@ static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, fl const int blocks = (num_atoms + threads - 1) / threads; auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "assignHash", [&] { + scalar_t cutoff_ = cutoff.to(); + scalar3 box_size_ = {box_size[0].item(), box_size[1].item(), + box_size[2].item()}; assignHash<<>>( get_accessor(positions), thrust::raw_pointer_cast(hash_keys.data()), - get_accessor(hash_values), get_accessor(batch), box_size, - cutoff, num_atoms); + get_accessor(hash_values), get_accessor(batch), box_size_, + cutoff_, num_atoms); }); - // Sort positions by hash_values using thrust thrust::device_ptr index_ptr(hash_values.data_ptr()); - // pytorch allocator CudaAllocator allocator; - // Adapted for thrust thrust::sort_by_key(thrust::cuda::par.on(stream), hash_keys.begin(), hash_keys.end(), index_ptr); Tensor sorted_positions = positions.index_select(0, hash_values); return std::make_tuple(sorted_positions, hash_values); } -__device__ int getCellIndex(int3 cell, int3 cell_dim) { - return cell.x + cell_dim.x * (cell.y + cell_dim.y * cell.z); -} - template __global__ void fillCellOffsetsD(const Accessor sorted_positions, const Accessor sorted_indices, Accessor cell_start, Accessor cell_end, - const Accessor batch, float3 box_size, float cutoff) { + const Accessor batch, scalar3 box_size, + scalar_t cutoff) { + // Since positions are sorted by cell, for a given atom, if the previous atom is in a different + // cell, then the current atom is the first atom in its cell We use this fact to fill the + // cell_start and cell_end arrays const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; if (i_atom >= sorted_positions.size(0)) return; - // const int32_t i_batch = batch[sorted_indices[i_atom]]; - const float3 pi = make_float3(sorted_positions[i_atom][0], sorted_positions[i_atom][1], - sorted_positions[i_atom][2]); - const int3 cell_dim = getNumberCells(box_size, cutoff); + const scalar3 pi = {sorted_positions[i_atom][0], sorted_positions[i_atom][1], + sorted_positions[i_atom][2]}; + const int3 cell_dim = getCellDimensions(box_size, cutoff); const int icell = getCellIndex(getCell(pi, box_size, cutoff), cell_dim); int im1_cell; if (i_atom > 0) { int im1 = i_atom - 1; - const float3 pim1 = make_float3(sorted_positions[im1][0], sorted_positions[im1][1], - sorted_positions[im1][2]); + const scalar3 pim1 = {sorted_positions[im1][0], sorted_positions[im1][1], + sorted_positions[im1][2]}; im1_cell = getCellIndex(getCell(pim1, box_size, cutoff), cell_dim); } else { im1_cell = 0; @@ -241,12 +325,28 @@ __global__ void fillCellOffsetsD(const Accessor sorted_positions, } } -// Fill the cell offsets for each batch, identifying the start and end of each cell for each batch -// in the sorted positions +/* + @brief + Fill the cell offsets for each batch, identifying the start and end of each cell in the sorted + positions + @param sorted_positions The positions sorted by cell + @param sorted_indices The original indices of the sorted positions + @param batch The batch index of each position + @param box_size The size of the box + @param cutoff The cutoff distance + @return A tuple of cell_start and cell_end arrays +*/ static auto fillCellOffsets(const Tensor& sorted_positions, const Tensor& sorted_indices, - const Tensor& batch, float3 box_size, float cutoff) { + const Tensor& batch, const Tensor& box_size, const Scalar& cutoff) { const TensorOptions options = sorted_positions.options(); - const int3 cell_dim = getNumberCells(box_size, cutoff); + + int3 cell_dim; + AT_DISPATCH_FLOATING_TYPES(sorted_positions.scalar_type(), "fillCellOffsets", [&] { + scalar_t cutoff_ = cutoff.to(); + scalar3 box_size_ = {box_size[0].item(), box_size[1].item(), + box_size[2].item()}; + cell_dim = getCellDimensions(box_size_, cutoff_); + }); const int num_cells = cell_dim.x * cell_dim.y * cell_dim.z; const Tensor cell_start = full({num_cells}, -1, options.dtype(torch::kInt)); const Tensor cell_end = empty({num_cells}, options.dtype(torch::kInt)); @@ -254,30 +354,32 @@ static auto fillCellOffsets(const Tensor& sorted_positions, const Tensor& sorted const int blocks = (sorted_positions.size(0) + threads - 1) / threads; AT_DISPATCH_FLOATING_TYPES(sorted_positions.scalar_type(), "fillCellOffsets", [&] { auto stream = at::cuda::getCurrentCUDAStream(); + scalar_t cutoff_ = cutoff.to(); + scalar3 box_size_ = {box_size[0].item(), box_size[1].item(), + box_size[2].item()}; fillCellOffsetsD<<>>( get_accessor(sorted_positions), get_accessor(sorted_indices), get_accessor(cell_start), get_accessor(cell_end), - get_accessor(batch), box_size, cutoff); + get_accessor(batch), box_size_, cutoff_); }); return std::make_tuple(cell_start, cell_end); } -// Fold a cell coordinate to the range [0, cell_dim) -__device__ int3 getPeriodicCell(int3 cell, int3 cell_dim) { - int3 periodic_cell = cell; - if (cell.x < 0) - periodic_cell.x += cell_dim.x; - if (cell.x >= cell_dim.x) - periodic_cell.x -= cell_dim.x; - if (cell.y < 0) - periodic_cell.y += cell_dim.y; - if (cell.y >= cell_dim.y) - periodic_cell.y -= cell_dim.y; - if (cell.z < 0) - periodic_cell.z += cell_dim.z; - if (cell.z >= cell_dim.z) - periodic_cell.z -= cell_dim.z; - return periodic_cell; +/* + @brief Get the cell index of the i'th neighboring cell for a given cell + @param cell_i The cell coordinates + @param i The index of the neighboring cell, from 0 to 26 + @param cell_dim The dimensions of the cell grid + @return The cell index of the i'th neighboring cell +*/ +__device__ int getNeighborCellIndex(int3 cell_i, int i, int3 cell_dim) { + auto cell_j = cell_i; + cell_j.x += i % 3 - 1; + cell_j.y += (i / 3) % 3 - 1; + cell_j.z += i / 9 - 1; + cell_j = getPeriodicCell(cell_j, cell_dim); + int icellj = getCellIndex(cell_j, cell_dim); + return icellj; } // Traverse the cell list for each atom and find the neighbors @@ -288,28 +390,22 @@ forward_kernel(const Accessor sorted_positions, const Accessor cell_start, const Accessor cell_end, Accessor neighbors, Accessor deltas, Accessor distances, Accessor i_curr_pair, int num_atoms, - int num_pairs, float3 box_size, float cutoff) { + int num_pairs, scalar3 box_size, scalar_t cutoff) { + // Each atom traverses the cells around it and finds the neighbors + // Atoms for all batches are placed in the same cell list, but other batches are ignored while + // traversing const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; if (i_atom >= num_atoms) return; - // Each batch has its own cell list, starting at cell_start[0][i_batch] and ending at - // cell_start[ncells-1][i_batch] Each thread is responsible for a single atom Each thread will - // loop over all atoms in the cell list of the current batch const int ori = original_index[i_atom]; const int i_batch = batch[ori]; - float3 pi = make_float3(sorted_positions[i_atom][0], sorted_positions[i_atom][1], - sorted_positions[i_atom][2]); + const scalar3 pi = {sorted_positions[i_atom][0], sorted_positions[i_atom][1], + sorted_positions[i_atom][2]}; const int3 cell_i = getCell(pi, box_size, cutoff); - const int3 cell_dim = getNumberCells(box_size, cutoff); - const int i_cell_index = getCellIndex(cell_i, cell_dim); + const int3 cell_dim = getCellDimensions(box_size, cutoff); // Loop over the 27 cells around the current cell for (int i = 0; i < 27; i++) { - auto cell_j = cell_i; - cell_j.x += i % 3 - 1; - cell_j.y += (i / 3) % 3 - 1; - cell_j.z += i / 9 - 1; - cell_j = getPeriodicCell(cell_j, cell_dim); - int icellj = getCellIndex(cell_j, cell_dim); + int icellj = getNeighborCellIndex(cell_i, i, cell_dim); const int firstParticle = cell_start[icellj]; if (firstParticle != -1) { // Continue only if there are particles in this cell // Index of the last particle in the cell's list @@ -319,9 +415,13 @@ forward_kernel(const Accessor sorted_positions, int cur_j = j + firstParticle; int orj = original_index[cur_j]; int j_batch = batch[orj]; - if (cur_j < i_atom and j_batch == i_batch) { - float3 pj = make_float3(sorted_positions[cur_j][0], sorted_positions[cur_j][1], - sorted_positions[cur_j][2]); + if (j_batch > + i_batch) // Particles are sorted by batch after cell, so we can break early here + break; + if (orj < ori and j_batch == i_batch) { + const scalar3 pj = {sorted_positions[cur_j][0], + sorted_positions[cur_j][1], + sorted_positions[cur_j][2]}; const scalar_t dx = pi.x - pj.x; const scalar_t dy = pi.y - pj.y; const scalar_t dz = pi.z - pj.z; @@ -331,7 +431,7 @@ forward_kernel(const Accessor sorted_positions, // We handle too many neighbors outside of the kernel if (i_pair < neighbors.size(1)) { neighbors[0][i_pair] = ori; - neighbors[1][i_pair] = original_index[cur_j]; + neighbors[1][i_pair] = orj; deltas[i_pair][0] = dx; deltas[i_pair][1] = dy; deltas[i_pair][2] = dz; @@ -349,7 +449,6 @@ public: static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, const Tensor& box_size, const Scalar& cutoff, const Scalar& max_num_pairs, bool checkErrors) { - // The algorithm for the cell list construction can be summarized in three separate steps: // 1. Hash (label) the particles according to the cell (bin) they lie in. // 2. Sort the particles and hashes using the hashes as the ordering label @@ -365,16 +464,13 @@ public: const int num_pairs = max_num_pairs_; const TensorOptions options = positions.options(); // Steps 1 and 2 - float3 box_size_ = make_float3(box_size[0].item(), box_size[1].item(), - box_size[2].item()); - float cutoff_ = cutoff.toFloat(); Tensor sorted_positions, hash_values; std::tie(sorted_positions, hash_values) = - sortPositionsByHash(positions, batch, box_size_, cutoff_); + sortPositionsByHash(positions, batch, box_size, cutoff); cudaDeviceSynchronize(); Tensor cell_start, cell_end; std::tie(cell_start, cell_end) = - fillCellOffsets(sorted_positions, hash_values, batch, box_size_, cutoff_); + fillCellOffsets(sorted_positions, hash_values, batch, box_size, cutoff); const Tensor neighbors = full({2, num_pairs}, -1, options.dtype(kInt32)); const Tensor deltas = empty({num_pairs, 3}, options); const Tensor distances = full(num_pairs, 0, options); @@ -383,6 +479,10 @@ public: { // Use the cell list for each batch to find the neighbors const CUDAStreamGuard guard(stream); AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "forward", [&] { + const scalar_t cutoff_ = cutoff.to(); + const scalar3 box_size_ = {box_size[0].item(), + box_size[1].item(), + box_size[2].item()}; const int threads = 128; const int blocks = (num_atoms + threads - 1) / threads; forward_kernel<<>>( From badecdc48e3298331a0d67a5006ca99d3abc3150 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 3 May 2023 11:58:48 +0200 Subject: [PATCH 10/76] Adapt brute force method to understand batch --- torchmdnet/neighbors/neighbors_cuda.cu | 99 +++++++++++++------------- 1 file changed, 51 insertions(+), 48 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index 8d9f45918..24d645812 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -3,9 +3,6 @@ #include #include #include -#include -#include -#include using c10::cuda::CUDAStreamGuard; using c10::cuda::getCurrentCUDAStream; using std::make_tuple; @@ -26,7 +23,8 @@ using torch::autograd::tensor_list; template using Accessor = PackedTensorAccessor32; -template inline Accessor get_accessor(const Tensor& tensor) { +template +inline Accessor get_accessor(const Tensor& tensor) { return tensor.packed_accessor32(); }; @@ -39,10 +37,12 @@ template <> __device__ __forceinline__ double sqrt_(double x) { }; template -__global__ void forward_kernel(const int32_t num_all_pairs, const Accessor positions, - const Accessor batch_offsets, const int32_t batch_index, - const scalar_t cutoff2, Accessor i_curr_pair, Accessor neighbors, - Accessor deltas, Accessor distances) { +__global__ void +forward_kernel(const int32_t num_all_pairs, const Accessor positions, + // const Accessor batch_offsets, const int32_t batch_index, + const Accessor batch, const scalar_t cutoff2, + Accessor i_curr_pair, Accessor neighbors, + Accessor deltas, Accessor distances) { const int32_t index = blockIdx.x * blockDim.x + threadIdx.x; if (index >= num_all_pairs) return; @@ -50,14 +50,13 @@ __global__ void forward_kernel(const int32_t num_all_pairs, const Accessor 2 * index) row--; - const int32_t position_offset = batch_index ? batch_offsets[batch_index - 1] : 0; - const int32_t column = (index - row * (row - 1) / 2) + position_offset; - row += position_offset; + const int32_t column = (index - row * (row - 1) / 2); + if (batch[row] != batch[column]) + return; scalar_t delta_x = positions[row][0] - positions[column][0]; scalar_t delta_y = positions[row][1] - positions[column][1]; scalar_t delta_z = positions[row][2] - positions[column][2]; const scalar_t distance2 = delta_x * delta_x + delta_y * delta_y + delta_z * delta_z; - if (distance2 > cutoff2) return; @@ -73,11 +72,11 @@ __global__ void forward_kernel(const int32_t num_all_pairs, const Accessor -__global__ void backward_kernel(const Accessor neighbors, const Accessor deltas, - const Accessor distances, const Accessor grad_distances, - Accessor grad_positions) { +__global__ void +backward_kernel(const Accessor neighbors, const Accessor deltas, + const Accessor distances, const Accessor grad_distances, + Accessor grad_positions) { const int32_t i_pair = blockIdx.x * blockDim.x + threadIdx.x; const int32_t num_pairs = neighbors.size(1); if (i_pair >= num_pairs) @@ -104,22 +103,24 @@ static void checkInput(const Tensor& positions, const Tensor& batch) { // Batch is assumed to be non-negative // Each batch can have a different number of atoms TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); - TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); + TORCH_CHECK(positions.size(0) > 0, + "Expected the 1nd dimension size of \"positions\" to be more than 0"); TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); TORCH_CHECK(batch.dim() == 1, "Expected \"batch\" to have one dimension"); - TORCH_CHECK( - batch.size(0) == positions.size(0), - "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension size of \"positions\""); + TORCH_CHECK(batch.size(0) == positions.size(0), + "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension " + "size of \"positions\""); TORCH_CHECK(batch.is_contiguous(), "Expected \"batch\" to be contiguous"); TORCH_CHECK(batch.dtype() == torch::kInt32, "Expected \"batch\" to have torch::kInt32 dtype"); } class Autograd : public Function { public: - static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, const Scalar& cutoff, - const Scalar& max_num_pairs, bool checkErrors) { + static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, + const Scalar& cutoff, const Scalar& max_num_pairs, + bool checkErrors) { checkInput(positions, batch); const auto max_num_pairs_ = max_num_pairs.toLong(); TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); @@ -127,35 +128,34 @@ public: const int num_pairs = max_num_pairs_; const TensorOptions options = positions.options(); const auto stream = getCurrentCUDAStream(positions.get_device()); - const Tensor num_atoms_per_batch = torch::bincount(batch).to(torch::kCPU); - const int n_batches = num_atoms_per_batch.size(0); - const Tensor batch_offsets = torch::cumsum(num_atoms_per_batch, 0, torch::kInt32).to(positions.device()); const Tensor neighbors = full({2, num_pairs}, -1, options.dtype(kInt32)); const Tensor deltas = empty({num_pairs, 3}, options); const Tensor distances = full(num_pairs, 0, options); const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); { const CUDAStreamGuard guard(stream); - for (int i = 0; i < n_batches; i++) { - const int num_atoms = num_atoms_per_batch[i].item(); - const int num_all_pairs = num_atoms * (num_atoms - 1) / 2; - const int num_threads = 128; - const int num_blocks = max((num_all_pairs + num_threads - 1) / num_threads, 1); - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { + const int num_atoms = positions.size(0); + const int num_all_pairs = num_atoms * (num_atoms - 1) / 2; + const int num_threads = 128; + const int num_blocks = max((num_all_pairs + num_threads - 1) / num_threads, 1); + AT_DISPATCH_FLOATING_TYPES( + positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { const scalar_t cutoff_ = cutoff.to(); TORCH_CHECK(cutoff_ > 0, "Expected \"cutoff\" to be positive"); forward_kernel<<>>( - num_all_pairs, get_accessor(positions), get_accessor(batch_offsets), i, - cutoff_ * cutoff_, get_accessor(i_curr_pair), get_accessor(neighbors), + num_all_pairs, get_accessor(positions), + get_accessor(batch), cutoff_ * cutoff_, + get_accessor(i_curr_pair), get_accessor(neighbors), get_accessor(deltas), get_accessor(distances)); }); - } } - // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA graphs + // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA + // graphs if (checkErrors) { int num_found_pairs = i_curr_pair.item(); TORCH_CHECK(num_found_pairs <= max_num_pairs_, - "Too many neighbor pairs found. Maximum is " + std::to_string(max_num_pairs_), + "Too many neighbor pairs found. Maximum is " + + std::to_string(max_num_pairs_), " but found " + std::to_string(num_found_pairs)); } neighbors.resize_({2, i_curr_pair[0].item()}); @@ -181,22 +181,25 @@ public: const Tensor distances = data[2]; const Tensor grad_positions = zeros({num_atoms, 3}, grad_distances.options()); - AT_DISPATCH_FLOATING_TYPES(grad_distances.scalar_type(), "get_neighbor_pairs_backward", [&]() { - const CUDAStreamGuard guard(stream); - backward_kernel<<>>( - get_accessor(neighbors), get_accessor(deltas), - get_accessor(distances), get_accessor(grad_distances), - get_accessor(grad_positions)); - }); + AT_DISPATCH_FLOATING_TYPES( + grad_distances.scalar_type(), "get_neighbor_pairs_backward", [&]() { + const CUDAStreamGuard guard(stream); + backward_kernel<<>>( + get_accessor(neighbors), get_accessor(deltas), + get_accessor(distances), get_accessor(grad_distances), + get_accessor(grad_positions)); + }); return {grad_positions, Tensor(), Tensor()}; } }; TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { - m.impl("get_neighbor_pairs", [](const Tensor& positions, const Tensor& batch, const Tensor& box_size, const Scalar& cutoff, - const Scalar& max_num_pairs, bool checkErrors) { - const tensor_list results = Autograd::apply(positions, batch, cutoff, max_num_pairs, checkErrors); - return std::make_tuple(results[0], results[1], results[2]); - }); + m.impl("get_neighbor_pairs", + [](const Tensor& positions, const Tensor& batch, const Tensor& box_size, + const Scalar& cutoff, const Scalar& max_num_pairs, bool checkErrors) { + const tensor_list results = + Autograd::apply(positions, batch, cutoff, max_num_pairs, checkErrors); + return std::make_tuple(results[0], results[1], results[2]); + }); } From a25eeecaa231fec9243ea55f79a6c6f1ed669376 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 3 May 2023 11:59:27 +0200 Subject: [PATCH 11/76] Update test_neigbors --- tests/test_neighbors.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 2a42564c4..df9219852 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -9,13 +9,12 @@ def sort_neighbors(neighbors, deltas, distances): return neighbors[:, i_sorted], deltas[i_sorted], distances[i_sorted] @pytest.mark.parametrize("device", ["cpu", "cuda"]) -@pytest.mark.parametrize("strategy", ["cell"]) -@pytest.mark.parametrize("n_batches", [1]) -@pytest.mark.parametrize("cutoff", [1]) +@pytest.mark.parametrize("strategy", ["brute", "cell"]) +@pytest.mark.parametrize("n_batches", [1, 2, 3, 4, 128]) +@pytest.mark.parametrize("cutoff", [0.1, 1.0, 1000.0]) def test_neighbors(device, strategy, n_batches, cutoff): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") - np.random.seed(1234) torch.manual_seed(4321) n_atoms_per_batch = np.random.randint(2, 10, size=n_batches) @@ -27,11 +26,6 @@ def test_neighbors(device, strategy, n_batches, cutoff): pos[0,:] = torch.zeros(3) pos[1,:] = torch.zeros(3) pos.requires_grad = True - print("batch") - print(batch) - print("pos") - print(pos) - ref_neighbors = np.concatenate([np.tril_indices(n_atoms_per_batch[i], -1)+cumsum[i] for i in range(n_batches)], axis=1) pos_np = pos.cpu().detach().numpy() ref_distances = np.linalg.norm(pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]], axis=-1) @@ -50,10 +44,6 @@ def test_neighbors(device, strategy, n_batches, cutoff): neighbors = neighbors.cpu().detach().numpy() distance_vecs = distance_vecs.cpu().detach().numpy() distances = distances.cpu().detach().numpy() - print("neighbors") - print(neighbors) - print("ref_neighbors") - print(ref_neighbors) assert neighbors.shape == (2, max_num_pairs) assert distances.shape == (max_num_pairs,) assert distance_vecs.shape == (max_num_pairs, 3) From 7040359b4165651b72e29c72571de5f61feeb927 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 3 May 2023 12:29:36 +0200 Subject: [PATCH 12/76] Ensure number of total pairs to check does not overflow int32 --- torchmdnet/neighbors/neighbors_cuda.cu | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index 24d645812..84aba9306 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -38,12 +38,11 @@ template <> __device__ __forceinline__ double sqrt_(double x) { template __global__ void -forward_kernel(const int32_t num_all_pairs, const Accessor positions, - // const Accessor batch_offsets, const int32_t batch_index, - const Accessor batch, const scalar_t cutoff2, +forward_kernel(const int64_t num_all_pairs, const Accessor positions, + const Accessor batch, scalar_t cutoff2, Accessor i_curr_pair, Accessor neighbors, Accessor deltas, Accessor distances) { - const int32_t index = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t index = blockIdx.x * blockDim.x + threadIdx.x; if (index >= num_all_pairs) return; @@ -134,10 +133,10 @@ public: const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); { const CUDAStreamGuard guard(stream); - const int num_atoms = positions.size(0); - const int num_all_pairs = num_atoms * (num_atoms - 1) / 2; - const int num_threads = 128; - const int num_blocks = max((num_all_pairs + num_threads - 1) / num_threads, 1); + const int32_t num_atoms = positions.size(0); + const int64_t num_all_pairs = num_atoms * (num_atoms - 1) / 2; + const int64_t num_threads = 128; + const int64_t num_blocks = max((num_all_pairs + num_threads - 1) / num_threads, 1l); AT_DISPATCH_FLOATING_TYPES( positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { const scalar_t cutoff_ = cutoff.to(); From 9afad752e238413a0904aa2d11f9396d895c1a5d Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 3 May 2023 12:30:18 +0200 Subject: [PATCH 13/76] Update benchmark --- benchmarks/neighbors.py | 41 ++++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py index dfe9b91d2..453742b0a 100644 --- a/benchmarks/neighbors.py +++ b/benchmarks/neighbors.py @@ -4,7 +4,7 @@ from torchmdnet.models.utils import DistanceCellList -def benchmark_neighbors(device, strategy, n_batches, total_num_particles): +def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_num_neighbors=32): """Benchmark the neighbor list generation. Parameters @@ -24,7 +24,7 @@ def benchmark_neighbors(device, strategy, n_batches, total_num_particles): """ density = 0.5; num_particles = total_num_particles // n_batches - expected_num_neighbors = min(num_particles, 32); + expected_num_neighbors = mean_num_neighbors cutoff = np.cbrt(3 * expected_num_neighbors / (4 * np.pi * density)); n_atoms_per_batch = np.random.randint(num_particles-10, num_particles+10, size=n_batches) #Fix the last batch so that the total number of particles is correct @@ -33,11 +33,12 @@ def benchmark_neighbors(device, strategy, n_batches, total_num_particles): n_atoms_per_batch[-1] = 1 lbox = np.cbrt(num_particles / density); - batch = torch.tensor([i for i in range(n_batches) for j in range(n_atoms_per_batch[i])], device=device) + batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int32, device=device), torch.tensor(n_atoms_per_batch, dtype=torch.int32, device=device)) + cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) pos = torch.rand(cumsum[-1], 3, device=device)*lbox max_num_pairs = torch.tensor(expected_num_neighbors * n_atoms_per_batch.sum(), dtype=torch.int64).item() - nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy) + nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=torch.Tensor([lbox, lbox, lbox])) #Warmup neighbors, distances, distance_vecs = nl(pos, batch) if device == 'cuda': @@ -51,23 +52,29 @@ def benchmark_neighbors(device, strategy, n_batches, total_num_particles): end = torch.cuda.Event(enable_timing=True) start.record() - with torch.autograd.profiler.profile(use_cuda=True) as prof: - for i in range(nruns): - neighbors, distances, distance_vecs = nl(pos, batch) + for i in range(nruns): + neighbors, distances, distance_vecs = nl(pos, batch) end.record() if device == 'cuda': torch.cuda.synchronize() #Final time return (start.elapsed_time(end) / nruns) - if __name__ == '__main__': - n_particles = 10000 - print("Benchmarking neighbor list generation for {} particles".format(n_particles)) - #Loop over different number of batches - for n_batches in [1, 10, 100, 1000]: - time = benchmark_neighbors(device='cuda', - strategy='brute', - n_batches=n_batches, - total_num_particles=n_particles) - print("Time for {} batches: {} ms".format(n_batches, time)) + n_particles = 100000 + mean_num_neighbors = min(n_particles, 128); + print("Benchmarking neighbor list generation for {} particles with {} neighbors on average".format(n_particles, mean_num_neighbors)) + for strategy in ['brute', 'cell']: + print("Strategy: {}".format(strategy)) + print("--------") + print("{:<10} {:<10}".format("Batch size", "Time (ms)")) + print("{:<10} {:<10}".format("----------", "---------")) + #Loop over different number of batches + for n_batches in [1, 10, 100, 1000]: + time = benchmark_neighbors(device='cuda', + strategy=strategy, + n_batches=n_batches, + total_num_particles=n_particles, + mean_num_neighbors=mean_num_neighbors + ) + print("{:<10} {:<10.2f}".format(n_batches, time)) From 2589479c2f919bb0b36fcfbfb1187b4f7b11e3d6 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 3 May 2023 12:58:57 +0200 Subject: [PATCH 14/76] Improve efficiency of tests and benchmark --- benchmarks/neighbors.py | 27 ++++++++++++++++----------- tests/test_neighbors.py | 7 +++---- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py index 453742b0a..13a2c2c68 100644 --- a/benchmarks/neighbors.py +++ b/benchmarks/neighbors.py @@ -26,21 +26,26 @@ def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_n num_particles = total_num_particles // n_batches expected_num_neighbors = mean_num_neighbors cutoff = np.cbrt(3 * expected_num_neighbors / (4 * np.pi * density)); - n_atoms_per_batch = np.random.randint(num_particles-10, num_particles+10, size=n_batches) - #Fix the last batch so that the total number of particles is correct - n_atoms_per_batch[-1] += total_num_particles - n_atoms_per_batch.sum() - if n_atoms_per_batch[-1] < 0: - n_atoms_per_batch[-1] = 1 - + n_atoms_per_batch = torch.randint(num_particles-10, num_particles+10, size=(n_batches,)) + #Fix so that the total number of particles is correct. Special care if the difference is negative + difference = total_num_particles - n_atoms_per_batch.sum() + if n_atoms_per_batch[-1] + difference > 0: + n_atoms_per_batch[-1] += difference + else: + while difference < 0: + i = np.random.randint(0, n_batches) + if n_atoms_per_batch[i] > 2: + n_atoms_per_batch[i] -= 1 + difference += 1 lbox = np.cbrt(num_particles / density); - batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int32, device=device), torch.tensor(n_atoms_per_batch, dtype=torch.int32, device=device)) - + batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int32), n_atoms_per_batch).to(device) cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) pos = torch.rand(cumsum[-1], 3, device=device)*lbox - max_num_pairs = torch.tensor(expected_num_neighbors * n_atoms_per_batch.sum(), dtype=torch.int64).item() + max_num_pairs = (expected_num_neighbors * n_atoms_per_batch.sum()).item() nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=torch.Tensor([lbox, lbox, lbox])) #Warmup - neighbors, distances, distance_vecs = nl(pos, batch) + for i in range(10): + neighbors, distances, distance_vecs = nl(pos, batch) if device == 'cuda': torch.cuda.synchronize() #Benchmark using torch profiler @@ -62,7 +67,7 @@ def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_n if __name__ == '__main__': n_particles = 100000 - mean_num_neighbors = min(n_particles, 128); + mean_num_neighbors = min(n_particles, 16); print("Benchmarking neighbor list generation for {} particles with {} neighbors on average".format(n_particles, mean_num_neighbors)) for strategy in ['brute', 'cell']: print("Strategy: {}".format(strategy)) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index df9219852..1a757e8e1 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -15,10 +15,9 @@ def sort_neighbors(neighbors, deltas, distances): def test_neighbors(device, strategy, n_batches, cutoff): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") - np.random.seed(1234) torch.manual_seed(4321) - n_atoms_per_batch = np.random.randint(2, 10, size=n_batches) - batch = torch.tensor([i for i in range(n_batches) for j in range(n_atoms_per_batch[i])], device=device, dtype=torch.int) + n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) + batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int32), n_atoms_per_batch).to(device) cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) lbox=10.0 pos = torch.rand(cumsum[-1], 3, device=device)*lbox @@ -26,7 +25,7 @@ def test_neighbors(device, strategy, n_batches, cutoff): pos[0,:] = torch.zeros(3) pos[1,:] = torch.zeros(3) pos.requires_grad = True - ref_neighbors = np.concatenate([np.tril_indices(n_atoms_per_batch[i], -1)+cumsum[i] for i in range(n_batches)], axis=1) + ref_neighbors = np.concatenate([np.tril_indices(int(n_atoms_per_batch[i]), -1)+cumsum[i] for i in range(n_batches)], axis=1) pos_np = pos.cpu().detach().numpy() ref_distances = np.linalg.norm(pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]], axis=-1) ref_distance_vecs = pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]] From 810918607f7b1a656bbfac1cce8c25001e29fd42 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 3 May 2023 13:47:21 +0200 Subject: [PATCH 15/76] Add neighbor sources to setup.py --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 3ace81d3f..619d760de 100644 --- a/setup.py +++ b/setup.py @@ -15,5 +15,7 @@ name="torchmd-net", version=version, packages=find_packages(), + package_data={"torchmdnet": ["neighbors/neighbors*"]}, + include_package_data=True, entry_points={"console_scripts": ["torchmd-train = torchmdnet.scripts.train:main"]}, ) From f838068546d0cd3a322fd4c299cc14dce23853fb Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 3 May 2023 13:50:33 +0200 Subject: [PATCH 16/76] Remove unnecessary synchronization barrier --- torchmdnet/neighbors/neighbors_cuda_cell.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index 6bd1a4353..5d0ca113e 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -467,7 +467,6 @@ public: Tensor sorted_positions, hash_values; std::tie(sorted_positions, hash_values) = sortPositionsByHash(positions, batch, box_size, cutoff); - cudaDeviceSynchronize(); Tensor cell_start, cell_end; std::tie(cell_start, cell_end) = fillCellOffsets(sorted_positions, hash_values, batch, box_size, cutoff); From 97679724d3ccac207836b5f4fbb72e1094745f3f Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 3 May 2023 16:51:48 +0200 Subject: [PATCH 17/76] Add loop and cutoff_lower parameters --- tests/test_neighbors.py | 9 +- torchmdnet/models/utils.py | 24 +++- torchmdnet/neighbors/neighbors.cpp | 4 +- torchmdnet/neighbors/neighbors_cpu.cpp | 20 +++- torchmdnet/neighbors/neighbors_cuda.cu | 122 ++++++++++++++------ torchmdnet/neighbors/neighbors_cuda_cell.cu | 44 ++++--- 6 files changed, 153 insertions(+), 70 deletions(-) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 1a757e8e1..52460070b 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -12,7 +12,8 @@ def sort_neighbors(neighbors, deltas, distances): @pytest.mark.parametrize("strategy", ["brute", "cell"]) @pytest.mark.parametrize("n_batches", [1, 2, 3, 4, 128]) @pytest.mark.parametrize("cutoff", [0.1, 1.0, 1000.0]) -def test_neighbors(device, strategy, n_batches, cutoff): +@pytest.mark.parametrize("loop", [True, False]) +def test_neighbors(device, strategy, n_batches, cutoff, loop): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") torch.manual_seed(4321) @@ -26,6 +27,9 @@ def test_neighbors(device, strategy, n_batches, cutoff): pos[1,:] = torch.zeros(3) pos.requires_grad = True ref_neighbors = np.concatenate([np.tril_indices(int(n_atoms_per_batch[i]), -1)+cumsum[i] for i in range(n_batches)], axis=1) + if(loop): # Add self interactions + ilist=np.arange(cumsum[-1]) + ref_neighbors = np.concatenate([ref_neighbors, np.stack([ilist, ilist], axis=0)], axis=1) pos_np = pos.cpu().detach().numpy() ref_distances = np.linalg.norm(pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]], axis=-1) ref_distance_vecs = pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]] @@ -37,7 +41,8 @@ def test_neighbors(device, strategy, n_batches, cutoff): ref_distances = ref_distances[mask] max_num_pairs = ref_neighbors.shape[1] box = torch.tensor([lbox, lbox, lbox]) - nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box) + + nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box) batch.to(device) neighbors, distances, distance_vecs = nl(pos, batch) neighbors = neighbors.cpu().detach().numpy() diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 8b314ef55..8ccc04382 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -79,16 +79,20 @@ def message(self, x_j, W): from torchmdnet.neighbors import get_neighbor_pairs, get_neighbor_pairs_cell class DistanceCellList(torch.nn.Module): def __init__( - self, - cutoff_upper, - max_num_pairs=32, - strategy="cell", + self, + cutoff_upper, + cutoff_lower=0.0, + max_num_pairs=32, + loop=False, + strategy="cell", box=None ): super(DistanceCellList, self).__init__() """ Compute the neighbor list for a given cutoff. Parameters ---------- + cutoff_lower : float + Lower cutoff for the neighbor list. cutoff_upper : float Upper cutoff for the neighbor list. max_num_pairs : int @@ -98,12 +102,20 @@ def __init__( ["brute", "cell"]. box : torch.Tensor Size of the box shape (3,) or None + loop : bool + Whether to include self-interactions. """ self.cutoff_upper = cutoff_upper + self.cutoff_lower = cutoff_lower self.max_num_pairs = max_num_pairs self.strategy = strategy self.box = box + self.loop = loop + #Default the box to 3 times the cutoff + if self.box is None and self.strategy == "cell": + self.box = torch.tensor([cutoff_upper * 3] * 3) + def forward(self, pos, batch): """ @@ -129,7 +141,9 @@ def forward(self, pos, batch): function = get_neighbor_pairs if self.strategy == "brute" else get_neighbor_pairs_cell neighbors, distance_vecs, distances = function( pos, - cutoff=self.cutoff_upper, + cutoff_lower=self.cutoff_lower, + cutoff_upper=self.cutoff_upper, + loop=self.loop, batch=batch, max_num_pairs=self.max_num_pairs, check_errors=True, diff --git a/torchmdnet/neighbors/neighbors.cpp b/torchmdnet/neighbors/neighbors.cpp index cb2bfe2d3..7813f81f0 100644 --- a/torchmdnet/neighbors/neighbors.cpp +++ b/torchmdnet/neighbors/neighbors.cpp @@ -1,6 +1,6 @@ #include TORCH_LIBRARY(neighbors, m) { - m.def("get_neighbor_pairs(Tensor positions, Tensor batch, Tensor box_size, Scalar cutoff, Scalar max_num_pairs, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); - m.def("get_neighbor_pairs_cell(Tensor positions, Tensor batch, Tensor box_size,Scalar cutoff, Scalar max_num_pairs, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); + m.def("get_neighbor_pairs(Tensor positions, Tensor batch, Tensor box_size, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); + m.def("get_neighbor_pairs_cell(Tensor positions, Tensor batch, Tensor box_size, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); } diff --git a/torchmdnet/neighbors/neighbors_cpu.cpp b/torchmdnet/neighbors/neighbors_cpu.cpp index 954a0d5fe..04d7802ee 100644 --- a/torchmdnet/neighbors/neighbors_cpu.cpp +++ b/torchmdnet/neighbors/neighbors_cpu.cpp @@ -16,14 +16,15 @@ using torch::Tensor; using torch::outer; using torch::round; -static tuple forward(const Tensor& positions, const Tensor& batch, const Tensor &box_size, const Scalar& cutoff, - const Scalar& max_num_pairs, bool checkErrors) { +static tuple forward(const Tensor& positions, const Tensor& batch, const Tensor &box_size, + const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Scalar& max_num_pairs, bool loop, bool checkErrors) { TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); - TORCH_CHECK(cutoff.to() > 0, "Expected \"cutoff\" to be positive"); + TORCH_CHECK(cutoff_upper.to() > 0, "Expected \"cutoff\" to be positive"); auto box_vectors = torch::empty(0); if (box_vectors.size(0) != 0) { TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions"); @@ -33,7 +34,7 @@ static tuple forward(const Tensor& positions, const Tens for (int i = 0; i < 3; i++) for (int j = 0; j < 3; j++) v[i][j] = box_vectors[i][j].item(); - double c = cutoff.to(); + double c = cutoff_upper.to(); TORCH_CHECK(v[0][1] == 0, "Invalid box vectors: box_vectors[0][1] != 0"); TORCH_CHECK(v[0][2] == 0, "Invalid box vectors: box_vectors[0][2] != 0"); TORCH_CHECK(v[1][2] == 0, "Invalid box vectors: box_vectors[1][2] != 0"); @@ -73,9 +74,16 @@ static tuple forward(const Tensor& positions, const Tens Tensor neighbors_i = vstack({rows_i, columns_i}); Tensor deltas_i = index_select(positions_i, 0, rows_i) - index_select(positions_i, 0, columns_i); Tensor distances_i = frobenius_norm(deltas_i, 1); - const Tensor mask = distances_i <= cutoff; + const Tensor mask_upper = distances_i <= cutoff_upper; + const Tensor mask_lower = distances_i >= cutoff_lower; + const Tensor mask = mask_upper*mask_lower; neighbors_i = neighbors_i.index({Slice(), mask}) + current_offset; - n_pairs += distances_i.size(0); + //Add self interaction using batch_i + if(loop){ + const Tensor batch_i_tensor = torch::tensor(batch_i, kInt32); + neighbors_i = torch::hstack({neighbors_i, torch::stack({batch_i_tensor, batch_i_tensor})}); + } + n_pairs += neighbors_i.size(1); TORCH_CHECK(n_pairs >= 0, "The maximum number of pairs has been exceed! Increase \"max_num_neighbors\""); neighbors = torch::hstack({neighbors, neighbors_i}); current_offset += n_atoms_i; diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index 84aba9306..fb4cd2777 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -36,38 +36,73 @@ template <> __device__ __forceinline__ double sqrt_(double x) { return ::sqrt(x); }; +__device__ int32_t get_row(int index) { + int32_t row = floor((sqrtf(8 * index + 1) + 1) / 2); + if (row * (row - 1) > 2 * index) + row--; + return row; +} + template -__global__ void -forward_kernel(const int64_t num_all_pairs, const Accessor positions, - const Accessor batch, scalar_t cutoff2, - Accessor i_curr_pair, Accessor neighbors, - Accessor deltas, Accessor distances) { +__global__ void forward_kernel(const int64_t num_all_pairs, const Accessor positions, + const Accessor batch, scalar_t cutoff_lower2, + scalar_t cutoff_upper2, bool loop, Accessor i_curr_pair, + Accessor neighbors, Accessor deltas, + Accessor distances) { const int64_t index = blockIdx.x * blockDim.x + threadIdx.x; if (index >= num_all_pairs) return; - int32_t row = floor((sqrtf(8 * index + 1) + 1) / 2); - if (row * (row - 1) > 2 * index) - row--; + int32_t row = get_row(index); const int32_t column = (index - row * (row - 1) / 2); - if (batch[row] != batch[column]) - return; - scalar_t delta_x = positions[row][0] - positions[column][0]; - scalar_t delta_y = positions[row][1] - positions[column][1]; - scalar_t delta_z = positions[row][2] - positions[column][2]; - const scalar_t distance2 = delta_x * delta_x + delta_y * delta_y + delta_z * delta_z; - if (distance2 > cutoff2) - return; + if (batch[row] == batch[column]) { + scalar_t delta_x = positions[row][0] - positions[column][0]; + scalar_t delta_y = positions[row][1] - positions[column][1]; + scalar_t delta_z = positions[row][2] - positions[column][2]; + const scalar_t distance2 = delta_x * delta_x + delta_y * delta_y + delta_z * delta_z; + if (distance2 <= cutoff_upper2 && distance2 >= cutoff_lower2) { + const int32_t i_pair = atomicAdd(&i_curr_pair[0], 1); + // We handle too many neighbors outside of the kernel + if (i_pair < neighbors.size(1)) { + neighbors[0][i_pair] = row; + neighbors[1][i_pair] = column; + deltas[i_pair][0] = delta_x; + deltas[i_pair][1] = delta_y; + deltas[i_pair][2] = delta_z; + distances[i_pair] = sqrt_(distance2); + } + } + } + // If loop is true and this is the first thread dealing with particle "row" add the self + // interaction + // if (loop && ((column == 0) || (index == row * (row + 1) / 2))) { + // const int32_t i_pair = atomicAdd(&i_curr_pair[0], 1); + // if (i_pair < neighbors.size(1)) { + // neighbors[0][i_pair] = row; + // neighbors[1][i_pair] = row; + // deltas[i_pair][0] = 0; + // deltas[i_pair][1] = 0; + // deltas[i_pair][2] = 0; + // distances[i_pair] = 0; + // } + // } +} +template +__global__ void add_self_kernel(const int num_atoms, Accessor positions, + Accessor i_curr_pair, Accessor neighbors, + Accessor deltas, Accessor distances) { + const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; + if (i_atom >= num_atoms) + return; const int32_t i_pair = atomicAdd(&i_curr_pair[0], 1); - // We handle too many neighbors outside of the kernel if (i_pair < neighbors.size(1)) { - neighbors[0][i_pair] = row; - neighbors[1][i_pair] = column; - deltas[i_pair][0] = delta_x; - deltas[i_pair][1] = delta_y; - deltas[i_pair][2] = delta_z; - distances[i_pair] = sqrt_(distance2); + neighbors[0][i_pair] = i_atom; + neighbors[1][i_pair] = i_atom; + deltas[i_pair][0] = 0; + deltas[i_pair][1] = 0; + deltas[i_pair][2] = 0; + distances[i_pair] = 0; } } @@ -118,8 +153,8 @@ static void checkInput(const Tensor& positions, const Tensor& batch) { class Autograd : public Function { public: static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, - const Scalar& cutoff, const Scalar& max_num_pairs, - bool checkErrors) { + const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Scalar& max_num_pairs, bool loop, bool checkErrors) { checkInput(positions, batch); const auto max_num_pairs_ = max_num_pairs.toLong(); TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); @@ -139,13 +174,25 @@ public: const int64_t num_blocks = max((num_all_pairs + num_threads - 1) / num_threads, 1l); AT_DISPATCH_FLOATING_TYPES( positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { - const scalar_t cutoff_ = cutoff.to(); - TORCH_CHECK(cutoff_ > 0, "Expected \"cutoff\" to be positive"); + const scalar_t cutoff_upper_ = cutoff_upper.to(); + const scalar_t cutoff_lower_ = cutoff_lower.to(); + TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); forward_kernel<<>>( num_all_pairs, get_accessor(positions), - get_accessor(batch), cutoff_ * cutoff_, - get_accessor(i_curr_pair), get_accessor(neighbors), - get_accessor(deltas), get_accessor(distances)); + get_accessor(batch), cutoff_lower_ * cutoff_lower_, + cutoff_upper_ * cutoff_upper_, loop, get_accessor(i_curr_pair), + get_accessor(neighbors), get_accessor(deltas), + get_accessor(distances)); + if (loop) { + const int64_t num_threads = 128; + const int64_t num_blocks = + max((num_atoms + num_threads - 1) / num_threads, 1l); + add_self_kernel<<>>( + num_atoms, get_accessor(positions), + get_accessor(i_curr_pair), + get_accessor(neighbors), get_accessor(deltas), + get_accessor(distances)); + } }); } // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA @@ -194,11 +241,12 @@ public: }; TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { - m.impl("get_neighbor_pairs", - [](const Tensor& positions, const Tensor& batch, const Tensor& box_size, - const Scalar& cutoff, const Scalar& max_num_pairs, bool checkErrors) { - const tensor_list results = - Autograd::apply(positions, batch, cutoff, max_num_pairs, checkErrors); - return std::make_tuple(results[0], results[1], results[2]); - }); + m.impl("get_neighbor_pairs", [](const Tensor& positions, const Tensor& batch, + const Tensor& box_size, const Scalar& cutoff_lower, + const Scalar& cutoff_upper, const Scalar& max_num_pairs, + bool loop, bool checkErrors) { + const tensor_list results = Autograd::apply(positions, batch, cutoff_lower, cutoff_upper, + max_num_pairs, loop, checkErrors); + return std::make_tuple(results[0], results[1], results[2]); + }); } diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index 5d0ca113e..d8ca5a9cf 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -390,7 +390,8 @@ forward_kernel(const Accessor sorted_positions, const Accessor cell_start, const Accessor cell_end, Accessor neighbors, Accessor deltas, Accessor distances, Accessor i_curr_pair, int num_atoms, - int num_pairs, scalar3 box_size, scalar_t cutoff) { + int num_pairs, scalar3 box_size, scalar_t cutoff_lower, + scalar_t cutoff_upper, bool loop) { // Each atom traverses the cells around it and finds the neighbors // Atoms for all batches are placed in the same cell list, but other batches are ignored while // traversing @@ -401,8 +402,8 @@ forward_kernel(const Accessor sorted_positions, const int i_batch = batch[ori]; const scalar3 pi = {sorted_positions[i_atom][0], sorted_positions[i_atom][1], sorted_positions[i_atom][2]}; - const int3 cell_i = getCell(pi, box_size, cutoff); - const int3 cell_dim = getCellDimensions(box_size, cutoff); + const int3 cell_i = getCell(pi, box_size, cutoff_upper); + const int3 cell_dim = getCellDimensions(box_size, cutoff_upper); // Loop over the 27 cells around the current cell for (int i = 0; i < 27; i++) { int icellj = getNeighborCellIndex(cell_i, i, cell_dim); @@ -418,7 +419,7 @@ forward_kernel(const Accessor sorted_positions, if (j_batch > i_batch) // Particles are sorted by batch after cell, so we can break early here break; - if (orj < ori and j_batch == i_batch) { + if ((orj < ori and j_batch == i_batch) or (loop and orj == ori)) { const scalar3 pj = {sorted_positions[cur_j][0], sorted_positions[cur_j][1], sorted_positions[cur_j][2]}; @@ -426,7 +427,10 @@ forward_kernel(const Accessor sorted_positions, const scalar_t dy = pi.y - pj.y; const scalar_t dz = pi.z - pj.z; const scalar_t distance2 = dx * dx + dy * dy + dz * dz; - if (distance2 < cutoff * cutoff) { + const scalar_t cutoff_upper2 = cutoff_upper * cutoff_upper; + const scalar_t cutoff_lower2 = cutoff_lower * cutoff_lower; + if ((distance2 <= cutoff_upper2 and distance2 >= cutoff_lower2) or + (loop and orj == ori)) { const int32_t i_pair = atomicAdd(&i_curr_pair[0], 1); // We handle too many neighbors outside of the kernel if (i_pair < neighbors.size(1)) { @@ -447,8 +451,9 @@ forward_kernel(const Accessor sorted_positions, class Autograd : public Function { public: static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, - const Tensor& box_size, const Scalar& cutoff, - const Scalar& max_num_pairs, bool checkErrors) { + const Tensor& box_size, const Scalar& cutoff_lower, + const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop, + bool checkErrors) { // The algorithm for the cell list construction can be summarized in three separate steps: // 1. Hash (label) the particles according to the cell (bin) they lie in. // 2. Sort the particles and hashes using the hashes as the ordering label @@ -466,10 +471,10 @@ public: // Steps 1 and 2 Tensor sorted_positions, hash_values; std::tie(sorted_positions, hash_values) = - sortPositionsByHash(positions, batch, box_size, cutoff); + sortPositionsByHash(positions, batch, box_size, cutoff_upper); Tensor cell_start, cell_end; std::tie(cell_start, cell_end) = - fillCellOffsets(sorted_positions, hash_values, batch, box_size, cutoff); + fillCellOffsets(sorted_positions, hash_values, batch, box_size, cutoff_upper); const Tensor neighbors = full({2, num_pairs}, -1, options.dtype(kInt32)); const Tensor deltas = empty({num_pairs, 3}, options); const Tensor distances = full(num_pairs, 0, options); @@ -478,7 +483,9 @@ public: { // Use the cell list for each batch to find the neighbors const CUDAStreamGuard guard(stream); AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "forward", [&] { - const scalar_t cutoff_ = cutoff.to(); + const scalar_t cutoff_upper_ = cutoff_upper.to(); + TORCH_CHECK(cutoff_upper_ > 0, "Expected cutoff_upper to be positive"); + const scalar_t cutoff_lower_ = cutoff_lower.to(); const scalar3 box_size_ = {box_size[0].item(), box_size[1].item(), box_size[2].item()}; @@ -490,7 +497,7 @@ public: get_accessor(cell_start), get_accessor(cell_end), get_accessor(neighbors), get_accessor(deltas), get_accessor(distances), get_accessor(i_curr_pair), - num_atoms, num_pairs, box_size_, cutoff_); + num_atoms, num_pairs, box_size_, cutoff_lower_, cutoff_upper_, loop); }); } // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA @@ -539,11 +546,12 @@ public: }; TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { - m.impl("get_neighbor_pairs_cell", - [](const Tensor& positions, const Tensor& batch, const Tensor& box_size, - const Scalar& cutoff, const Scalar& max_num_pairs, bool checkErrors) { - const tensor_list results = - Autograd::apply(positions, batch, box_size, cutoff, max_num_pairs, checkErrors); - return std::make_tuple(results[0], results[1], results[2]); - }); + m.impl("get_neighbor_pairs_cell", [](const Tensor& positions, const Tensor& batch, + const Tensor& box_size, const Scalar& cutoff_lower, + const Scalar& cutoff_upper, const Scalar& max_num_pairs, + bool loop, bool checkErrors) { + const tensor_list results = Autograd::apply(positions, batch, box_size, cutoff_lower, + cutoff_upper, max_num_pairs, loop, checkErrors); + return std::make_tuple(results[0], results[1], results[2]); + }); } From 6f1645df2771ee92fbefe894d5df63b88649e0e6 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 3 May 2023 16:55:04 +0200 Subject: [PATCH 18/76] Added the return_vecs option for api compatibility with Distance --- tests/test_neighbors.py | 2 +- torchmdnet/models/utils.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 52460070b..32c630231 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -42,7 +42,7 @@ def test_neighbors(device, strategy, n_batches, cutoff, loop): max_num_pairs = ref_neighbors.shape[1] box = torch.tensor([lbox, lbox, lbox]) - nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box) + nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True) batch.to(device) neighbors, distances, distance_vecs = nl(pos, batch) neighbors = neighbors.cpu().detach().numpy() diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 8ccc04382..2b46067bd 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -83,6 +83,7 @@ def __init__( cutoff_upper, cutoff_lower=0.0, max_num_pairs=32, + return_vecs=False, loop=False, strategy="cell", box=None @@ -104,6 +105,8 @@ def __init__( Size of the box shape (3,) or None loop : bool Whether to include self-interactions. + return_vecs : bool + Whether to return the distance vectors. """ self.cutoff_upper = cutoff_upper @@ -149,7 +152,10 @@ def forward(self, pos, batch): check_errors=True, box_size=self.box ) - return neighbors, distances, distance_vecs + if self.return_vecs: + return neighbors, distances, distance_vecs + else: + return neighbors, distances, None class GaussianSmearing(nn.Module): def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True): From 9cb9ca6c1b3ae8d88db52c4f4b870142d56ad4b6 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 3 May 2023 16:56:35 +0200 Subject: [PATCH 19/76] Fix typo --- torchmdnet/models/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 2b46067bd..be6d55653 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -115,6 +115,7 @@ def __init__( self.strategy = strategy self.box = box self.loop = loop + self.return_vecs = return_vecs #Default the box to 3 times the cutoff if self.box is None and self.strategy == "cell": self.box = torch.tensor([cutoff_upper * 3] * 3) From 9acb8b025da0c925a50981d31b9296a38e2f47b6 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 3 May 2023 17:26:19 +0200 Subject: [PATCH 20/76] Remove (-1,-1) pairs python-side --- torchmdnet/models/utils.py | 12 +++++++++--- torchmdnet/neighbors/neighbors_cuda.cu | 3 --- torchmdnet/neighbors/neighbors_cuda_cell.cu | 3 --- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index be6d55653..89c34d1e7 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -133,13 +133,13 @@ def forward(self, pos, batch): ------- neighbors : torch.Tensor List of neighbors for each atom in the batch. - shape (2, max_num_pairs) + shape (2, num_found_pairs) distances : torch.Tensor List of distances for each atom in the batch. - shape (max_num_pairs,) + shape (num_found_pairs,) distance_vecs : torch.Tensor List of distance vectors for each atom in the batch. - shape (max_num_pairs, 3) + shape (num_found_pairs, 3) """ function = get_neighbor_pairs if self.strategy == "brute" else get_neighbor_pairs_cell @@ -153,6 +153,12 @@ def forward(self, pos, batch): check_errors=True, box_size=self.box ) + #Remove (-1,-1) pairs + mask = neighbors[0] != -1 + neighbors = neighbors[:, mask] + distances = distances[mask] + distance_vecs = distance_vecs[mask,:] + if self.return_vecs: return neighbors, distances, distance_vecs else: diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index fb4cd2777..a6488aaaa 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -204,9 +204,6 @@ public: std::to_string(max_num_pairs_), " but found " + std::to_string(num_found_pairs)); } - neighbors.resize_({2, i_curr_pair[0].item()}); - deltas.resize_({i_curr_pair[0].item(), 3}); - distances.resize_(i_curr_pair[0].item()); ctx->save_for_backward({neighbors, deltas, distances}); ctx->saved_data["num_atoms"] = num_atoms; return {neighbors, deltas, distances}; diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index d8ca5a9cf..06ae31f81 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -509,9 +509,6 @@ public: std::to_string(max_num_pairs_), " but found " + std::to_string(num_found_pairs)); } - neighbors.resize_({2, i_curr_pair[0].item()}); - deltas.resize_({i_curr_pair[0].item(), 3}); - distances.resize_(i_curr_pair[0].item()); ctx->save_for_backward({neighbors, deltas, distances}); ctx->saved_data["num_atoms"] = num_atoms; return {neighbors, deltas, distances}; From 9f42956556a472c14306317543879a3bad6e1e3c Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 4 May 2023 10:16:59 +0200 Subject: [PATCH 21/76] Return i,j and j,i pairs to mimic Distance (probably will put it behind an option) Add test to check identical outputs compared to Distance --- tests/test_neighbors.py | 83 +++++++++++++++++---- torchmdnet/neighbors/neighbors_cpu.cpp | 2 + torchmdnet/neighbors/neighbors_cuda.cu | 11 ++- torchmdnet/neighbors/neighbors_cuda_cell.cu | 11 ++- 4 files changed, 88 insertions(+), 19 deletions(-) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 32c630231..1c1b56c94 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -2,12 +2,35 @@ import pytest import torch import numpy as np -from torchmdnet.models.utils import DistanceCellList +from torchmdnet.models.utils import Distance, DistanceCellList def sort_neighbors(neighbors, deltas, distances): i_sorted = np.lexsort(neighbors) return neighbors[:, i_sorted], deltas[i_sorted], distances[i_sorted] + +def compute_ref_neighbors(pos, batch, loop, cutoff): + batch = batch.cpu() + n_atoms_per_batch = torch.bincount(batch) + n_batches = n_atoms_per_batch.shape[0] + cumsum = torch.cumsum(torch.cat([torch.tensor([0]), n_atoms_per_batch]), dim=0).cpu().detach().numpy() + ref_neighbors = np.concatenate([np.tril_indices(int(n_atoms_per_batch[i]), -1)+cumsum[i] for i in range(n_batches)], axis=1) + #add the upper triangle + ref_neighbors = np.concatenate([ref_neighbors, np.flip(ref_neighbors, axis=0)], axis=1) + if(loop): # Add self interactions + ilist=np.arange(cumsum[-1]) + ref_neighbors = np.concatenate([ref_neighbors, np.stack([ilist, ilist], axis=0)], axis=1) + pos_np = pos.cpu().detach().numpy() + ref_distances = np.linalg.norm(pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]], axis=-1) + ref_distance_vecs = pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]] + #remove pairs with distance > cutoff + mask = ref_distances < cutoff + ref_neighbors = ref_neighbors[:, mask] + ref_distance_vecs = ref_distance_vecs[mask] + ref_distances = ref_distances[mask] + ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors(ref_neighbors, ref_distance_vecs, ref_distances) + return ref_neighbors, ref_distance_vecs, ref_distances + @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("strategy", ["brute", "cell"]) @pytest.mark.parametrize("n_batches", [1, 2, 3, 4, 128]) @@ -26,19 +49,7 @@ def test_neighbors(device, strategy, n_batches, cutoff, loop): pos[0,:] = torch.zeros(3) pos[1,:] = torch.zeros(3) pos.requires_grad = True - ref_neighbors = np.concatenate([np.tril_indices(int(n_atoms_per_batch[i]), -1)+cumsum[i] for i in range(n_batches)], axis=1) - if(loop): # Add self interactions - ilist=np.arange(cumsum[-1]) - ref_neighbors = np.concatenate([ref_neighbors, np.stack([ilist, ilist], axis=0)], axis=1) - pos_np = pos.cpu().detach().numpy() - ref_distances = np.linalg.norm(pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]], axis=-1) - ref_distance_vecs = pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]] - ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors(ref_neighbors, ref_distance_vecs, ref_distances) - #remove pairs with distance > cutoff - mask = ref_distances < cutoff - ref_neighbors = ref_neighbors[:, mask] - ref_distance_vecs = ref_distance_vecs[mask] - ref_distances = ref_distances[mask] + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, cutoff) max_num_pairs = ref_neighbors.shape[1] box = torch.tensor([lbox, lbox, lbox]) @@ -48,9 +59,53 @@ def test_neighbors(device, strategy, n_batches, cutoff, loop): neighbors = neighbors.cpu().detach().numpy() distance_vecs = distance_vecs.cpu().detach().numpy() distances = distances.cpu().detach().numpy() + neighbors, distance_vecs, distances = sort_neighbors(neighbors, distance_vecs, distances) assert neighbors.shape == (2, max_num_pairs) assert distances.shape == (max_num_pairs,) assert distance_vecs.shape == (max_num_pairs, 3) + + assert np.allclose(neighbors, ref_neighbors) + assert np.allclose(distances, ref_distances) + assert np.allclose(distance_vecs, ref_distance_vecs) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("strategy", ["brute", "cell"]) +@pytest.mark.parametrize("n_batches", [1, 2, 3, 4]) +@pytest.mark.parametrize("cutoff", [0.1, 1.0, 1000.0]) +@pytest.mark.parametrize("loop", [True, False]) +def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + torch.manual_seed(4321) + n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) + batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.long), n_atoms_per_batch).to(device) + cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) + lbox=10.0 + pos = torch.rand(cumsum[-1], 3, device=device)*lbox + #Ensure there is at least one pair + pos[0,:] = torch.zeros(3) + pos[1,:] = torch.zeros(3) + pos.requires_grad = True + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, cutoff) + #Find the particle appearing in the most pairs + max_num_neighbors = torch.max(torch.bincount(torch.tensor(ref_neighbors[0,:]))) + d = Distance(cutoff_lower=0.0, cutoff_upper=cutoff, loop=loop, max_num_neighbors=max_num_neighbors, return_vecs=True) + ref_neighbors, ref_distances, ref_distance_vecs = d(pos, batch) + ref_neighbors = ref_neighbors.cpu().detach().numpy() + ref_distance_vecs = ref_distance_vecs.cpu().detach().numpy() + ref_distances = ref_distances.cpu().detach().numpy() + ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors(ref_neighbors, ref_distance_vecs, ref_distances) + + max_num_pairs = ref_neighbors.shape[1] + box = torch.tensor([lbox, lbox, lbox]) + nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True) + batch = batch.to(torch.int32).to(device) + neighbors, distances, distance_vecs = nl(pos, batch) + + neighbors = neighbors.cpu().detach().numpy() + distance_vecs = distance_vecs.cpu().detach().numpy() + distances = distances.cpu().detach().numpy() neighbors, distance_vecs, distances = sort_neighbors(neighbors, distance_vecs, distances) assert np.allclose(neighbors, ref_neighbors) assert np.allclose(distances, ref_distances) diff --git a/torchmdnet/neighbors/neighbors_cpu.cpp b/torchmdnet/neighbors/neighbors_cpu.cpp index 04d7802ee..fcae246d5 100644 --- a/torchmdnet/neighbors/neighbors_cpu.cpp +++ b/torchmdnet/neighbors/neighbors_cpu.cpp @@ -78,6 +78,8 @@ static tuple forward(const Tensor& positions, const Tens const Tensor mask_lower = distances_i >= cutoff_lower; const Tensor mask = mask_upper*mask_lower; neighbors_i = neighbors_i.index({Slice(), mask}) + current_offset; + //Add the transposed pairs + neighbors_i = torch::hstack({neighbors_i, torch::stack({neighbors_i[1], neighbors_i[0]})}); //Add self interaction using batch_i if(loop){ const Tensor batch_i_tensor = torch::tensor(batch_i, kInt32); diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index a6488aaaa..1146faae3 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -61,15 +61,22 @@ __global__ void forward_kernel(const int64_t num_all_pairs, const Accessor= cutoff_lower2) { - const int32_t i_pair = atomicAdd(&i_curr_pair[0], 1); + const int32_t i_pair = atomicAdd(&i_curr_pair[0], 2); // We handle too many neighbors outside of the kernel if (i_pair < neighbors.size(1)) { + const scalar_t r2 = sqrt_(distance2); neighbors[0][i_pair] = row; neighbors[1][i_pair] = column; deltas[i_pair][0] = delta_x; deltas[i_pair][1] = delta_y; deltas[i_pair][2] = delta_z; - distances[i_pair] = sqrt_(distance2); + distances[i_pair] = r2; + neighbors[0][i_pair+1] = column; + neighbors[1][i_pair+1] = row; + deltas[i_pair+1][0] = -delta_x; + deltas[i_pair+1][1] = -delta_y; + deltas[i_pair+1][2] = -delta_z; + distances[i_pair+1] = r2; } } } diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index 06ae31f81..7a0413f03 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -391,7 +391,7 @@ forward_kernel(const Accessor sorted_positions, Accessor neighbors, Accessor deltas, Accessor distances, Accessor i_curr_pair, int num_atoms, int num_pairs, scalar3 box_size, scalar_t cutoff_lower, - scalar_t cutoff_upper, bool loop) { + scalar_t cutoff_upper, bool loop, bool include_traspose) { // Each atom traverses the cells around it and finds the neighbors // Atoms for all batches are placed in the same cell list, but other batches are ignored while // traversing @@ -419,7 +419,10 @@ forward_kernel(const Accessor sorted_positions, if (j_batch > i_batch) // Particles are sorted by batch after cell, so we can break early here break; - if ((orj < ori and j_batch == i_batch) or (loop and orj == ori)) { + const bool includePair = + (j_batch == i_batch) and + ((orj != ori and (orj < ori or include_traspose)) or (loop and orj == ori)); + if (includePair) { const scalar3 pj = {sorted_positions[cur_j][0], sorted_positions[cur_j][1], sorted_positions[cur_j][2]}; @@ -491,13 +494,15 @@ public: box_size[2].item()}; const int threads = 128; const int blocks = (num_atoms + threads - 1) / threads; + bool include_traspose = true; forward_kernel<<>>( get_accessor(sorted_positions), get_accessor(hash_values), get_accessor(batch), get_accessor(cell_start), get_accessor(cell_end), get_accessor(neighbors), get_accessor(deltas), get_accessor(distances), get_accessor(i_curr_pair), - num_atoms, num_pairs, box_size_, cutoff_lower_, cutoff_upper_, loop); + num_atoms, num_pairs, box_size_, cutoff_lower_, cutoff_upper_, loop, + include_traspose); }); } // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA From e6004dd4561fde580567521779f14468cd36f97f Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 4 May 2023 10:25:47 +0200 Subject: [PATCH 22/76] Use kLong for batch --- tests/test_neighbors.py | 4 +--- torchmdnet/neighbors/neighbors_cuda.cu | 6 +++--- torchmdnet/neighbors/neighbors_cuda_cell.cu | 24 ++++++++++----------- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 1c1b56c94..67870de8a 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -41,7 +41,7 @@ def test_neighbors(device, strategy, n_batches, cutoff, loop): pytest.skip("CUDA not available") torch.manual_seed(4321) n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) - batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int32), n_atoms_per_batch).to(device) + batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch).to(device) cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) lbox=10.0 pos = torch.rand(cumsum[-1], 3, device=device)*lbox @@ -100,9 +100,7 @@ def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop): max_num_pairs = ref_neighbors.shape[1] box = torch.tensor([lbox, lbox, lbox]) nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True) - batch = batch.to(torch.int32).to(device) neighbors, distances, distance_vecs = nl(pos, batch) - neighbors = neighbors.cpu().detach().numpy() distance_vecs = distance_vecs.cpu().detach().numpy() distances = distances.cpu().detach().numpy() diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index 1146faae3..2f1ce5514 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -45,7 +45,7 @@ __device__ int32_t get_row(int index) { template __global__ void forward_kernel(const int64_t num_all_pairs, const Accessor positions, - const Accessor batch, scalar_t cutoff_lower2, + const Accessor batch, scalar_t cutoff_lower2, scalar_t cutoff_upper2, bool loop, Accessor i_curr_pair, Accessor neighbors, Accessor deltas, Accessor distances) { @@ -154,7 +154,7 @@ static void checkInput(const Tensor& positions, const Tensor& batch) { "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension " "size of \"positions\""); TORCH_CHECK(batch.is_contiguous(), "Expected \"batch\" to be contiguous"); - TORCH_CHECK(batch.dtype() == torch::kInt32, "Expected \"batch\" to have torch::kInt32 dtype"); + TORCH_CHECK(batch.dtype() == torch::kInt64, "Expected \"batch\" to be of type torch::kLong"); } class Autograd : public Function { @@ -186,7 +186,7 @@ public: TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); forward_kernel<<>>( num_all_pairs, get_accessor(positions), - get_accessor(batch), cutoff_lower_ * cutoff_lower_, + get_accessor(batch), cutoff_lower_ * cutoff_lower_, cutoff_upper_ * cutoff_upper_, loop, get_accessor(i_curr_pair), get_accessor(neighbors), get_accessor(deltas), get_accessor(distances)); diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index 7a0413f03..c00dd0e41 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -78,7 +78,7 @@ static void checkInput(const Tensor& positions, const Tensor& batch) { // Batch is a 1D tensor of size (N_atoms) // Batch is assumed to be sorted and starts at zero. // Batch is assumed to be contiguous - // Batch is assumed to be of type torch::kInt32 + // Batch is assumed to be of type torch::kLong // Batch is assumed to be non-negative // Each batch can have a different number of atoms TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); @@ -92,7 +92,7 @@ static void checkInput(const Tensor& positions, const Tensor& batch) { "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension " "size of \"positions\""); TORCH_CHECK(batch.is_contiguous(), "Expected \"batch\" to be contiguous"); - TORCH_CHECK(batch.dtype() == torch::kInt32, "Expected \"batch\" to be of type torch::kInt32"); + TORCH_CHECK(batch.dtype() == torch::kInt64, "Expected \"batch\" to be of type torch::kLong"); } /* @@ -219,7 +219,7 @@ __device__ int3 getPeriodicCell(int3 cell, int3 cell_dim) { // This hash is such that atoms in the same cell and batch have the same hash. template __global__ void assignHash(const Accessor positions, uint64_t* hash_keys, - Accessor hash_values, const Accessor batch, + Accessor hash_values, const Accessor batch, scalar3 box_size, scalar_t cutoff, int32_t num_atoms) { const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; if (i_atom >= num_atoms) @@ -278,7 +278,7 @@ static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, box_size[2].item()}; assignHash<<>>( get_accessor(positions), thrust::raw_pointer_cast(hash_keys.data()), - get_accessor(hash_values), get_accessor(batch), box_size_, + get_accessor(hash_values), get_accessor(batch), box_size_, cutoff_, num_atoms); }); thrust::device_ptr index_ptr(hash_values.data_ptr()); @@ -293,7 +293,7 @@ template __global__ void fillCellOffsetsD(const Accessor sorted_positions, const Accessor sorted_indices, Accessor cell_start, Accessor cell_end, - const Accessor batch, scalar3 box_size, + const Accessor batch, scalar3 box_size, scalar_t cutoff) { // Since positions are sorted by cell, for a given atom, if the previous atom is in a different // cell, then the current atom is the first atom in its cell We use this fact to fill the @@ -360,7 +360,7 @@ static auto fillCellOffsets(const Tensor& sorted_positions, const Tensor& sorted fillCellOffsetsD<<>>( get_accessor(sorted_positions), get_accessor(sorted_indices), get_accessor(cell_start), get_accessor(cell_end), - get_accessor(batch), box_size_, cutoff_); + get_accessor(batch), box_size_, cutoff_); }); return std::make_tuple(cell_start, cell_end); } @@ -386,7 +386,7 @@ __device__ int getNeighborCellIndex(int3 cell_i, int i, int3 cell_dim) { template __global__ void forward_kernel(const Accessor sorted_positions, - const Accessor original_index, const Accessor batch, + const Accessor original_index, const Accessor batch, const Accessor cell_start, const Accessor cell_end, Accessor neighbors, Accessor deltas, Accessor distances, Accessor i_curr_pair, int num_atoms, @@ -399,7 +399,7 @@ forward_kernel(const Accessor sorted_positions, if (i_atom >= num_atoms) return; const int ori = original_index[i_atom]; - const int i_batch = batch[ori]; + const auto i_batch = batch[ori]; const scalar3 pi = {sorted_positions[i_atom][0], sorted_positions[i_atom][1], sorted_positions[i_atom][2]}; const int3 cell_i = getCell(pi, box_size, cutoff_upper); @@ -413,9 +413,9 @@ forward_kernel(const Accessor sorted_positions, const int lastParticle = cell_end[icellj]; const int nincell = lastParticle - firstParticle; for (int j = 0; j < nincell; j++) { - int cur_j = j + firstParticle; - int orj = original_index[cur_j]; - int j_batch = batch[orj]; + const int cur_j = j + firstParticle; + const int orj = original_index[cur_j]; + const auto j_batch = batch[orj]; if (j_batch > i_batch) // Particles are sorted by batch after cell, so we can break early here break; @@ -497,7 +497,7 @@ public: bool include_traspose = true; forward_kernel<<>>( get_accessor(sorted_positions), - get_accessor(hash_values), get_accessor(batch), + get_accessor(hash_values), get_accessor(batch), get_accessor(cell_start), get_accessor(cell_end), get_accessor(neighbors), get_accessor(deltas), get_accessor(distances), get_accessor(i_curr_pair), From 493a9ae4198b7a4153db68566416369e89ecbc14 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 4 May 2023 12:08:39 +0200 Subject: [PATCH 23/76] Add options resize_to_fit (whether to trim -1,-1 pairs) and include_traspose (whether to include redundant pairs) --- tests/test_neighbors.py | 18 +-- torchmdnet/models/utils.py | 29 +++-- torchmdnet/neighbors/neighbors.cpp | 4 +- torchmdnet/neighbors/neighbors_cpu.cpp | 115 +++++++++++--------- torchmdnet/neighbors/neighbors_cuda.cu | 50 +++++---- torchmdnet/neighbors/neighbors_cuda_cell.cu | 38 +++---- 6 files changed, 142 insertions(+), 112 deletions(-) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 67870de8a..f6cfe95fe 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -9,14 +9,15 @@ def sort_neighbors(neighbors, deltas, distances): return neighbors[:, i_sorted], deltas[i_sorted], distances[i_sorted] -def compute_ref_neighbors(pos, batch, loop, cutoff): +def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff): batch = batch.cpu() n_atoms_per_batch = torch.bincount(batch) n_batches = n_atoms_per_batch.shape[0] cumsum = torch.cumsum(torch.cat([torch.tensor([0]), n_atoms_per_batch]), dim=0).cpu().detach().numpy() ref_neighbors = np.concatenate([np.tril_indices(int(n_atoms_per_batch[i]), -1)+cumsum[i] for i in range(n_batches)], axis=1) - #add the upper triangle - ref_neighbors = np.concatenate([ref_neighbors, np.flip(ref_neighbors, axis=0)], axis=1) + # add the upper triangle + if(include_transpose): + ref_neighbors = np.concatenate([ref_neighbors, np.flip(ref_neighbors, axis=0)], axis=1) if(loop): # Add self interactions ilist=np.arange(cumsum[-1]) ref_neighbors = np.concatenate([ref_neighbors, np.stack([ilist, ilist], axis=0)], axis=1) @@ -36,7 +37,8 @@ def compute_ref_neighbors(pos, batch, loop, cutoff): @pytest.mark.parametrize("n_batches", [1, 2, 3, 4, 128]) @pytest.mark.parametrize("cutoff", [0.1, 1.0, 1000.0]) @pytest.mark.parametrize("loop", [True, False]) -def test_neighbors(device, strategy, n_batches, cutoff, loop): +@pytest.mark.parametrize("include_transpose", [True, False]) +def test_neighbors(device, strategy, n_batches, cutoff, loop, include_transpose): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") torch.manual_seed(4321) @@ -49,11 +51,11 @@ def test_neighbors(device, strategy, n_batches, cutoff, loop): pos[0,:] = torch.zeros(3) pos[1,:] = torch.zeros(3) pos.requires_grad = True - ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, cutoff) + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff) max_num_pairs = ref_neighbors.shape[1] box = torch.tensor([lbox, lbox, lbox]) - nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True) + nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True, include_transpose=include_transpose) batch.to(device) neighbors, distances, distance_vecs = nl(pos, batch) neighbors = neighbors.cpu().detach().numpy() @@ -87,7 +89,7 @@ def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop): pos[0,:] = torch.zeros(3) pos[1,:] = torch.zeros(3) pos.requires_grad = True - ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, cutoff) + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, True, cutoff) #Find the particle appearing in the most pairs max_num_neighbors = torch.max(torch.bincount(torch.tensor(ref_neighbors[0,:]))) d = Distance(cutoff_lower=0.0, cutoff_upper=cutoff, loop=loop, max_num_neighbors=max_num_neighbors, return_vecs=True) @@ -99,7 +101,7 @@ def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop): max_num_pairs = ref_neighbors.shape[1] box = torch.tensor([lbox, lbox, lbox]) - nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True) + nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True, include_transpose=True) neighbors, distances, distance_vecs = nl(pos, batch) neighbors = neighbors.cpu().detach().numpy() distance_vecs = distance_vecs.cpu().detach().numpy() diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 89c34d1e7..726c11192 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -85,7 +85,9 @@ def __init__( max_num_pairs=32, return_vecs=False, loop=False, - strategy="cell", + strategy="brute", + include_transpose=True, + resize_to_fit=True, box=None ): super(DistanceCellList, self).__init__() @@ -105,6 +107,10 @@ def __init__( Size of the box shape (3,) or None loop : bool Whether to include self-interactions. + include_transpose : bool + Whether to include the transpose of the neighbor list. + resize_to_fit : bool + Whether to resize the neighbor list to the actual number of pairs found. return_vecs : bool Whether to return the distance vectors. @@ -116,6 +122,8 @@ def __init__( self.box = box self.loop = loop self.return_vecs = return_vecs + self.include_transpose = include_transpose + self.resize_to_fit = resize_to_fit #Default the box to 3 times the cutoff if self.box is None and self.strategy == "cell": self.box = torch.tensor([cutoff_upper * 3] * 3) @@ -133,13 +141,16 @@ def forward(self, pos, batch): ------- neighbors : torch.Tensor List of neighbors for each atom in the batch. - shape (2, num_found_pairs) + shape (2, num_found_pairs or max_num_pairs) distances : torch.Tensor List of distances for each atom in the batch. - shape (num_found_pairs,) + shape (num_found_pairs or max_num_pairs,) distance_vecs : torch.Tensor List of distance vectors for each atom in the batch. - shape (num_found_pairs, 3) + shape (num_found_pairs or max_num_pairs, 3) + + If resize_to_fit is True, the tensors will be trimmed to the actual number of pairs found. + otherwise the tensors will have size max_num_pairs, with neighbor pairs (-1, -1) at the end. """ function = get_neighbor_pairs if self.strategy == "brute" else get_neighbor_pairs_cell @@ -151,13 +162,15 @@ def forward(self, pos, batch): batch=batch, max_num_pairs=self.max_num_pairs, check_errors=True, + include_transpose=self.include_transpose, box_size=self.box ) #Remove (-1,-1) pairs - mask = neighbors[0] != -1 - neighbors = neighbors[:, mask] - distances = distances[mask] - distance_vecs = distance_vecs[mask,:] + if self.resize_to_fit: + mask = neighbors[0] != -1 + neighbors = neighbors[:, mask] + distances = distances[mask] + distance_vecs = distance_vecs[mask,:] if self.return_vecs: return neighbors, distances, distance_vecs diff --git a/torchmdnet/neighbors/neighbors.cpp b/torchmdnet/neighbors/neighbors.cpp index 7813f81f0..64c4aefb5 100644 --- a/torchmdnet/neighbors/neighbors.cpp +++ b/torchmdnet/neighbors/neighbors.cpp @@ -1,6 +1,6 @@ #include TORCH_LIBRARY(neighbors, m) { - m.def("get_neighbor_pairs(Tensor positions, Tensor batch, Tensor box_size, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); - m.def("get_neighbor_pairs_cell(Tensor positions, Tensor batch, Tensor box_size, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); + m.def("get_neighbor_pairs(Tensor positions, Tensor batch, Tensor box_size, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); + m.def("get_neighbor_pairs_cell(Tensor positions, Tensor batch, Tensor box_size, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); } diff --git a/torchmdnet/neighbors/neighbors_cpu.cpp b/torchmdnet/neighbors/neighbors_cpu.cpp index fcae246d5..807c29257 100644 --- a/torchmdnet/neighbors/neighbors_cpu.cpp +++ b/torchmdnet/neighbors/neighbors_cpu.cpp @@ -2,25 +2,28 @@ #include using std::tuple; +using torch::arange; using torch::div; +using torch::frobenius_norm; using torch::full; +using torch::hstack; using torch::index_select; -using torch::indexing::Slice; -using torch::arange; -using torch::frobenius_norm; using torch::kInt32; -using torch::Scalar; -using torch::hstack; -using torch::vstack; -using torch::Tensor; using torch::outer; using torch::round; +using torch::Scalar; +using torch::Tensor; +using torch::vstack; +using torch::indexing::Slice; -static tuple forward(const Tensor& positions, const Tensor& batch, const Tensor &box_size, - const Scalar& cutoff_lower, const Scalar& cutoff_upper, - const Scalar& max_num_pairs, bool loop, bool checkErrors) { +static tuple forward(const Tensor& positions, const Tensor& batch, + const Tensor& box_size, const Scalar& cutoff_lower, + const Scalar& cutoff_upper, + const Scalar& max_num_pairs, bool loop, + bool include_transpose, bool checkErrors) { TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); - TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); + TORCH_CHECK(positions.size(0) > 0, + "Expected the 1nd dimension size of \"positions\" to be more than 0"); TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); @@ -41,12 +44,14 @@ static tuple forward(const Tensor& positions, const Tens TORCH_CHECK(v[0][0] >= 2 * c, "Invalid box vectors: box_vectors[0][0] < 2*cutoff"); TORCH_CHECK(v[1][1] >= 2 * c, "Invalid box vectors: box_vectors[1][1] < 2*cutoff"); TORCH_CHECK(v[2][2] >= 2 * c, "Invalid box vectors: box_vectors[2][2] < 2*cutoff"); - TORCH_CHECK(v[0][0] >= 2 * v[1][0], "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]"); - TORCH_CHECK(v[0][0] >= 2 * v[2][0], "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]"); - TORCH_CHECK(v[1][1] >= 2 * v[2][1], "Invalid box vectors: box_vectors[1][1] < 2*box_vectors[2][1]"); + TORCH_CHECK(v[0][0] >= 2 * v[1][0], + "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]"); + TORCH_CHECK(v[0][0] >= 2 * v[2][0], + "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]"); + TORCH_CHECK(v[1][1] >= 2 * v[2][1], + "Invalid box vectors: box_vectors[1][1] < 2*box_vectors[2][1]"); } - TORCH_CHECK(max_num_pairs.toLong() > 0, - "Expected \"max_num_neighbors\" to be positive"); + TORCH_CHECK(max_num_pairs.toLong() > 0, "Expected \"max_num_neighbors\" to be positive"); const int n_atoms = positions.size(0); const int n_batches = batch[n_atoms - 1].item() + 1; int current_offset = 0; @@ -55,43 +60,49 @@ static tuple forward(const Tensor& positions, const Tens Tensor neighbors = torch::empty({0}, positions.options().dtype(kInt32)); Tensor distances = torch::empty({0}, positions.options()); Tensor deltas = torch::empty({0}, positions.options()); - for(int i = 0; i < n_batches; i++){ - batch_i.clear(); - for(int j = current_offset; j < n_atoms; j++){ - if(batch[j].item() == i){ - batch_i.push_back(j); - } - else{ - break; - } - } - const int n_atoms_i = batch_i.size(); - Tensor positions_i = index_select(positions, 0, torch::tensor(batch_i, kInt32)); - Tensor indices_i = arange(0, n_atoms_i * (n_atoms_i - 1) / 2, positions.options().dtype(kInt32)); - Tensor rows_i = (((8 * indices_i + 1).sqrt() + 1) / 2).floor().to(kInt32); - rows_i -= (rows_i * (rows_i - 1) > 2 * indices_i).to(kInt32); - Tensor columns_i = indices_i - div(rows_i * (rows_i - 1), 2, "floor"); - Tensor neighbors_i = vstack({rows_i, columns_i}); - Tensor deltas_i = index_select(positions_i, 0, rows_i) - index_select(positions_i, 0, columns_i); - Tensor distances_i = frobenius_norm(deltas_i, 1); - const Tensor mask_upper = distances_i <= cutoff_upper; - const Tensor mask_lower = distances_i >= cutoff_lower; - const Tensor mask = mask_upper*mask_lower; - neighbors_i = neighbors_i.index({Slice(), mask}) + current_offset; - //Add the transposed pairs - neighbors_i = torch::hstack({neighbors_i, torch::stack({neighbors_i[1], neighbors_i[0]})}); - //Add self interaction using batch_i - if(loop){ - const Tensor batch_i_tensor = torch::tensor(batch_i, kInt32); - neighbors_i = torch::hstack({neighbors_i, torch::stack({batch_i_tensor, batch_i_tensor})}); - } - n_pairs += neighbors_i.size(1); - TORCH_CHECK(n_pairs >= 0, "The maximum number of pairs has been exceed! Increase \"max_num_neighbors\""); - neighbors = torch::hstack({neighbors, neighbors_i}); - current_offset += n_atoms_i; + for (int i = 0; i < n_batches; i++) { + batch_i.clear(); + for (int j = current_offset; j < n_atoms; j++) { + if (batch[j].item() == i) { + batch_i.push_back(j); + } else { + break; + } + } + const int n_atoms_i = batch_i.size(); + Tensor positions_i = index_select(positions, 0, torch::tensor(batch_i, kInt32)); + Tensor indices_i = + arange(0, n_atoms_i * (n_atoms_i - 1) / 2, positions.options().dtype(kInt32)); + Tensor rows_i = (((8 * indices_i + 1).sqrt() + 1) / 2).floor().to(kInt32); + rows_i -= (rows_i * (rows_i - 1) > 2 * indices_i).to(kInt32); + Tensor columns_i = indices_i - div(rows_i * (rows_i - 1), 2, "floor"); + Tensor neighbors_i = vstack({rows_i, columns_i}); + Tensor deltas_i = + index_select(positions_i, 0, rows_i) - index_select(positions_i, 0, columns_i); + Tensor distances_i = frobenius_norm(deltas_i, 1); + const Tensor mask_upper = distances_i <= cutoff_upper; + const Tensor mask_lower = distances_i >= cutoff_lower; + const Tensor mask = mask_upper * mask_lower; + neighbors_i = neighbors_i.index({Slice(), mask}) + current_offset; + // Add the transposed pairs + if (include_transpose) { + neighbors_i = + torch::hstack({neighbors_i, torch::stack({neighbors_i[1], neighbors_i[0]})}); + } + // Add self interaction using batch_i + if (loop) { + const Tensor batch_i_tensor = torch::tensor(batch_i, kInt32); + neighbors_i = + torch::hstack({neighbors_i, torch::stack({batch_i_tensor, batch_i_tensor})}); + } + n_pairs += neighbors_i.size(1); + TORCH_CHECK(n_pairs >= 0, + "The maximum number of pairs has been exceed! Increase \"max_num_neighbors\""); + neighbors = torch::hstack({neighbors, neighbors_i}); + current_offset += n_atoms_i; } - if(n_batches > 1){ - neighbors = torch::cat(neighbors,0).to(kInt32); + if (n_batches > 1) { + neighbors = torch::cat(neighbors, 0).to(kInt32); } deltas = index_select(positions, 0, neighbors[0]) - index_select(positions, 0, neighbors[1]); distances = frobenius_norm(deltas, 1); diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index 2f1ce5514..e2df6ce55 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -46,9 +46,9 @@ __device__ int32_t get_row(int index) { template __global__ void forward_kernel(const int64_t num_all_pairs, const Accessor positions, const Accessor batch, scalar_t cutoff_lower2, - scalar_t cutoff_upper2, bool loop, Accessor i_curr_pair, - Accessor neighbors, Accessor deltas, - Accessor distances) { + scalar_t cutoff_upper2, bool loop, bool include_transpose, + Accessor i_curr_pair, Accessor neighbors, + Accessor deltas, Accessor distances) { const int64_t index = blockIdx.x * blockDim.x + threadIdx.x; if (index >= num_all_pairs) return; @@ -61,22 +61,24 @@ __global__ void forward_kernel(const int64_t num_all_pairs, const Accessor= cutoff_lower2) { - const int32_t i_pair = atomicAdd(&i_curr_pair[0], 2); + const int32_t i_pair = atomicAdd(&i_curr_pair[0], include_transpose ? 2 : 1); // We handle too many neighbors outside of the kernel if (i_pair < neighbors.size(1)) { - const scalar_t r2 = sqrt_(distance2); + const scalar_t r2 = sqrt_(distance2); neighbors[0][i_pair] = row; neighbors[1][i_pair] = column; deltas[i_pair][0] = delta_x; deltas[i_pair][1] = delta_y; deltas[i_pair][2] = delta_z; distances[i_pair] = r2; - neighbors[0][i_pair+1] = column; - neighbors[1][i_pair+1] = row; - deltas[i_pair+1][0] = -delta_x; - deltas[i_pair+1][1] = -delta_y; - deltas[i_pair+1][2] = -delta_z; - distances[i_pair+1] = r2; + if (include_transpose) { + neighbors[0][i_pair + 1] = column; + neighbors[1][i_pair + 1] = row; + deltas[i_pair + 1][0] = -delta_x; + deltas[i_pair + 1][1] = -delta_y; + deltas[i_pair + 1][2] = -delta_z; + distances[i_pair + 1] = r2; + } } } } @@ -161,7 +163,8 @@ class Autograd : public Function { public: static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, const Scalar& cutoff_lower, const Scalar& cutoff_upper, - const Scalar& max_num_pairs, bool loop, bool checkErrors) { + const Scalar& max_num_pairs, bool loop, bool include_transpose, + bool checkErrors) { checkInput(positions, batch); const auto max_num_pairs_ = max_num_pairs.toLong(); TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); @@ -187,9 +190,9 @@ public: forward_kernel<<>>( num_all_pairs, get_accessor(positions), get_accessor(batch), cutoff_lower_ * cutoff_lower_, - cutoff_upper_ * cutoff_upper_, loop, get_accessor(i_curr_pair), - get_accessor(neighbors), get_accessor(deltas), - get_accessor(distances)); + cutoff_upper_ * cutoff_upper_, loop, include_transpose, + get_accessor(i_curr_pair), get_accessor(neighbors), + get_accessor(deltas), get_accessor(distances)); if (loop) { const int64_t num_threads = 128; const int64_t num_blocks = @@ -245,12 +248,13 @@ public: }; TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { - m.impl("get_neighbor_pairs", [](const Tensor& positions, const Tensor& batch, - const Tensor& box_size, const Scalar& cutoff_lower, - const Scalar& cutoff_upper, const Scalar& max_num_pairs, - bool loop, bool checkErrors) { - const tensor_list results = Autograd::apply(positions, batch, cutoff_lower, cutoff_upper, - max_num_pairs, loop, checkErrors); - return std::make_tuple(results[0], results[1], results[2]); - }); + m.impl("get_neighbor_pairs", + [](const Tensor& positions, const Tensor& batch, const Tensor& box_size, + const Scalar& cutoff_lower, const Scalar& cutoff_upper, const Scalar& max_num_pairs, + bool loop, bool include_transpose, bool checkErrors) { + const tensor_list results = + Autograd::apply(positions, batch, cutoff_lower, cutoff_upper, max_num_pairs, + loop, include_transpose, checkErrors); + return std::make_tuple(results[0], results[1], results[2]); + }); } diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index c00dd0e41..ff66a755c 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -391,7 +391,7 @@ forward_kernel(const Accessor sorted_positions, Accessor neighbors, Accessor deltas, Accessor distances, Accessor i_curr_pair, int num_atoms, int num_pairs, scalar3 box_size, scalar_t cutoff_lower, - scalar_t cutoff_upper, bool loop, bool include_traspose) { + scalar_t cutoff_upper, bool loop, bool include_transpose) { // Each atom traverses the cells around it and finds the neighbors // Atoms for all batches are placed in the same cell list, but other batches are ignored while // traversing @@ -421,7 +421,7 @@ forward_kernel(const Accessor sorted_positions, break; const bool includePair = (j_batch == i_batch) and - ((orj != ori and (orj < ori or include_traspose)) or (loop and orj == ori)); + ((orj != ori and (orj < ori or include_transpose)) or (loop and orj == ori)); if (includePair) { const scalar3 pj = {sorted_positions[cur_j][0], sorted_positions[cur_j][1], @@ -443,12 +443,12 @@ forward_kernel(const Accessor sorted_positions, deltas[i_pair][1] = dy; deltas[i_pair][2] = dz; distances[i_pair] = sqrt_(distance2); - } - } - } // endfor - } // endif - } // endfor - } + } // endif + } // endif + } // endfor + } // endif + } // endfor + } // endfor } class Autograd : public Function { @@ -456,7 +456,7 @@ public: static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, const Tensor& box_size, const Scalar& cutoff_lower, const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop, - bool checkErrors) { + bool include_transpose, bool checkErrors) { // The algorithm for the cell list construction can be summarized in three separate steps: // 1. Hash (label) the particles according to the cell (bin) they lie in. // 2. Sort the particles and hashes using the hashes as the ordering label @@ -494,7 +494,6 @@ public: box_size[2].item()}; const int threads = 128; const int blocks = (num_atoms + threads - 1) / threads; - bool include_traspose = true; forward_kernel<<>>( get_accessor(sorted_positions), get_accessor(hash_values), get_accessor(batch), @@ -502,7 +501,7 @@ public: get_accessor(neighbors), get_accessor(deltas), get_accessor(distances), get_accessor(i_curr_pair), num_atoms, num_pairs, box_size_, cutoff_lower_, cutoff_upper_, loop, - include_traspose); + include_transpose); }); } // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA @@ -548,12 +547,13 @@ public: }; TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { - m.impl("get_neighbor_pairs_cell", [](const Tensor& positions, const Tensor& batch, - const Tensor& box_size, const Scalar& cutoff_lower, - const Scalar& cutoff_upper, const Scalar& max_num_pairs, - bool loop, bool checkErrors) { - const tensor_list results = Autograd::apply(positions, batch, box_size, cutoff_lower, - cutoff_upper, max_num_pairs, loop, checkErrors); - return std::make_tuple(results[0], results[1], results[2]); - }); + m.impl("get_neighbor_pairs_cell", + [](const Tensor& positions, const Tensor& batch, const Tensor& box_size, + const Scalar& cutoff_lower, const Scalar& cutoff_upper, const Scalar& max_num_pairs, + bool loop, bool include_transpose, bool checkErrors) { + const tensor_list results = + Autograd::apply(positions, batch, box_size, cutoff_lower, cutoff_upper, + max_num_pairs, loop, include_transpose, checkErrors); + return std::make_tuple(results[0], results[1], results[2]); + }); } From b066b31574fb06f0bc21c72fece1a0c89d4baab0 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 5 May 2023 21:47:11 +0200 Subject: [PATCH 24/76] Add Periodic Boundary Conditions (CPU and brute force support triclinic boxes, cell list only rectangular) Some optimizations overall Update tests Update benchmark --- benchmarks/neighbors.py | 77 ++++++++++++--- tests/test_neighbors.py | 81 ++++++++++++++-- torchmdnet/models/utils.py | 22 +++-- torchmdnet/neighbors/neighbors.cpp | 4 +- torchmdnet/neighbors/neighbors_cpu.cpp | 30 ++++-- torchmdnet/neighbors/neighbors_cuda.cu | 90 +++++++++-------- torchmdnet/neighbors/neighbors_cuda_cell.cu | 102 ++++++++++++-------- 7 files changed, 282 insertions(+), 124 deletions(-) diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py index 13a2c2c68..6d468a879 100644 --- a/benchmarks/neighbors.py +++ b/benchmarks/neighbors.py @@ -1,7 +1,9 @@ import os import torch import numpy as np -from torchmdnet.models.utils import DistanceCellList +from torchmdnet.models.utils import Distance, DistanceCellList + + def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_num_neighbors=32): @@ -22,34 +24,48 @@ def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_n float Average time per batch in seconds. """ - density = 0.5; + density = 0.7; + torch.random.manual_seed(12344) num_particles = total_num_particles // n_batches expected_num_neighbors = mean_num_neighbors cutoff = np.cbrt(3 * expected_num_neighbors / (4 * np.pi * density)); - n_atoms_per_batch = torch.randint(num_particles-10, num_particles+10, size=(n_batches,)) + n_atoms_per_batch = torch.randint(int(num_particles/2), int(num_particles*2), size=(n_batches,),device="cpu") #Fix so that the total number of particles is correct. Special care if the difference is negative difference = total_num_particles - n_atoms_per_batch.sum() - if n_atoms_per_batch[-1] + difference > 0: - n_atoms_per_batch[-1] += difference + if difference > 0: + while difference > 0: + i = np.random.randint(0, n_batches) + n_atoms_per_batch[i] += 1 + difference -= 1 + else: while difference < 0: i = np.random.randint(0, n_batches) - if n_atoms_per_batch[i] > 2: + if n_atoms_per_batch[i] > num_particles: n_atoms_per_batch[i] -= 1 difference += 1 lbox = np.cbrt(num_particles / density); - batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int32), n_atoms_per_batch).to(device) + batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch).to(device) cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) - pos = torch.rand(cumsum[-1], 3, device=device)*lbox - max_num_pairs = (expected_num_neighbors * n_atoms_per_batch.sum()).item() - nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=torch.Tensor([lbox, lbox, lbox])) + pos = torch.rand(cumsum[-1], 3, device="cpu").to(device)*lbox + if strategy != 'distance': + max_num_pairs = (expected_num_neighbors * n_atoms_per_batch.sum()).item()*2 + box = torch.eye(3, device=device)*lbox + nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box) + else: + max_num_neighbors = int(expected_num_neighbors*5) + nl = Distance(loop=False, cutoff_lower=0.0, cutoff_upper=cutoff, max_num_neighbors=max_num_neighbors) #Warmup for i in range(10): neighbors, distances, distance_vecs = nl(pos, batch) + #print + print("Batch with largest number of atoms: {}".format(int(n_atoms_per_batch.max()))) + print("Batch with smallest number of atoms: {}".format(int(n_atoms_per_batch.min()))) + print("Number of pairs: {}, Number of particles: {}".format(int(neighbors.shape[1]), int(n_atoms_per_batch.to(torch.double). + sum().item()))) if device == 'cuda': torch.cuda.synchronize() - #Benchmark using torch profiler - nruns = 100 + nruns = 10 if device == 'cuda': torch.cuda.synchronize() @@ -66,20 +82,49 @@ def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_n return (start.elapsed_time(end) / nruns) if __name__ == '__main__': - n_particles = 100000 + n_particles = 32767 mean_num_neighbors = min(n_particles, 16); print("Benchmarking neighbor list generation for {} particles with {} neighbors on average".format(n_particles, mean_num_neighbors)) - for strategy in ['brute', 'cell']: + results = {} + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + for strategy in ['brute', 'cell', 'distance']: print("Strategy: {}".format(strategy)) print("--------") print("{:<10} {:<10}".format("Batch size", "Time (ms)")) print("{:<10} {:<10}".format("----------", "---------")) - #Loop over different number of batches - for n_batches in [1, 10, 100, 1000]: + #Loop over different number of batches, random + for n_batches in batch_sizes: time = benchmark_neighbors(device='cuda', strategy=strategy, n_batches=n_batches, total_num_particles=n_particles, mean_num_neighbors=mean_num_neighbors ) + #Store results in a dictionary + results[strategy, n_batches] = time print("{:<10} {:<10.2f}".format(n_batches, time)) + print("\n") + print("Summary") + print("-------") + print("{:<10} {:<21} {:<18} {:<10}".format("Batch size", "Brute(ms)", "Cell(ms)", "Distance(ms)")) + print("{:<10} {:<21} {:<18} {:<10}".format("----------", "---------", "--------", "-----------")) + #Print a column per strategy, show speedup over Distance in parenthesis + for n_batches in batch_sizes: + base = results['distance', n_batches] + print("{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format(n_batches, + results['brute', n_batches], + base/results['brute', n_batches], + results['cell', n_batches], + base/results['cell', n_batches], + results['distance', n_batches])) + + #Print a second table showing time per atom, show in ns + print("\n") + print("Time per atom") + print("{:<10} {:<10} {:<10} {:<10}".format("Batch size", "Brute(ns)", "Cell(ns)", "Distance(ns)")) + print("{:<10} {:<10} {:<10} {:<10}".format("----------", "---------", "--------", "-----------")) + for n_batches in batch_sizes: + print("{:<10} {:<10.2f} {:<10.2f} {:<10.2f}".format(n_batches, + results['brute', n_batches]/n_particles*1e6, + results['cell', n_batches]/n_particles*1e6, + results['distance', n_batches]/n_particles*1e6)) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index f6cfe95fe..d46031d47 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -9,7 +9,18 @@ def sort_neighbors(neighbors, deltas, distances): return neighbors[:, i_sorted], deltas[i_sorted], distances[i_sorted] -def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff): +def apply_pbc(deltas, box_vectors): + if(box_vectors is None): + return deltas + else: + ref_vectors = box_vectors.cpu().detach().numpy() + deltas -= np.outer(np.round(deltas[:,2]/ref_vectors[2,2]), ref_vectors[2]) + deltas -= np.outer(np.round(deltas[:,1]/ref_vectors[1,1]), ref_vectors[1]) + deltas -= np.outer(np.round(deltas[:,0]/ref_vectors[0,0]), ref_vectors[0]) + return deltas + + +def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box_vectors): batch = batch.cpu() n_atoms_per_batch = torch.bincount(batch) n_batches = n_atoms_per_batch.shape[0] @@ -22,8 +33,9 @@ def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff): ilist=np.arange(cumsum[-1]) ref_neighbors = np.concatenate([ref_neighbors, np.stack([ilist, ilist], axis=0)], axis=1) pos_np = pos.cpu().detach().numpy() - ref_distances = np.linalg.norm(pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]], axis=-1) - ref_distance_vecs = pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]] + ref_distance_vecs = apply_pbc(pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]], box_vectors) + ref_distances = np.linalg.norm(ref_distance_vecs, axis=-1) + #remove pairs with distance > cutoff mask = ref_distances < cutoff ref_neighbors = ref_neighbors[:, mask] @@ -35,12 +47,15 @@ def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff): @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("strategy", ["brute", "cell"]) @pytest.mark.parametrize("n_batches", [1, 2, 3, 4, 128]) -@pytest.mark.parametrize("cutoff", [0.1, 1.0, 1000.0]) +@pytest.mark.parametrize("cutoff", [0.1, 1.0, 3.0, 4.9]) @pytest.mark.parametrize("loop", [True, False]) @pytest.mark.parametrize("include_transpose", [True, False]) -def test_neighbors(device, strategy, n_batches, cutoff, loop, include_transpose): +@pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"]) +def test_neighbors(device, strategy, n_batches, cutoff, loop, include_transpose, box_type): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") + if box_type == "triclinic" and strategy == "cell": + pytest.skip("Triclinic only supported for brute force") torch.manual_seed(4321) n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch).to(device) @@ -51,10 +66,12 @@ def test_neighbors(device, strategy, n_batches, cutoff, loop, include_transpose) pos[0,:] = torch.zeros(3) pos[1,:] = torch.zeros(3) pos.requires_grad = True - ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff) + if(box_type is None): + box = None + else: + box = torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]).to(pos.dtype).to(device) + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box) max_num_pairs = ref_neighbors.shape[1] - box = torch.tensor([lbox, lbox, lbox]) - nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True, include_transpose=include_transpose) batch.to(device) neighbors, distances, distance_vecs = nl(pos, batch) @@ -89,7 +106,7 @@ def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop): pos[0,:] = torch.zeros(3) pos[1,:] = torch.zeros(3) pos.requires_grad = True - ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, True, cutoff) + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, True, cutoff, None) #Find the particle appearing in the most pairs max_num_neighbors = torch.max(torch.bincount(torch.tensor(ref_neighbors[0,:]))) d = Distance(cutoff_lower=0.0, cutoff_upper=cutoff, loop=loop, max_num_neighbors=max_num_neighbors, return_vecs=True) @@ -100,7 +117,7 @@ def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop): ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors(ref_neighbors, ref_distance_vecs, ref_distances) max_num_pairs = ref_neighbors.shape[1] - box = torch.tensor([lbox, lbox, lbox]) + box = None nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True, include_transpose=True) neighbors, distances, distance_vecs = nl(pos, batch) neighbors = neighbors.cpu().detach().numpy() @@ -110,3 +127,47 @@ def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop): assert np.allclose(neighbors, ref_neighbors) assert np.allclose(distances, ref_distances) assert np.allclose(distance_vecs, ref_distance_vecs) + + + +@pytest.mark.parametrize("strategy", ["brute", "cell"]) +@pytest.mark.parametrize("n_batches", [1, 2, 3, 4]) +def test_large_size(strategy, n_batches): + device = "cuda" + cutoff = 1.76 + loop = False + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + torch.manual_seed(4321) + num_atoms=int(32000/n_batches) + n_atoms_per_batch = torch.ones(n_batches, dtype=torch.int64)*num_atoms + batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch).to(device) + cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) + lbox=45.0 + pos = torch.rand(cumsum[-1], 3, device=device)*lbox + #Ensure there is at least one pair + pos[0,:] = torch.zeros(3) + pos[1,:] = torch.zeros(3) + pos.requires_grad = True + #Find the particle appearing in the most pairs + max_num_neighbors = 64 + d = Distance(cutoff_lower=0.0, cutoff_upper=cutoff, loop=loop, max_num_neighbors=max_num_neighbors, return_vecs=True) + ref_neighbors, ref_distances, ref_distance_vecs = d(pos, batch) + ref_neighbors = ref_neighbors.cpu().detach().numpy() + ref_distance_vecs = ref_distance_vecs.cpu().detach().numpy() + ref_distances = ref_distances.cpu().detach().numpy() + ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors(ref_neighbors, ref_distance_vecs, ref_distances) + + max_num_pairs = ref_neighbors.shape[1] + + #Must check without PBC since Distance does not support it + box = None #torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]).to(pos.dtype).to(device) + nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True, include_transpose=True, resize_to_fit=True) + neighbors, distances, distance_vecs = nl(pos, batch) + neighbors = neighbors.cpu().detach().numpy() + distance_vecs = distance_vecs.cpu().detach().numpy() + distances = distances.cpu().detach().numpy() + neighbors, distance_vecs, distances = sort_neighbors(neighbors, distance_vecs, distances) + assert np.allclose(neighbors, ref_neighbors) + assert np.allclose(distances, ref_distances) + assert np.allclose(distance_vecs, ref_distance_vecs) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 726c11192..7ddd2dccf 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -104,7 +104,8 @@ def __init__( Strategy to use for computing the neighbor list. Can be one of ["brute", "cell"]. box : torch.Tensor - Size of the box shape (3,) or None + Size of the box shape (3,3) or None. + If strategy is "cell", the box must be diagonal. loop : bool Whether to include self-interactions. include_transpose : bool @@ -113,7 +114,6 @@ def __init__( Whether to resize the neighbor list to the actual number of pairs found. return_vecs : bool Whether to return the distance vectors. - """ self.cutoff_upper = cutoff_upper self.cutoff_lower = cutoff_lower @@ -124,10 +124,13 @@ def __init__( self.return_vecs = return_vecs self.include_transpose = include_transpose self.resize_to_fit = resize_to_fit - #Default the box to 3 times the cutoff - if self.box is None and self.strategy == "cell": - self.box = torch.tensor([cutoff_upper * 3] * 3) - + self.use_periodic = True + if self.box is None: + self.use_periodic = False + if self.strategy == "cell": + #Default the box to 3 times the cutoff, really inefficient for the cell list + lbox = cutoff_upper * 3.0 + self.box = torch.tensor([[lbox, 0, 0], [0, lbox, 0], [0, 0, lbox]]) def forward(self, pos, batch): """ @@ -154,6 +157,9 @@ def forward(self, pos, batch): """ function = get_neighbor_pairs if self.strategy == "brute" else get_neighbor_pairs_cell + if self.box is None: + self.box = torch.empty((0, 0), dtype=pos.dtype) + self.box = self.box.to(pos.dtype).to(pos.device) neighbors, distance_vecs, distances = function( pos, cutoff_lower=self.cutoff_lower, @@ -163,7 +169,8 @@ def forward(self, pos, batch): max_num_pairs=self.max_num_pairs, check_errors=True, include_transpose=self.include_transpose, - box_size=self.box + box_vectors=self.box, + use_periodic=self.use_periodic ) #Remove (-1,-1) pairs if self.resize_to_fit: @@ -171,7 +178,6 @@ def forward(self, pos, batch): neighbors = neighbors[:, mask] distances = distances[mask] distance_vecs = distance_vecs[mask,:] - if self.return_vecs: return neighbors, distances, distance_vecs else: diff --git a/torchmdnet/neighbors/neighbors.cpp b/torchmdnet/neighbors/neighbors.cpp index 64c4aefb5..e4b5b53bb 100644 --- a/torchmdnet/neighbors/neighbors.cpp +++ b/torchmdnet/neighbors/neighbors.cpp @@ -1,6 +1,6 @@ #include TORCH_LIBRARY(neighbors, m) { - m.def("get_neighbor_pairs(Tensor positions, Tensor batch, Tensor box_size, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); - m.def("get_neighbor_pairs_cell(Tensor positions, Tensor batch, Tensor box_size, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); + m.def("get_neighbor_pairs(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); + m.def("get_neighbor_pairs_cell(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); } diff --git a/torchmdnet/neighbors/neighbors_cpu.cpp b/torchmdnet/neighbors/neighbors_cpu.cpp index 807c29257..0dd6bf798 100644 --- a/torchmdnet/neighbors/neighbors_cpu.cpp +++ b/torchmdnet/neighbors/neighbors_cpu.cpp @@ -17,8 +17,8 @@ using torch::vstack; using torch::indexing::Slice; static tuple forward(const Tensor& positions, const Tensor& batch, - const Tensor& box_size, const Scalar& cutoff_lower, - const Scalar& cutoff_upper, + const Tensor& box_vectors, bool use_periodic, + const Scalar& cutoff_lower, const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop, bool include_transpose, bool checkErrors) { TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); @@ -28,8 +28,7 @@ static tuple forward(const Tensor& positions, const Tens TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); TORCH_CHECK(cutoff_upper.to() > 0, "Expected \"cutoff\" to be positive"); - auto box_vectors = torch::empty(0); - if (box_vectors.size(0) != 0) { + if (use_periodic) { TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions"); TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, "Expected \"box_vectors\" to have shape (3, 3)"); @@ -69,18 +68,25 @@ static tuple forward(const Tensor& positions, const Tens break; } } + // batch_i = torch.where(batch[current_offset:] == i) + const int n_atoms_i = batch_i.size(); Tensor positions_i = index_select(positions, 0, torch::tensor(batch_i, kInt32)); Tensor indices_i = - arange(0, n_atoms_i * (n_atoms_i - 1) / 2, positions.options().dtype(kInt32)); - Tensor rows_i = (((8 * indices_i + 1).sqrt() + 1) / 2).floor().to(kInt32); - rows_i -= (rows_i * (rows_i - 1) > 2 * indices_i).to(kInt32); - Tensor columns_i = indices_i - div(rows_i * (rows_i - 1), 2, "floor"); + arange(0, n_atoms_i * (n_atoms_i - 1l) / 2l, positions.options().dtype(torch::kLong)); + Tensor rows_i = (((8l * indices_i + 1l).sqrt() + 1l) / 2l).floor().to(torch::kLong); + rows_i -= (rows_i * (rows_i - 1l) > 2l * indices_i).to(torch::kLong); + Tensor columns_i = indices_i - div(rows_i * (rows_i - 1l), 2, "floor"); Tensor neighbors_i = vstack({rows_i, columns_i}); Tensor deltas_i = index_select(positions_i, 0, rows_i) - index_select(positions_i, 0, columns_i); + if (use_periodic) { + deltas_i -= outer(round(deltas_i.index({Slice(), 2})/box_vectors.index({2, 2})), box_vectors.index({2})); + deltas_i -= outer(round(deltas_i.index({Slice(), 1})/box_vectors.index({1, 1})), box_vectors.index({1})); + deltas_i -= outer(round(deltas_i.index({Slice(), 0})/box_vectors.index({0, 0})), box_vectors.index({0})); + } Tensor distances_i = frobenius_norm(deltas_i, 1); - const Tensor mask_upper = distances_i <= cutoff_upper; + const Tensor mask_upper = distances_i < cutoff_upper; const Tensor mask_lower = distances_i >= cutoff_lower; const Tensor mask = mask_upper * mask_lower; neighbors_i = neighbors_i.index({Slice(), mask}) + current_offset; @@ -105,8 +111,12 @@ static tuple forward(const Tensor& positions, const Tens neighbors = torch::cat(neighbors, 0).to(kInt32); } deltas = index_select(positions, 0, neighbors[0]) - index_select(positions, 0, neighbors[1]); + if (use_periodic) { + deltas -= outer(round(deltas.index({Slice(), 2})/box_vectors.index({2, 2})), box_vectors.index({2})); + deltas -= outer(round(deltas.index({Slice(), 1})/box_vectors.index({1, 1})), box_vectors.index({1})); + deltas -= outer(round(deltas.index({Slice(), 0})/box_vectors.index({0, 0})), box_vectors.index({0})); + } distances = frobenius_norm(deltas, 1); - return {neighbors, deltas, distances}; } diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index e2df6ce55..579894fc8 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -1,8 +1,8 @@ #include #include #include +#include #include -#include using c10::cuda::CUDAStreamGuard; using c10::cuda::getCurrentCUDAStream; using std::make_tuple; @@ -36,31 +36,50 @@ template <> __device__ __forceinline__ double sqrt_(double x) { return ::sqrt(x); }; -__device__ int32_t get_row(int index) { - int32_t row = floor((sqrtf(8 * index + 1) + 1) / 2); +__device__ uint32_t get_row(uint32_t index) { + uint32_t row = floor((sqrtf(8 * index + 1) + 1) / 2); if (row * (row - 1) > 2 * index) row--; return row; } template -__global__ void forward_kernel(const int64_t num_all_pairs, const Accessor positions, +__device__ auto apply_pbc(scalar_t delta_x, scalar_t delta_y, scalar_t delta_z, + const Accessor box_vectors) { + scalar_t scale3 = round(delta_z / box_vectors[2][2]); + delta_x -= scale3 * box_vectors[2][0]; + delta_y -= scale3 * box_vectors[2][1]; + delta_z -= scale3 * box_vectors[2][2]; + scalar_t scale2 = round(delta_y / box_vectors[1][1]); + delta_x -= scale2 * box_vectors[1][0]; + delta_y -= scale2 * box_vectors[1][1]; + scalar_t scale1 = round(delta_x / box_vectors[0][0]); + delta_x -= scale1 * box_vectors[0][0]; + return thrust::make_tuple(delta_x, delta_y, delta_z); +} + +template +__global__ void forward_kernel(uint32_t num_all_pairs, const Accessor positions, const Accessor batch, scalar_t cutoff_lower2, scalar_t cutoff_upper2, bool loop, bool include_transpose, Accessor i_curr_pair, Accessor neighbors, - Accessor deltas, Accessor distances) { - const int64_t index = blockIdx.x * blockDim.x + threadIdx.x; + Accessor deltas, Accessor distances, + bool use_periodic, const Accessor box_vectors) { + const uint32_t index = blockIdx.x * blockDim.x + threadIdx.x; if (index >= num_all_pairs) return; - - int32_t row = get_row(index); - const int32_t column = (index - row * (row - 1) / 2); + const uint32_t row = get_row(index); + const uint32_t column = (index - row * (row - 1) / 2); if (batch[row] == batch[column]) { scalar_t delta_x = positions[row][0] - positions[column][0]; scalar_t delta_y = positions[row][1] - positions[column][1]; scalar_t delta_z = positions[row][2] - positions[column][2]; + if (use_periodic) { + thrust::tie(delta_x, delta_y, delta_z) = + apply_pbc(delta_x, delta_y, delta_z, box_vectors); + } const scalar_t distance2 = delta_x * delta_x + delta_y * delta_y + delta_z * delta_z; - if (distance2 <= cutoff_upper2 && distance2 >= cutoff_lower2) { + if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) { const int32_t i_pair = atomicAdd(&i_curr_pair[0], include_transpose ? 2 : 1); // We handle too many neighbors outside of the kernel if (i_pair < neighbors.size(1)) { @@ -82,19 +101,6 @@ __global__ void forward_kernel(const int64_t num_all_pairs, const Accessor @@ -150,7 +156,8 @@ static void checkInput(const Tensor& positions, const Tensor& batch) { "Expected the 1nd dimension size of \"positions\" to be more than 0"); TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); - + TORCH_CHECK(positions.size(0) < 1l << 15l, + "Expected the 1st dimension size of \"positions\" to be less than ", 1l << 15l); TORCH_CHECK(batch.dim() == 1, "Expected \"batch\" to have one dimension"); TORCH_CHECK(batch.size(0) == positions.size(0), "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension " @@ -163,11 +170,17 @@ class Autograd : public Function { public: static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Tensor& box_vectors, bool use_periodic, const Scalar& max_num_pairs, bool loop, bool include_transpose, bool checkErrors) { checkInput(positions, batch); const auto max_num_pairs_ = max_num_pairs.toLong(); TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); + if (use_periodic) { + TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions"); + TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, + "Expected \"box_vectors\" to have shape (3, 3)"); + } const int num_atoms = positions.size(0); const int num_pairs = max_num_pairs_; const TensorOptions options = positions.options(); @@ -178,10 +191,10 @@ public: const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); { const CUDAStreamGuard guard(stream); - const int32_t num_atoms = positions.size(0); - const int64_t num_all_pairs = num_atoms * (num_atoms - 1) / 2; - const int64_t num_threads = 128; - const int64_t num_blocks = max((num_all_pairs + num_threads - 1) / num_threads, 1l); + const uint64_t num_atoms = positions.size(0); + const uint64_t num_all_pairs = num_atoms * (num_atoms - 1ul) / 2ul; + const uint64_t num_threads = 128; + const uint64_t num_blocks = max((num_all_pairs + num_threads - 1ul) / num_threads, 1ul); AT_DISPATCH_FLOATING_TYPES( positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { const scalar_t cutoff_upper_ = cutoff_upper.to(); @@ -192,11 +205,12 @@ public: get_accessor(batch), cutoff_lower_ * cutoff_lower_, cutoff_upper_ * cutoff_upper_, loop, include_transpose, get_accessor(i_curr_pair), get_accessor(neighbors), - get_accessor(deltas), get_accessor(distances)); + get_accessor(deltas), get_accessor(distances), + use_periodic, get_accessor(box_vectors)); if (loop) { - const int64_t num_threads = 128; - const int64_t num_blocks = - max((num_atoms + num_threads - 1) / num_threads, 1l); + const uint64_t num_threads = 128; + const uint64_t num_blocks = + max((num_atoms + num_threads - 1ul) / num_threads, 1ul); add_self_kernel<<>>( num_atoms, get_accessor(positions), get_accessor(i_curr_pair), @@ -249,12 +263,12 @@ public: TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { m.impl("get_neighbor_pairs", - [](const Tensor& positions, const Tensor& batch, const Tensor& box_size, - const Scalar& cutoff_lower, const Scalar& cutoff_upper, const Scalar& max_num_pairs, - bool loop, bool include_transpose, bool checkErrors) { - const tensor_list results = - Autograd::apply(positions, batch, cutoff_lower, cutoff_upper, max_num_pairs, - loop, include_transpose, checkErrors); + [](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, + bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Scalar& max_num_pairs, bool loop, bool include_transpose, bool checkErrors) { + const tensor_list results = Autograd::apply( + positions, batch, cutoff_lower, cutoff_upper, box_vectors, use_periodic, + max_num_pairs, loop, include_transpose, checkErrors); return std::make_tuple(results[0], results[1], results[2]); }); } diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index ff66a755c..42c44a716 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -128,7 +128,7 @@ inline __host__ __device__ uint hashMorton(int3 ci) { * @return The point in the unit cell */ template -__device__ auto takeToUnitCell(scalar3 p, scalar3 box_size) { +__device__ auto apply_pbc(scalar3 p, scalar3 box_size) { p.x = p.x - floorf(p.x / box_size.x + scalar_t(0.5)) * box_size.x; p.y = p.y - floorf(p.y / box_size.y + scalar_t(0.5)) * box_size.y; p.z = p.z - floorf(p.z / box_size.z + scalar_t(0.5)) * box_size.z; @@ -166,7 +166,7 @@ __host__ __device__ int3 getCellDimensions(scalar3 box_size, scalar_t */ template __device__ int3 getCell(scalar3 p, scalar3 box_size, scalar_t cutoff) { - p = takeToUnitCell(p, box_size); + p = apply_pbc(p, box_size); // Take to the [0, box_size] range and divide by cutoff (which is the cell size) int cx = floorf((p.x + scalar_t(0.5) * box_size.x) / cutoff); int cy = floorf((p.y + scalar_t(0.5) * box_size.y) / cutoff); @@ -261,7 +261,6 @@ public: * @param cutoff The cutoff * @return A tuple of the sorted positions and the original indices of each atom in the sorted list */ - static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, const Tensor& box_size, const Scalar& cutoff) { @@ -274,15 +273,15 @@ static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "assignHash", [&] { scalar_t cutoff_ = cutoff.to(); - scalar3 box_size_ = {box_size[0].item(), box_size[1].item(), - box_size[2].item()}; + scalar3 box_size_ = {box_size[0][0].item(), + box_size[1][1].item(), + box_size[2][2].item()}; assignHash<<>>( get_accessor(positions), thrust::raw_pointer_cast(hash_keys.data()), get_accessor(hash_values), get_accessor(batch), box_size_, cutoff_, num_atoms); }); thrust::device_ptr index_ptr(hash_values.data_ptr()); - CudaAllocator allocator; thrust::sort_by_key(thrust::cuda::par.on(stream), hash_keys.begin(), hash_keys.end(), index_ptr); Tensor sorted_positions = positions.index_select(0, hash_values); @@ -343,8 +342,9 @@ static auto fillCellOffsets(const Tensor& sorted_positions, const Tensor& sorted int3 cell_dim; AT_DISPATCH_FLOATING_TYPES(sorted_positions.scalar_type(), "fillCellOffsets", [&] { scalar_t cutoff_ = cutoff.to(); - scalar3 box_size_ = {box_size[0].item(), box_size[1].item(), - box_size[2].item()}; + scalar3 box_size_ = {box_size[0][0].item(), + box_size[1][1].item(), + box_size[2][2].item()}; cell_dim = getCellDimensions(box_size_, cutoff_); }); const int num_cells = cell_dim.x * cell_dim.y * cell_dim.z; @@ -355,8 +355,9 @@ static auto fillCellOffsets(const Tensor& sorted_positions, const Tensor& sorted AT_DISPATCH_FLOATING_TYPES(sorted_positions.scalar_type(), "fillCellOffsets", [&] { auto stream = at::cuda::getCurrentCUDAStream(); scalar_t cutoff_ = cutoff.to(); - scalar3 box_size_ = {box_size[0].item(), box_size[1].item(), - box_size[2].item()}; + scalar3 box_size_ = {box_size[0][0].item(), + box_size[1][1].item(), + box_size[2][2].item()}; fillCellOffsetsD<<>>( get_accessor(sorted_positions), get_accessor(sorted_indices), get_accessor(cell_start), get_accessor(cell_end), @@ -390,7 +391,7 @@ forward_kernel(const Accessor sorted_positions, const Accessor cell_start, const Accessor cell_end, Accessor neighbors, Accessor deltas, Accessor distances, Accessor i_curr_pair, int num_atoms, - int num_pairs, scalar3 box_size, scalar_t cutoff_lower, + int num_pairs, bool use_periodic, scalar3 box_size, scalar_t cutoff_lower, scalar_t cutoff_upper, bool loop, bool include_transpose) { // Each atom traverses the cells around it and finds the neighbors // Atoms for all batches are placed in the same cell list, but other batches are ignored while @@ -419,30 +420,42 @@ forward_kernel(const Accessor sorted_positions, if (j_batch > i_batch) // Particles are sorted by batch after cell, so we can break early here break; - const bool includePair = - (j_batch == i_batch) and - ((orj != ori and (orj < ori or include_transpose)) or (loop and orj == ori)); - if (includePair) { + const bool testPair = + (j_batch == i_batch) and ((orj < ori) or (loop and orj == ori)); + if (testPair) { const scalar3 pj = {sorted_positions[cur_j][0], sorted_positions[cur_j][1], sorted_positions[cur_j][2]}; - const scalar_t dx = pi.x - pj.x; - const scalar_t dy = pi.y - pj.y; - const scalar_t dz = pi.z - pj.z; - const scalar_t distance2 = dx * dx + dy * dy + dz * dz; + scalar3 delta = {pi.x - pj.x, pi.y - pj.y, pi.z - pj.z}; + if (use_periodic) { + delta = apply_pbc(delta, box_size); + } + const scalar_t distance2 = + delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; const scalar_t cutoff_upper2 = cutoff_upper * cutoff_upper; const scalar_t cutoff_lower2 = cutoff_lower * cutoff_lower; - if ((distance2 <= cutoff_upper2 and distance2 >= cutoff_lower2) or + if ((distance2 < cutoff_upper2 and distance2 >= cutoff_lower2) or (loop and orj == ori)) { - const int32_t i_pair = atomicAdd(&i_curr_pair[0], 1); + const bool requires_transpose = (orj != ori) and include_transpose; + const int32_t i_pair = + atomicAdd(&i_curr_pair[0], requires_transpose ? 2 : 1); // We handle too many neighbors outside of the kernel if (i_pair < neighbors.size(1)) { + const scalar_t sqrt_distance2 = sqrt_(distance2); neighbors[0][i_pair] = ori; neighbors[1][i_pair] = orj; - deltas[i_pair][0] = dx; - deltas[i_pair][1] = dy; - deltas[i_pair][2] = dz; - distances[i_pair] = sqrt_(distance2); + deltas[i_pair][0] = delta.x; + deltas[i_pair][1] = delta.y; + deltas[i_pair][2] = delta.z; + distances[i_pair] = sqrt_distance2; + if (requires_transpose) { + neighbors[0][i_pair + 1] = orj; + neighbors[1][i_pair + 1] = ori; + deltas[i_pair + 1][0] = -delta.x; + deltas[i_pair + 1][1] = -delta.y; + deltas[i_pair + 1][2] = -delta.z; + distances[i_pair + 1] = sqrt_distance2; + } } // endif } // endif } // endfor @@ -454,9 +467,10 @@ forward_kernel(const Accessor sorted_positions, class Autograd : public Function { public: static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, - const Tensor& box_size, const Scalar& cutoff_lower, - const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop, - bool include_transpose, bool checkErrors) { + const Tensor& box_size_gpu, bool use_periodic, + const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Scalar& max_num_pairs, bool loop, bool include_transpose, + bool checkErrors) { // The algorithm for the cell list construction can be summarized in three separate steps: // 1. Hash (label) the particles according to the cell (bin) they lie in. // 2. Sort the particles and hashes using the hashes as the ordering label @@ -465,7 +479,15 @@ public: // 3. Identify where each cell starts and ends in the sorted particle positions // array. checkInput(positions, batch); - TORCH_CHECK(box_size.size(0) == 3, "Expected \"box_size\" to have 3 elements"); + auto box_size = box_size_gpu.cpu(); + TORCH_CHECK(box_size.dim() == 2, "Expected \"box_size\" to have two dimensions"); + TORCH_CHECK(box_size.size(0) == 3 && box_size.size(1) == 3, + "Expected \"box_size\" to have shape (3, 3)"); + //Ensure that box size has no non-zero values outside of the diagonal + TORCH_CHECK(box_size[0][1].item() == 0 && box_size[0][2].item() == 0 && + box_size[1][0].item() == 0 && box_size[1][2].item() == 0 && + box_size[2][0].item() == 0 && box_size[2][1].item() == 0, + "Expected \"box_size\" to be diagonal"); const auto max_num_pairs_ = max_num_pairs.toLong(); TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); const int num_atoms = positions.size(0); @@ -489,9 +511,9 @@ public: const scalar_t cutoff_upper_ = cutoff_upper.to(); TORCH_CHECK(cutoff_upper_ > 0, "Expected cutoff_upper to be positive"); const scalar_t cutoff_lower_ = cutoff_lower.to(); - const scalar3 box_size_ = {box_size[0].item(), - box_size[1].item(), - box_size[2].item()}; + const scalar3 box_size_ = {box_size[0][0].item(), + box_size[1][1].item(), + box_size[2][2].item()}; const int threads = 128; const int blocks = (num_atoms + threads - 1) / threads; forward_kernel<<>>( @@ -500,8 +522,8 @@ public: get_accessor(cell_start), get_accessor(cell_end), get_accessor(neighbors), get_accessor(deltas), get_accessor(distances), get_accessor(i_curr_pair), - num_atoms, num_pairs, box_size_, cutoff_lower_, cutoff_upper_, loop, - include_transpose); + num_atoms, num_pairs, use_periodic, box_size_, + cutoff_lower_, cutoff_upper_, loop, include_transpose); }); } // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA @@ -548,12 +570,12 @@ public: TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { m.impl("get_neighbor_pairs_cell", - [](const Tensor& positions, const Tensor& batch, const Tensor& box_size, - const Scalar& cutoff_lower, const Scalar& cutoff_upper, const Scalar& max_num_pairs, - bool loop, bool include_transpose, bool checkErrors) { - const tensor_list results = - Autograd::apply(positions, batch, box_size, cutoff_lower, cutoff_upper, - max_num_pairs, loop, include_transpose, checkErrors); + [](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, + bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Scalar& max_num_pairs, bool loop, bool include_transpose, bool checkErrors) { + const tensor_list results = Autograd::apply( + positions, batch, box_vectors, use_periodic, cutoff_lower, cutoff_upper, + max_num_pairs, loop, include_transpose, checkErrors); return std::make_tuple(results[0], results[1], results[2]); }); } From 730ecd458eea34f7e18cf47960d6ad4e5561c9ef Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Sat, 6 May 2023 19:44:18 +0200 Subject: [PATCH 25/76] Add shared-memory nbody algorithm Move error checking python-side update tests update benchmarks --- benchmarks/neighbors.py | 54 ++-- setup.py | 2 +- tests/test_neighbors.py | 11 +- torchmdnet/models/utils.py | 28 +- torchmdnet/neighbors/__init__.py | 3 +- torchmdnet/neighbors/common.cuh | 120 +++++++++ torchmdnet/neighbors/neighbors.cpp | 5 +- torchmdnet/neighbors/neighbors_cpu.cpp | 35 ++- torchmdnet/neighbors/neighbors_cuda.cu | 279 ++++++++++++-------- torchmdnet/neighbors/neighbors_cuda_cell.cu | 137 +++------- 10 files changed, 406 insertions(+), 268 deletions(-) create mode 100644 torchmdnet/neighbors/common.cuh diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py index 6d468a879..a2874c08d 100644 --- a/benchmarks/neighbors.py +++ b/benchmarks/neighbors.py @@ -6,7 +6,8 @@ -def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_num_neighbors=32): + +def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_num_neighbors, density): """Benchmark the neighbor list generation. Parameters @@ -19,13 +20,17 @@ def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_n Number of batches to generate. total_num_particles : int Total number of particles. + mean_num_neighbors : int + Mean number of neighbors per particle. + density : float + Density of the system. Returns ------- float Average time per batch in seconds. """ - density = 0.7; torch.random.manual_seed(12344) + np.random.seed(43211) num_particles = total_num_particles // n_batches expected_num_neighbors = mean_num_neighbors cutoff = np.cbrt(3 * expected_num_neighbors / (4 * np.pi * density)); @@ -51,21 +56,16 @@ def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_n if strategy != 'distance': max_num_pairs = (expected_num_neighbors * n_atoms_per_batch.sum()).item()*2 box = torch.eye(3, device=device)*lbox - nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box) + nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, loop=False, include_transpose=True, resize_to_fit=False) else: max_num_neighbors = int(expected_num_neighbors*5) nl = Distance(loop=False, cutoff_lower=0.0, cutoff_upper=cutoff, max_num_neighbors=max_num_neighbors) #Warmup for i in range(10): neighbors, distances, distance_vecs = nl(pos, batch) - #print - print("Batch with largest number of atoms: {}".format(int(n_atoms_per_batch.max()))) - print("Batch with smallest number of atoms: {}".format(int(n_atoms_per_batch.min()))) - print("Number of pairs: {}, Number of particles: {}".format(int(neighbors.shape[1]), int(n_atoms_per_batch.to(torch.double). - sum().item()))) if device == 'cuda': torch.cuda.synchronize() - nruns = 10 + nruns = 50 if device == 'cuda': torch.cuda.synchronize() @@ -83,35 +83,38 @@ def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_n if __name__ == '__main__': n_particles = 32767 - mean_num_neighbors = min(n_particles, 16); + mean_num_neighbors = min(n_particles, 64); + density=0.5 print("Benchmarking neighbor list generation for {} particles with {} neighbors on average".format(n_particles, mean_num_neighbors)) results = {} - batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] - for strategy in ['brute', 'cell', 'distance']: - print("Strategy: {}".format(strategy)) - print("--------") - print("{:<10} {:<10}".format("Batch size", "Time (ms)")) - print("{:<10} {:<10}".format("----------", "---------")) + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + for strategy in ['shared', 'brute', 'cell', 'distance']: + # print("Strategy: {}".format(strategy)) + # print("--------") + # print("{:<10} {:<10}".format("Batch size", "Time (ms)")) + # print("{:<10} {:<10}".format("----------", "---------")) #Loop over different number of batches, random for n_batches in batch_sizes: time = benchmark_neighbors(device='cuda', strategy=strategy, n_batches=n_batches, total_num_particles=n_particles, - mean_num_neighbors=mean_num_neighbors + mean_num_neighbors=mean_num_neighbors, + density=density ) #Store results in a dictionary results[strategy, n_batches] = time - print("{:<10} {:<10.2f}".format(n_batches, time)) - print("\n") + #print("{:<10} {:<10.2f}".format(n_batches, time)) print("Summary") print("-------") - print("{:<10} {:<21} {:<18} {:<10}".format("Batch size", "Brute(ms)", "Cell(ms)", "Distance(ms)")) - print("{:<10} {:<21} {:<18} {:<10}".format("----------", "---------", "--------", "-----------")) + print("{:<10} {:<21} {:<21} {:<18} {:<10}".format("Batch size", "Shared(ms)", "Brute(ms)", "Cell(ms)", "Distance(ms)")) + print("{:<10} {:<21} {:<21} {:<18} {:<10}".format("----------", "---------", "---------", "---------", "---------")) #Print a column per strategy, show speedup over Distance in parenthesis for n_batches in batch_sizes: base = results['distance', n_batches] - print("{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format(n_batches, + print("{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format(n_batches, + results['shared', n_batches], + base/results['shared', n_batches], results['brute', n_batches], base/results['brute', n_batches], results['cell', n_batches], @@ -121,10 +124,11 @@ def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_n #Print a second table showing time per atom, show in ns print("\n") print("Time per atom") - print("{:<10} {:<10} {:<10} {:<10}".format("Batch size", "Brute(ns)", "Cell(ns)", "Distance(ns)")) - print("{:<10} {:<10} {:<10} {:<10}".format("----------", "---------", "--------", "-----------")) + print("{:<10} {:<10} {:<10} {:<10} {:<10}".format("Batch size", "Shared(ns)", "Brute(ns)", "Cell(ns)", "Distance(ns)")) + print("{:<10} {:<10} {:<10} {:<10} {:<10}".format("----------", "---------", "---------", "---------", "---------")) for n_batches in batch_sizes: - print("{:<10} {:<10.2f} {:<10.2f} {:<10.2f}".format(n_batches, + print("{:<10} {:<10.2f} {:<10.2f} {:<10.2f} {:<10.2f}".format(n_batches, + results['shared', n_batches]/n_particles*1e6, results['brute', n_batches]/n_particles*1e6, results['cell', n_batches]/n_particles*1e6, results['distance', n_batches]/n_particles*1e6)) diff --git a/setup.py b/setup.py index 619d760de..dcf242017 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ name="torchmd-net", version=version, packages=find_packages(), - package_data={"torchmdnet": ["neighbors/neighbors*"]}, + package_data={"torchmdnet": ["neighbors/neighbors*", "neighbors/common.cuh"]}, include_package_data=True, entry_points={"console_scripts": ["torchmd-train = torchmdnet.scripts.train:main"]}, ) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index d46031d47..d84a59368 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -45,7 +45,7 @@ def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box_vecto return ref_neighbors, ref_distance_vecs, ref_distances @pytest.mark.parametrize("device", ["cpu", "cuda"]) -@pytest.mark.parametrize("strategy", ["brute", "cell"]) +@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"]) @pytest.mark.parametrize("n_batches", [1, 2, 3, 4, 128]) @pytest.mark.parametrize("cutoff", [0.1, 1.0, 3.0, 4.9]) @pytest.mark.parametrize("loop", [True, False]) @@ -56,6 +56,8 @@ def test_neighbors(device, strategy, n_batches, cutoff, loop, include_transpose, pytest.skip("CUDA not available") if box_type == "triclinic" and strategy == "cell": pytest.skip("Triclinic only supported for brute force") + if device=="cpu" and strategy!="brute": + pytest.skip("Only brute force supported on CPU") torch.manual_seed(4321) n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch).to(device) @@ -89,13 +91,16 @@ def test_neighbors(device, strategy, n_batches, cutoff, loop, include_transpose, @pytest.mark.parametrize("device", ["cpu", "cuda"]) -@pytest.mark.parametrize("strategy", ["brute", "cell"]) +@pytest.mark.parametrize("strategy", ["brute", "cell", "shared"]) @pytest.mark.parametrize("n_batches", [1, 2, 3, 4]) @pytest.mark.parametrize("cutoff", [0.1, 1.0, 1000.0]) @pytest.mark.parametrize("loop", [True, False]) def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") + if device=="cpu" and strategy!="brute": + pytest.skip("Only brute force supported on CPU") + torch.manual_seed(4321) n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.long), n_atoms_per_batch).to(device) @@ -130,7 +135,7 @@ def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop): -@pytest.mark.parametrize("strategy", ["brute", "cell"]) +@pytest.mark.parametrize("strategy", ["brute", "cell", "shared"]) @pytest.mark.parametrize("n_batches", [1, 2, 3, 4]) def test_large_size(strategy, n_batches): device = "cuda" diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 7ddd2dccf..e06721367 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -76,18 +76,22 @@ def forward(self, z, x, edge_index, edge_weight, edge_attr): def message(self, x_j, W): return x_j * W -from torchmdnet.neighbors import get_neighbor_pairs, get_neighbor_pairs_cell +from torchmdnet.neighbors import get_neighbor_pairs_brute, get_neighbor_pairs_cell, get_neighbor_pairs_shared class DistanceCellList(torch.nn.Module): + + _backends = { "brute": get_neighbor_pairs_brute, "cell": get_neighbor_pairs_cell, "shared": get_neighbor_pairs_shared } + def __init__( self, - cutoff_upper, cutoff_lower=0.0, + cutoff_upper=5.0, max_num_pairs=32, return_vecs=False, loop=False, strategy="brute", include_transpose=True, resize_to_fit=True, + check_errors=False, box=None ): super(DistanceCellList, self).__init__() @@ -100,6 +104,7 @@ def __init__( Upper cutoff for the neighbor list. max_num_pairs : int Maximum number of pairs to store. + If negative, it is interpreted as (minus) the maximum number of neighbors per atom. strategy : str Strategy to use for computing the neighbor list. Can be one of ["brute", "cell"]. @@ -112,6 +117,8 @@ def __init__( Whether to include the transpose of the neighbor list. resize_to_fit : bool Whether to resize the neighbor list to the actual number of pairs found. + check_errors : bool + Whether to check for too many pairs. return_vecs : bool Whether to return the distance vectors. """ @@ -132,6 +139,11 @@ def __init__( lbox = cutoff_upper * 3.0 self.box = torch.tensor([[lbox, 0, 0], [0, lbox, 0], [0, 0, lbox]]) + self.kernel = self._backends[self.strategy] + if self.kernel is None: + raise ValueError("Unknown strategy: {}".format(self.strategy)) + self.check_errors = check_errors + def forward(self, pos, batch): """ Parameters @@ -156,22 +168,26 @@ def forward(self, pos, batch): otherwise the tensors will have size max_num_pairs, with neighbor pairs (-1, -1) at the end. """ - function = get_neighbor_pairs if self.strategy == "brute" else get_neighbor_pairs_cell if self.box is None: self.box = torch.empty((0, 0), dtype=pos.dtype) self.box = self.box.to(pos.dtype).to(pos.device) - neighbors, distance_vecs, distances = function( + max_pairs = self.max_num_pairs + if self.max_num_pairs < 0: + max_pairs = -self.max_num_pairs*pos.shape[0] + neighbors, distance_vecs, distances, num_pairs = self.kernel( pos, cutoff_lower=self.cutoff_lower, cutoff_upper=self.cutoff_upper, loop=self.loop, batch=batch, - max_num_pairs=self.max_num_pairs, - check_errors=True, + max_num_pairs=max_pairs, include_transpose=self.include_transpose, box_vectors=self.box, use_periodic=self.use_periodic ) + if self.check_errors: + if num_pairs[0] > self.max_num_pairs: + raise RuntimeError("Found num_pairs({}) > max_num_pairs({})".format(num_pairs[0], self.max_num_pairs)) #Remove (-1,-1) pairs if self.resize_to_fit: mask = neighbors[0] != -1 diff --git a/torchmdnet/neighbors/__init__.py b/torchmdnet/neighbors/__init__.py index 6282fb558..2664da064 100644 --- a/torchmdnet/neighbors/__init__.py +++ b/torchmdnet/neighbors/__init__.py @@ -7,5 +7,6 @@ sources = [os.path.join(src_dir, name) for name in sources] cpp_extension.load(name='neighbors', sources=sources, is_python_module=False) -get_neighbor_pairs = pt.ops.neighbors.get_neighbor_pairs +get_neighbor_pairs_brute = pt.ops.neighbors.get_neighbor_pairs_brute +get_neighbor_pairs_shared = pt.ops.neighbors.get_neighbor_pairs_shared get_neighbor_pairs_cell = pt.ops.neighbors.get_neighbor_pairs_cell diff --git a/torchmdnet/neighbors/common.cuh b/torchmdnet/neighbors/common.cuh new file mode 100644 index 000000000..b6962f874 --- /dev/null +++ b/torchmdnet/neighbors/common.cuh @@ -0,0 +1,120 @@ +#pragma once +#include +#include +#include +#include + +using c10::cuda::CUDAStreamGuard; +using c10::cuda::getCurrentCUDAStream; +using torch::empty; +using torch::full; +using torch::kInt32; +using torch::Scalar; +using torch::Tensor; +using torch::TensorOptions; +using torch::zeros; +using torch::autograd::AutogradContext; +using torch::autograd::Function; +using torch::autograd::tensor_list; + +template +using Accessor = torch::PackedTensorAccessor32; + +template +inline Accessor get_accessor(const Tensor& tensor) { + return tensor.packed_accessor32(); +}; + +template __device__ __forceinline__ scalar_t sqrt_(scalar_t x){}; +template <> __device__ __forceinline__ float sqrt_(float x) { + return ::sqrtf(x); +}; +template <> __device__ __forceinline__ double sqrt_(double x) { + return ::sqrt(x); +}; + +template struct vec4 { + using type = void; +}; +template <> struct vec4 { + using type = float4; +}; +template <> struct vec4 { + using type = double4; +}; + +template using scalar4 = typename vec4::type; + +template struct vec3 { + using type = void; +}; +template <> struct vec3 { + using type = float3; +}; +template <> struct vec3 { + using type = double3; +}; + +template using scalar3 = typename vec3::type; + +namespace rect { + +/* + * @brief Takes a point to the unit cell in the range [-0.5, 0.5]*box_size using Minimum Image + * Convention + * @param p The point position + * @param box_size The box size + * @return The point in the unit cell + */ +template +__device__ auto apply_pbc(scalar3 p, scalar3 box_size) { + p.x = p.x - floorf(p.x / box_size.x + scalar_t(0.5)) * box_size.x; + p.y = p.y - floorf(p.y / box_size.y + scalar_t(0.5)) * box_size.y; + p.z = p.z - floorf(p.z / box_size.z + scalar_t(0.5)) * box_size.z; + return p; +} + +template +__device__ auto compute_distance(scalar3 pos_i, scalar3 pos_j, + bool use_periodic, scalar3 box_size) { + scalar3 delta = {pos_i.x - pos_j.x, pos_i.y - pos_j.y, pos_i.z - pos_j.z}; + if (use_periodic) { + delta = apply_pbc(delta, box_size); + } + return delta; +} + +} // namespace rect +namespace triclinic { +/* + * @brief Takes a point to the unit cell using Minimum Image + * Convention + * @param p The point position + * @param box_vectors The box vectors (3x3 matrix) + * @return The point in the unit cell + */ +template +__device__ auto apply_pbc(scalar3 delta, const Accessor box_vectors) { + scalar_t scale3 = round(delta.z / box_vectors[2][2]); + delta.x -= scale3 * box_vectors[2][0]; + delta.y -= scale3 * box_vectors[2][1]; + delta.z -= scale3 * box_vectors[2][2]; + scalar_t scale2 = round(delta.y / box_vectors[1][1]); + delta.x -= scale2 * box_vectors[1][0]; + delta.y -= scale2 * box_vectors[1][1]; + scalar_t scale1 = round(delta.x / box_vectors[0][0]); + delta.x -= scale1 * box_vectors[0][0]; + return delta; +} + +template +__device__ auto compute_distance(scalar3 pos_i, scalar3 pos_j, + bool use_periodic, const Accessor box_vectors) { + scalar3 delta = {pos_i.x - pos_j.x, pos_i.y - pos_j.y, pos_i.z - pos_j.z}; + if (use_periodic) { + delta = apply_pbc(delta, box_vectors); + } + return delta; +} + +} // namespace triclinic diff --git a/torchmdnet/neighbors/neighbors.cpp b/torchmdnet/neighbors/neighbors.cpp index e4b5b53bb..e8c3dd950 100644 --- a/torchmdnet/neighbors/neighbors.cpp +++ b/torchmdnet/neighbors/neighbors.cpp @@ -1,6 +1,7 @@ #include TORCH_LIBRARY(neighbors, m) { - m.def("get_neighbor_pairs(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); - m.def("get_neighbor_pairs_cell(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose, bool check_errors) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs)"); + m.def("get_neighbor_pairs_brute(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)"); + m.def("get_neighbor_pairs_shared(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)"); + m.def("get_neighbor_pairs_cell(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)"); } diff --git a/torchmdnet/neighbors/neighbors_cpu.cpp b/torchmdnet/neighbors/neighbors_cpu.cpp index 0dd6bf798..79ae130ac 100644 --- a/torchmdnet/neighbors/neighbors_cpu.cpp +++ b/torchmdnet/neighbors/neighbors_cpu.cpp @@ -16,11 +16,10 @@ using torch::Tensor; using torch::vstack; using torch::indexing::Slice; -static tuple forward(const Tensor& positions, const Tensor& batch, - const Tensor& box_vectors, bool use_periodic, - const Scalar& cutoff_lower, const Scalar& cutoff_upper, - const Scalar& max_num_pairs, bool loop, - bool include_transpose, bool checkErrors) { +static tuple +forward(const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, bool use_periodic, + const Scalar& cutoff_lower, const Scalar& cutoff_upper, const Scalar& max_num_pairs, + bool loop, bool include_transpose) { TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); @@ -68,7 +67,6 @@ static tuple forward(const Tensor& positions, const Tens break; } } - // batch_i = torch.where(batch[current_offset:] == i) const int n_atoms_i = batch_i.size(); Tensor positions_i = index_select(positions, 0, torch::tensor(batch_i, kInt32)); @@ -81,9 +79,12 @@ static tuple forward(const Tensor& positions, const Tens Tensor deltas_i = index_select(positions_i, 0, rows_i) - index_select(positions_i, 0, columns_i); if (use_periodic) { - deltas_i -= outer(round(deltas_i.index({Slice(), 2})/box_vectors.index({2, 2})), box_vectors.index({2})); - deltas_i -= outer(round(deltas_i.index({Slice(), 1})/box_vectors.index({1, 1})), box_vectors.index({1})); - deltas_i -= outer(round(deltas_i.index({Slice(), 0})/box_vectors.index({0, 0})), box_vectors.index({0})); + deltas_i -= outer(round(deltas_i.index({Slice(), 2}) / box_vectors.index({2, 2})), + box_vectors.index({2})); + deltas_i -= outer(round(deltas_i.index({Slice(), 1}) / box_vectors.index({1, 1})), + box_vectors.index({1})); + deltas_i -= outer(round(deltas_i.index({Slice(), 0}) / box_vectors.index({0, 0})), + box_vectors.index({0})); } Tensor distances_i = frobenius_norm(deltas_i, 1); const Tensor mask_upper = distances_i < cutoff_upper; @@ -112,15 +113,21 @@ static tuple forward(const Tensor& positions, const Tens } deltas = index_select(positions, 0, neighbors[0]) - index_select(positions, 0, neighbors[1]); if (use_periodic) { - deltas -= outer(round(deltas.index({Slice(), 2})/box_vectors.index({2, 2})), box_vectors.index({2})); - deltas -= outer(round(deltas.index({Slice(), 1})/box_vectors.index({1, 1})), box_vectors.index({1})); - deltas -= outer(round(deltas.index({Slice(), 0})/box_vectors.index({0, 0})), box_vectors.index({0})); + deltas -= outer(round(deltas.index({Slice(), 2}) / box_vectors.index({2, 2})), + box_vectors.index({2})); + deltas -= outer(round(deltas.index({Slice(), 1}) / box_vectors.index({1, 1})), + box_vectors.index({1})); + deltas -= outer(round(deltas.index({Slice(), 0}) / box_vectors.index({0, 0})), + box_vectors.index({0})); } distances = frobenius_norm(deltas, 1); - return {neighbors, deltas, distances}; + Tensor num_pairs_found = torch::empty(1, distances.options().dtype(kInt32)); + num_pairs_found[0] = distances.size(0); + return {neighbors, deltas, distances, num_pairs_found}; } TORCH_LIBRARY_IMPL(neighbors, CPU, m) { - m.impl("get_neighbor_pairs", &forward); + m.impl("get_neighbor_pairs_brute", &forward); + m.impl("get_neighbor_pairs_shared", &forward); m.impl("get_neighbor_pairs_cell", &forward); } diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index 579894fc8..a5120b7a3 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -1,40 +1,7 @@ +#include "common.cuh" #include -#include -#include #include #include -using c10::cuda::CUDAStreamGuard; -using c10::cuda::getCurrentCUDAStream; -using std::make_tuple; -using std::max; -using torch::empty; -using torch::full; -using torch::kInt32; -using torch::PackedTensorAccessor32; -using torch::RestrictPtrTraits; -using torch::Scalar; -using torch::Tensor; -using torch::TensorOptions; -using torch::zeros; -using torch::autograd::AutogradContext; -using torch::autograd::Function; -using torch::autograd::tensor_list; - -template -using Accessor = PackedTensorAccessor32; - -template -inline Accessor get_accessor(const Tensor& tensor) { - return tensor.packed_accessor32(); -}; - -template __device__ __forceinline__ scalar_t sqrt_(scalar_t x){}; -template <> __device__ __forceinline__ float sqrt_(float x) { - return ::sqrtf(x); -}; -template <> __device__ __forceinline__ double sqrt_(double x) { - return ::sqrt(x); -}; __device__ uint32_t get_row(uint32_t index) { uint32_t row = floor((sqrtf(8 * index + 1) + 1) / 2); @@ -44,58 +11,41 @@ __device__ uint32_t get_row(uint32_t index) { } template -__device__ auto apply_pbc(scalar_t delta_x, scalar_t delta_y, scalar_t delta_z, - const Accessor box_vectors) { - scalar_t scale3 = round(delta_z / box_vectors[2][2]); - delta_x -= scale3 * box_vectors[2][0]; - delta_y -= scale3 * box_vectors[2][1]; - delta_z -= scale3 * box_vectors[2][2]; - scalar_t scale2 = round(delta_y / box_vectors[1][1]); - delta_x -= scale2 * box_vectors[1][0]; - delta_y -= scale2 * box_vectors[1][1]; - scalar_t scale1 = round(delta_x / box_vectors[0][0]); - delta_x -= scale1 * box_vectors[0][0]; - return thrust::make_tuple(delta_x, delta_y, delta_z); -} - -template -__global__ void forward_kernel(uint32_t num_all_pairs, const Accessor positions, - const Accessor batch, scalar_t cutoff_lower2, - scalar_t cutoff_upper2, bool loop, bool include_transpose, - Accessor i_curr_pair, Accessor neighbors, - Accessor deltas, Accessor distances, - bool use_periodic, const Accessor box_vectors) { +__global__ void forward_kernel_brute(uint32_t num_all_pairs, const Accessor positions, + const Accessor batch, scalar_t cutoff_lower2, + scalar_t cutoff_upper2, bool loop, bool include_transpose, + Accessor i_curr_pair, + Accessor neighbors, Accessor deltas, + Accessor distances, bool use_periodic, + const Accessor box_vectors) { const uint32_t index = blockIdx.x * blockDim.x + threadIdx.x; if (index >= num_all_pairs) return; const uint32_t row = get_row(index); const uint32_t column = (index - row * (row - 1) / 2); if (batch[row] == batch[column]) { - scalar_t delta_x = positions[row][0] - positions[column][0]; - scalar_t delta_y = positions[row][1] - positions[column][1]; - scalar_t delta_z = positions[row][2] - positions[column][2]; - if (use_periodic) { - thrust::tie(delta_x, delta_y, delta_z) = - apply_pbc(delta_x, delta_y, delta_z, box_vectors); - } - const scalar_t distance2 = delta_x * delta_x + delta_y * delta_y + delta_z * delta_z; + const scalar3 pos_i{positions[row][0], positions[row][1], positions[row][2]}; + const scalar3 pos_j{positions[column][0], positions[column][1], + positions[column][2]}; + const auto delta = triclinic::compute_distance(pos_i, pos_j, use_periodic, box_vectors); + const scalar_t distance2 = delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) { const int32_t i_pair = atomicAdd(&i_curr_pair[0], include_transpose ? 2 : 1); // We handle too many neighbors outside of the kernel - if (i_pair < neighbors.size(1)) { + if (i_pair + include_transpose < neighbors.size(1)) { const scalar_t r2 = sqrt_(distance2); neighbors[0][i_pair] = row; neighbors[1][i_pair] = column; - deltas[i_pair][0] = delta_x; - deltas[i_pair][1] = delta_y; - deltas[i_pair][2] = delta_z; + deltas[i_pair][0] = delta.x; + deltas[i_pair][1] = delta.y; + deltas[i_pair][2] = delta.z; distances[i_pair] = r2; if (include_transpose) { neighbors[0][i_pair + 1] = column; neighbors[1][i_pair + 1] = row; - deltas[i_pair + 1][0] = -delta_x; - deltas[i_pair + 1][1] = -delta_y; - deltas[i_pair + 1][2] = -delta_z; + deltas[i_pair + 1][0] = -delta.x; + deltas[i_pair + 1][1] = -delta.y; + deltas[i_pair + 1][2] = -delta.z; distances[i_pair + 1] = r2; } } @@ -121,6 +71,85 @@ __global__ void add_self_kernel(const int num_atoms, Accessor posit } } +template +__global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor positions, + const Accessor batch, scalar_t cutoff_lower2, + scalar_t cutoff_upper2, bool loop, bool include_transpose, + Accessor i_curr_pair, + Accessor neighbors, Accessor deltas, + Accessor distances, int32_t num_tiles, + bool use_periodic, const Accessor box_vectors) { + // A thread per atom + const int id = blockIdx.x * blockDim.x + threadIdx.x; + // All threads must pass through __syncthreads, + // but when N is not a multiple of 32 some threads are assigned a particle i>N. + // This threads cant return, so they are masked to not do any work + const bool active = id < num_atoms; + __shared__ scalar3 sh_pos[BLOCKSIZE]; + __shared__ int64_t sh_batch[BLOCKSIZE]; + scalar3 pos_i; + int64_t batch_i; + if(active){ + pos_i = {positions[id][0], positions[id][1], positions[id][2]}; + batch_i = batch[id]; + } + // Distribute the N particles in a group of tiles. Storing in each tile blockDim.x values in + // shared memory. This way all threads are accesing the same memory addresses at the same time + for (int tile = 0; tile < num_tiles; tile++) { + // Load this tiles particles values to shared memory + const int i_load = tile * blockDim.x + threadIdx.x; + if (i_load < num_atoms) { // Even if im not active, my thread may load a value each tile to + // shared memory. + sh_pos[threadIdx.x] = {positions[i_load][0], positions[i_load][1], + positions[i_load][2]}; + sh_batch[threadIdx.x] = batch[i_load]; + } + // Wait for all threads to arrive + __syncthreads(); + // Go through all the particles in the current tile +#pragma unroll 8 + for (int counter = 0; counter < blockDim.x; counter++) { + if (!active) + break; // An out of bounds thread must be masked + const int cur_j = tile * blockDim.x + counter; + const bool testPair = cur_j < num_atoms and (cur_j < id or (loop and cur_j == id)); + if (testPair) { + const auto batch_j = sh_batch[counter]; + if (batch_i == batch_j) { + const auto pos_j = sh_pos[counter]; + const auto delta = + triclinic::compute_distance(pos_i, pos_j, use_periodic, box_vectors); + const scalar_t distance2 = + delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; + if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) { + const bool requires_transpose = include_transpose && !(cur_j == id); + const int32_t i_pair = + atomicAdd(&i_curr_pair[0], requires_transpose ? 2 : 1); + if (i_pair + requires_transpose < neighbors.size(1)) { + const auto distance = sqrt_(distance2); + neighbors[0][i_pair] = id; + neighbors[1][i_pair] = cur_j; + deltas[i_pair][0] = delta.x; + deltas[i_pair][1] = delta.y; + deltas[i_pair][2] = delta.z; + distances[i_pair] = distance; + if (requires_transpose) { + neighbors[0][i_pair + 1] = cur_j; + neighbors[1][i_pair + 1] = id; + deltas[i_pair + 1][0] = -delta.x; + deltas[i_pair + 1][1] = -delta.y; + deltas[i_pair + 1][2] = -delta.z; + distances[i_pair + 1] = distance; + } + } + } + } + } + } + __syncthreads(); + } +} + template __global__ void backward_kernel(const Accessor neighbors, const Accessor deltas, @@ -166,13 +195,15 @@ static void checkInput(const Tensor& positions, const Tensor& batch) { TORCH_CHECK(batch.dtype() == torch::kInt64, "Expected \"batch\" to be of type torch::kLong"); } +enum class strategy { brute, shared }; + class Autograd : public Function { public: static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, const Scalar& cutoff_lower, const Scalar& cutoff_upper, const Tensor& box_vectors, bool use_periodic, const Scalar& max_num_pairs, bool loop, bool include_transpose, - bool checkErrors) { + strategy strat) { checkInput(positions, batch); const auto max_num_pairs_ = max_num_pairs.toLong(); TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); @@ -191,54 +222,69 @@ public: const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); { const CUDAStreamGuard guard(stream); - const uint64_t num_atoms = positions.size(0); - const uint64_t num_all_pairs = num_atoms * (num_atoms - 1ul) / 2ul; - const uint64_t num_threads = 128; - const uint64_t num_blocks = max((num_all_pairs + num_threads - 1ul) / num_threads, 1ul); - AT_DISPATCH_FLOATING_TYPES( - positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { - const scalar_t cutoff_upper_ = cutoff_upper.to(); - const scalar_t cutoff_lower_ = cutoff_lower.to(); - TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); - forward_kernel<<>>( - num_all_pairs, get_accessor(positions), - get_accessor(batch), cutoff_lower_ * cutoff_lower_, - cutoff_upper_ * cutoff_upper_, loop, include_transpose, - get_accessor(i_curr_pair), get_accessor(neighbors), - get_accessor(deltas), get_accessor(distances), - use_periodic, get_accessor(box_vectors)); - if (loop) { - const uint64_t num_threads = 128; - const uint64_t num_blocks = - max((num_atoms + num_threads - 1ul) / num_threads, 1ul); - add_self_kernel<<>>( + const int32_t num_atoms = positions.size(0); + if (strat == strategy::brute) { + const uint64_t num_all_pairs = num_atoms * (num_atoms - 1ul) / 2ul; + const uint64_t num_threads = 128; + const uint64_t num_blocks = + std::max((num_all_pairs + num_threads - 1ul) / num_threads, 1ul); + AT_DISPATCH_FLOATING_TYPES( + positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { + const scalar_t cutoff_upper_ = cutoff_upper.to(); + const scalar_t cutoff_lower_ = cutoff_lower.to(); + TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); + forward_kernel_brute<<>>( + num_all_pairs, get_accessor(positions), + get_accessor(batch), cutoff_lower_ * cutoff_lower_, + cutoff_upper_ * cutoff_upper_, loop, include_transpose, + get_accessor(i_curr_pair), + get_accessor(neighbors), get_accessor(deltas), + get_accessor(distances), use_periodic, + get_accessor(box_vectors)); + if (loop) { + const uint64_t num_threads = 128; + const uint64_t num_blocks = + std::max((num_atoms + num_threads - 1ul) / num_threads, 1ul); + add_self_kernel<<>>( + num_atoms, get_accessor(positions), + get_accessor(i_curr_pair), + get_accessor(neighbors), + get_accessor(deltas), + get_accessor(distances)); + } + }); + } else if (strat == strategy::shared) { + AT_DISPATCH_FLOATING_TYPES( + positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() { + const scalar_t cutoff_upper_ = cutoff_upper.to(); + const scalar_t cutoff_lower_ = cutoff_lower.to(); + TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); + constexpr int BLOCKSIZE = 64; + const int num_blocks = std::max((num_atoms + BLOCKSIZE - 1) / BLOCKSIZE, 1); + const int num_threads = BLOCKSIZE; + const int num_tiles = num_blocks; + forward_kernel_shared<<>>( num_atoms, get_accessor(positions), + get_accessor(batch), cutoff_lower_ * cutoff_lower_, + cutoff_upper_ * cutoff_upper_, loop, include_transpose, get_accessor(i_curr_pair), get_accessor(neighbors), get_accessor(deltas), - get_accessor(distances)); - } - }); - } - // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA - // graphs - if (checkErrors) { - int num_found_pairs = i_curr_pair.item(); - TORCH_CHECK(num_found_pairs <= max_num_pairs_, - "Too many neighbor pairs found. Maximum is " + - std::to_string(max_num_pairs_), - " but found " + std::to_string(num_found_pairs)); + get_accessor(distances), num_tiles, use_periodic, + get_accessor(box_vectors)); + }); + } } ctx->save_for_backward({neighbors, deltas, distances}); ctx->saved_data["num_atoms"] = num_atoms; - return {neighbors, deltas, distances}; + return {neighbors, deltas, distances, i_curr_pair}; } static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) { const Tensor grad_distances = grad_inputs[1]; const int num_atoms = ctx->saved_data["num_atoms"].toInt(); const int num_pairs = grad_distances.size(0); - const int num_threads = 128; - const int num_blocks_x = max((num_pairs + num_threads - 1) / num_threads, 1); + const int num_threads = 32; + const int num_blocks_x = std::max((num_pairs + num_threads - 1) / num_threads, 1); const dim3 blocks(num_blocks_x, 2, 3); const auto stream = getCurrentCUDAStream(grad_distances.get_device()); @@ -257,18 +303,27 @@ public: get_accessor(grad_positions)); }); - return {grad_positions, Tensor(), Tensor()}; + return {grad_positions, Tensor(), Tensor(), Tensor()}; } }; TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { - m.impl("get_neighbor_pairs", + m.impl("get_neighbor_pairs_brute", + [](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, + bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Scalar& max_num_pairs, bool loop, bool include_transpose) { + const tensor_list results = Autograd::apply( + positions, batch, cutoff_lower, cutoff_upper, box_vectors, use_periodic, + max_num_pairs, loop, include_transpose, strategy::brute); + return std::make_tuple(results[0], results[1], results[2], results[3]); + }); + m.impl("get_neighbor_pairs_shared", [](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, - const Scalar& max_num_pairs, bool loop, bool include_transpose, bool checkErrors) { + const Scalar& max_num_pairs, bool loop, bool include_transpose) { const tensor_list results = Autograd::apply( positions, batch, cutoff_lower, cutoff_upper, box_vectors, use_periodic, - max_num_pairs, loop, include_transpose, checkErrors); - return std::make_tuple(results[0], results[1], results[2]); + max_num_pairs, loop, include_transpose, strategy::shared); + return std::make_tuple(results[0], results[1], results[2], results[3]); }); } diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index 42c44a716..0ed3bf5bc 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -1,51 +1,10 @@ /* Raul P. Pelaez 2023. Batched cell list neighbor list implementation for CUDA. */ -#include -#include -#include -#include +#include "common.cuh" #include #include #include -#include -#include -using c10::cuda::CUDAStreamGuard; -using c10::cuda::getCurrentCUDAStream; -using std::make_tuple; -using std::max; -using torch::empty; -using torch::full; -using torch::kInt32; -using torch::PackedTensorAccessor32; -using torch::RestrictPtrTraits; -using torch::Scalar; -using torch::Tensor; -using torch::TensorOptions; -using torch::zeros; -using torch::autograd::AutogradContext; -using torch::autograd::Function; -using torch::autograd::tensor_list; - -template struct scalar3 { - scalar_t x, y, z; -}; - -template -using Accessor = PackedTensorAccessor32; - -template -inline Accessor get_accessor(const Tensor& tensor) { - return tensor.packed_accessor32(); -}; - -template __device__ __forceinline__ scalar_t sqrt_(scalar_t x){}; -template <> __device__ __forceinline__ float sqrt_(float x) { - return ::sqrtf(x); -}; -template <> __device__ __forceinline__ double sqrt_(double x) { - return ::sqrt(x); -}; template __global__ void @@ -120,21 +79,6 @@ inline __host__ __device__ uint hashMorton(int3 ci) { return encodeMorton(ci.x) | (encodeMorton(ci.y) << 1) | (encodeMorton(ci.z) << 2); } -/* - * @brief Takes a point to the unit cell in the range [-0.5, 0.5]*box_size using Minimum Image - * Convention - * @param p The point position - * @param box_size The box size - * @return The point in the unit cell - */ -template -__device__ auto apply_pbc(scalar3 p, scalar3 box_size) { - p.x = p.x - floorf(p.x / box_size.x + scalar_t(0.5)) * box_size.x; - p.y = p.y - floorf(p.y / box_size.y + scalar_t(0.5)) * box_size.y; - p.z = p.z - floorf(p.z / box_size.z + scalar_t(0.5)) * box_size.z; - return p; -} - /* * @brief Calculates the cell dimensions for a given box size and cutoff * @param box_size The box size @@ -166,7 +110,7 @@ __host__ __device__ int3 getCellDimensions(scalar3 box_size, scalar_t */ template __device__ int3 getCell(scalar3 p, scalar3 box_size, scalar_t cutoff) { - p = apply_pbc(p, box_size); + p = rect::apply_pbc(p, box_size); // Take to the [0, box_size] range and divide by cutoff (which is the cell size) int cx = floorf((p.x + scalar_t(0.5) * box_size.x) / cutoff); int cy = floorf((p.y + scalar_t(0.5) * box_size.y) / cutoff); @@ -263,7 +207,6 @@ public: */ static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, const Tensor& box_size, const Scalar& cutoff) { - const int num_atoms = positions.size(0); const auto options = positions.options(); thrust::device_vector hash_keys(num_atoms); @@ -274,7 +217,7 @@ static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "assignHash", [&] { scalar_t cutoff_ = cutoff.to(); scalar3 box_size_ = {box_size[0][0].item(), - box_size[1][1].item(), + box_size[1][1].item(), box_size[2][2].item()}; assignHash<<>>( get_accessor(positions), thrust::raw_pointer_cast(hash_keys.data()), @@ -338,12 +281,11 @@ __global__ void fillCellOffsetsD(const Accessor sorted_positions, static auto fillCellOffsets(const Tensor& sorted_positions, const Tensor& sorted_indices, const Tensor& batch, const Tensor& box_size, const Scalar& cutoff) { const TensorOptions options = sorted_positions.options(); - int3 cell_dim; AT_DISPATCH_FLOATING_TYPES(sorted_positions.scalar_type(), "fillCellOffsets", [&] { scalar_t cutoff_ = cutoff.to(); scalar3 box_size_ = {box_size[0][0].item(), - box_size[1][1].item(), + box_size[1][1].item(), box_size[2][2].item()}; cell_dim = getCellDimensions(box_size_, cutoff_); }); @@ -356,7 +298,7 @@ static auto fillCellOffsets(const Tensor& sorted_positions, const Tensor& sorted auto stream = at::cuda::getCurrentCUDAStream(); scalar_t cutoff_ = cutoff.to(); scalar3 box_size_ = {box_size[0][0].item(), - box_size[1][1].item(), + box_size[1][1].item(), box_size[2][2].item()}; fillCellOffsetsD<<>>( get_accessor(sorted_positions), get_accessor(sorted_indices), @@ -426,35 +368,33 @@ forward_kernel(const Accessor sorted_positions, const scalar3 pj = {sorted_positions[cur_j][0], sorted_positions[cur_j][1], sorted_positions[cur_j][2]}; - scalar3 delta = {pi.x - pj.x, pi.y - pj.y, pi.z - pj.z}; - if (use_periodic) { - delta = apply_pbc(delta, box_size); - } + const auto delta = + rect::compute_distance(pi, pj, use_periodic, box_size); const scalar_t distance2 = delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; const scalar_t cutoff_upper2 = cutoff_upper * cutoff_upper; const scalar_t cutoff_lower2 = cutoff_lower * cutoff_lower; if ((distance2 < cutoff_upper2 and distance2 >= cutoff_lower2) or (loop and orj == ori)) { - const bool requires_transpose = (orj != ori) and include_transpose; + const bool requires_transpose = include_transpose and (orj != ori); const int32_t i_pair = atomicAdd(&i_curr_pair[0], requires_transpose ? 2 : 1); // We handle too many neighbors outside of the kernel - if (i_pair < neighbors.size(1)) { - const scalar_t sqrt_distance2 = sqrt_(distance2); + if (i_pair + requires_transpose < neighbors.size(1)) { + const scalar_t distance = sqrt_(distance2); neighbors[0][i_pair] = ori; neighbors[1][i_pair] = orj; deltas[i_pair][0] = delta.x; deltas[i_pair][1] = delta.y; deltas[i_pair][2] = delta.z; - distances[i_pair] = sqrt_distance2; + distances[i_pair] = distance; if (requires_transpose) { neighbors[0][i_pair + 1] = orj; neighbors[1][i_pair + 1] = ori; deltas[i_pair + 1][0] = -delta.x; deltas[i_pair + 1][1] = -delta.y; deltas[i_pair + 1][2] = -delta.z; - distances[i_pair + 1] = sqrt_distance2; + distances[i_pair + 1] = distance; } } // endif } // endif @@ -469,8 +409,7 @@ public: static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, const Tensor& box_size_gpu, bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, - const Scalar& max_num_pairs, bool loop, bool include_transpose, - bool checkErrors) { + const Scalar& max_num_pairs, bool loop, bool include_transpose) { // The algorithm for the cell list construction can be summarized in three separate steps: // 1. Hash (label) the particles according to the cell (bin) they lie in. // 2. Sort the particles and hashes using the hashes as the ordering label @@ -479,19 +418,18 @@ public: // 3. Identify where each cell starts and ends in the sorted particle positions // array. checkInput(positions, batch); - auto box_size = box_size_gpu.cpu(); + auto box_size = box_size_gpu.cpu(); TORCH_CHECK(box_size.dim() == 2, "Expected \"box_size\" to have two dimensions"); TORCH_CHECK(box_size.size(0) == 3 && box_size.size(1) == 3, "Expected \"box_size\" to have shape (3, 3)"); - //Ensure that box size has no non-zero values outside of the diagonal - TORCH_CHECK(box_size[0][1].item() == 0 && box_size[0][2].item() == 0 && - box_size[1][0].item() == 0 && box_size[1][2].item() == 0 && - box_size[2][0].item() == 0 && box_size[2][1].item() == 0, - "Expected \"box_size\" to be diagonal"); - const auto max_num_pairs_ = max_num_pairs.toLong(); + // Ensure that box size has no non-zero values outside of the diagonal + TORCH_CHECK(box_size[0][1].item() == 0 && box_size[0][2].item() == 0 && + box_size[1][0].item() == 0 && box_size[1][2].item() == 0 && + box_size[2][0].item() == 0 && box_size[2][1].item() == 0, + "Expected \"box_size\" to be diagonal"); + const auto max_num_pairs_ = max_num_pairs.toInt(); TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); const int num_atoms = positions.size(0); - const int num_pairs = max_num_pairs_; const TensorOptions options = positions.options(); // Steps 1 and 2 Tensor sorted_positions, hash_values; @@ -500,9 +438,9 @@ public: Tensor cell_start, cell_end; std::tie(cell_start, cell_end) = fillCellOffsets(sorted_positions, hash_values, batch, box_size, cutoff_upper); - const Tensor neighbors = full({2, num_pairs}, -1, options.dtype(kInt32)); - const Tensor deltas = empty({num_pairs, 3}, options); - const Tensor distances = full(num_pairs, 0, options); + const Tensor neighbors = full({2, max_num_pairs_}, -1, options.dtype(kInt32)); + const Tensor deltas = empty({max_num_pairs_, 3}, options); + const Tensor distances = full(max_num_pairs_, 0, options); const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); const auto stream = getCurrentCUDAStream(positions.get_device()); { // Use the cell list for each batch to find the neighbors @@ -522,22 +460,13 @@ public: get_accessor(cell_start), get_accessor(cell_end), get_accessor(neighbors), get_accessor(deltas), get_accessor(distances), get_accessor(i_curr_pair), - num_atoms, num_pairs, use_periodic, box_size_, - cutoff_lower_, cutoff_upper_, loop, include_transpose); + num_atoms, max_num_pairs_, use_periodic, box_size_, cutoff_lower_, + cutoff_upper_, loop, include_transpose); }); } - // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA - // graphs - if (checkErrors) { - int num_found_pairs = i_curr_pair[0].item(); - TORCH_CHECK(num_found_pairs <= max_num_pairs_, - "Too many neighbor pairs found. Maximum is " + - std::to_string(max_num_pairs_), - " but found " + std::to_string(num_found_pairs)); - } ctx->save_for_backward({neighbors, deltas, distances}); ctx->saved_data["num_atoms"] = num_atoms; - return {neighbors, deltas, distances}; + return {neighbors, deltas, distances, i_curr_pair}; } static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) { @@ -545,7 +474,7 @@ public: const int num_atoms = ctx->saved_data["num_atoms"].toInt(); const int num_pairs = grad_distances.size(0); const int num_threads = 128; - const int num_blocks_x = max((num_pairs + num_threads - 1) / num_threads, 1); + const int num_blocks_x = std::max((num_pairs + num_threads - 1) / num_threads, 1); const dim3 blocks(num_blocks_x, 2, 3); const auto stream = getCurrentCUDAStream(grad_distances.get_device()); @@ -564,7 +493,7 @@ public: get_accessor(grad_positions)); }); - return {grad_positions, Tensor(), Tensor()}; + return {grad_positions, Tensor(), Tensor(), Tensor()}; } }; @@ -572,10 +501,10 @@ TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { m.impl("get_neighbor_pairs_cell", [](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, - const Scalar& max_num_pairs, bool loop, bool include_transpose, bool checkErrors) { - const tensor_list results = Autograd::apply( - positions, batch, box_vectors, use_periodic, cutoff_lower, cutoff_upper, - max_num_pairs, loop, include_transpose, checkErrors); - return std::make_tuple(results[0], results[1], results[2]); + const Scalar& max_num_pairs, bool loop, bool include_transpose) { + const tensor_list results = + Autograd::apply(positions, batch, box_vectors, use_periodic, cutoff_lower, + cutoff_upper, max_num_pairs, loop, include_transpose); + return std::make_tuple(results[0], results[1], results[2], results[3]); }); } From 713dcad3ebccc499e1054eb00535d9ac24e7cae6 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 8 May 2023 15:17:34 +0200 Subject: [PATCH 26/76] Make cpu implementation torch only to get autograd --- torchmdnet/neighbors/neighbors_cpu.cpp | 77 +++++++------------------- 1 file changed, 19 insertions(+), 58 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cpu.cpp b/torchmdnet/neighbors/neighbors_cpu.cpp index 79ae130ac..8f4d8d9da 100644 --- a/torchmdnet/neighbors/neighbors_cpu.cpp +++ b/torchmdnet/neighbors/neighbors_cpu.cpp @@ -51,67 +51,13 @@ forward(const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, } TORCH_CHECK(max_num_pairs.toLong() > 0, "Expected \"max_num_neighbors\" to be positive"); const int n_atoms = positions.size(0); - const int n_batches = batch[n_atoms - 1].item() + 1; - int current_offset = 0; - std::vector batch_i; - int n_pairs = 0; Tensor neighbors = torch::empty({0}, positions.options().dtype(kInt32)); Tensor distances = torch::empty({0}, positions.options()); Tensor deltas = torch::empty({0}, positions.options()); - for (int i = 0; i < n_batches; i++) { - batch_i.clear(); - for (int j = current_offset; j < n_atoms; j++) { - if (batch[j].item() == i) { - batch_i.push_back(j); - } else { - break; - } - } - - const int n_atoms_i = batch_i.size(); - Tensor positions_i = index_select(positions, 0, torch::tensor(batch_i, kInt32)); - Tensor indices_i = - arange(0, n_atoms_i * (n_atoms_i - 1l) / 2l, positions.options().dtype(torch::kLong)); - Tensor rows_i = (((8l * indices_i + 1l).sqrt() + 1l) / 2l).floor().to(torch::kLong); - rows_i -= (rows_i * (rows_i - 1l) > 2l * indices_i).to(torch::kLong); - Tensor columns_i = indices_i - div(rows_i * (rows_i - 1l), 2, "floor"); - Tensor neighbors_i = vstack({rows_i, columns_i}); - Tensor deltas_i = - index_select(positions_i, 0, rows_i) - index_select(positions_i, 0, columns_i); - if (use_periodic) { - deltas_i -= outer(round(deltas_i.index({Slice(), 2}) / box_vectors.index({2, 2})), - box_vectors.index({2})); - deltas_i -= outer(round(deltas_i.index({Slice(), 1}) / box_vectors.index({1, 1})), - box_vectors.index({1})); - deltas_i -= outer(round(deltas_i.index({Slice(), 0}) / box_vectors.index({0, 0})), - box_vectors.index({0})); - } - Tensor distances_i = frobenius_norm(deltas_i, 1); - const Tensor mask_upper = distances_i < cutoff_upper; - const Tensor mask_lower = distances_i >= cutoff_lower; - const Tensor mask = mask_upper * mask_lower; - neighbors_i = neighbors_i.index({Slice(), mask}) + current_offset; - // Add the transposed pairs - if (include_transpose) { - neighbors_i = - torch::hstack({neighbors_i, torch::stack({neighbors_i[1], neighbors_i[0]})}); - } - // Add self interaction using batch_i - if (loop) { - const Tensor batch_i_tensor = torch::tensor(batch_i, kInt32); - neighbors_i = - torch::hstack({neighbors_i, torch::stack({batch_i_tensor, batch_i_tensor})}); - } - n_pairs += neighbors_i.size(1); - TORCH_CHECK(n_pairs >= 0, - "The maximum number of pairs has been exceed! Increase \"max_num_neighbors\""); - neighbors = torch::hstack({neighbors, neighbors_i}); - current_offset += n_atoms_i; - } - if (n_batches > 1) { - neighbors = torch::cat(neighbors, 0).to(kInt32); - } - deltas = index_select(positions, 0, neighbors[0]) - index_select(positions, 0, neighbors[1]); + neighbors = torch::vstack((torch::tril_indices(n_atoms,n_atoms, -1, neighbors.options()))); + auto mask = index_select(batch, 0, neighbors.index({0, Slice()})) == index_select(batch, 0, neighbors.index({1, Slice()})); + neighbors = neighbors.index({Slice(), mask}).to(kInt32); + deltas = index_select(positions, 0, neighbors.index({0, Slice()})) - index_select(positions, 0, neighbors.index({1, Slice()})); if (use_periodic) { deltas -= outer(round(deltas.index({Slice(), 2}) / box_vectors.index({2, 2})), box_vectors.index({2})); @@ -121,6 +67,21 @@ forward(const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, box_vectors.index({0})); } distances = frobenius_norm(deltas, 1); + mask = (distances < cutoff_upper)*(distances >= cutoff_lower); + neighbors = neighbors.index({Slice(), mask}); + deltas = deltas.index({mask, Slice()}); + distances = distances.index({mask}); + if (include_transpose) { + neighbors = torch::hstack({neighbors, torch::stack({neighbors[1], neighbors[0]})}); + distances = torch::hstack({distances, distances}); + deltas = torch::vstack({deltas, -deltas}); + } + if(loop) { + const Tensor range = torch::arange(0, n_atoms, torch::kInt32); + neighbors = torch::hstack({neighbors, torch::stack({range, range})}); + distances = torch::hstack({distances, torch::zeros_like(range)}); + deltas = torch::vstack({deltas, torch::zeros({n_atoms,3}, deltas.options())}); + } Tensor num_pairs_found = torch::empty(1, distances.options().dtype(kInt32)); num_pairs_found[0] = distances.size(0); return {neighbors, deltas, distances, num_pairs_found}; From 9e9cfab6323fb5e88835f4ddee853c7bf86b9cf0 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 8 May 2023 15:18:42 +0200 Subject: [PATCH 27/76] Annotate DistanceCellList Allow batch to be None (defaults to all zeros) --- torchmdnet/models/utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index e06721367..89d0780e0 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -1,6 +1,7 @@ import math -from typing import Optional +from typing import Optional, Tuple import torch +from torch import Tensor from torch import nn import torch.nn.functional as F from torch_geometric.nn import MessagePassing @@ -144,13 +145,13 @@ def __init__( raise ValueError("Unknown strategy: {}".format(self.strategy)) self.check_errors = check_errors - def forward(self, pos, batch): - """ + def forward(self, pos: Tensor, batch: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + """ Compute the neighbor list for a given cutoff. Parameters ---------- pos : torch.Tensor shape (N, 3) - batch : torch.Tensor + batch : torch.Tensor or None shape (N,) Returns ------- @@ -174,6 +175,8 @@ def forward(self, pos, batch): max_pairs = self.max_num_pairs if self.max_num_pairs < 0: max_pairs = -self.max_num_pairs*pos.shape[0] + if batch is None: + batch = torch.zeros(pos.shape[0], dtype=torch.long, device=pos.device) neighbors, distance_vecs, distances, num_pairs = self.kernel( pos, cutoff_lower=self.cutoff_lower, From 3517d6fc1526931e374afe04cc7469f4b8d78638 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 9 May 2023 16:15:44 +0200 Subject: [PATCH 28/76] Add backward pass with corresponding tests Use pytorch CUDACachingAllocator with thrust::sort and for temporary memory --- setup.py | 2 +- tests/test_neighbors.py | 84 +++++++++++++++++- torchmdnet/neighbors/__init__.py | 2 +- torchmdnet/neighbors/common.cuh | 8 +- torchmdnet/neighbors/neighbors_cpu.cpp | 26 +++--- torchmdnet/neighbors/neighbors_cuda.cu | 52 +---------- torchmdnet/neighbors/neighbors_cuda_cell.cu | 96 ++++++--------------- 7 files changed, 132 insertions(+), 138 deletions(-) diff --git a/setup.py b/setup.py index dcf242017..2fd821ed0 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ name="torchmd-net", version=version, packages=find_packages(), - package_data={"torchmdnet": ["neighbors/neighbors*", "neighbors/common.cuh"]}, + package_data={"torchmdnet": ["neighbors/neighbors*", "neighbors/*.cu*"]}, include_package_data=True, entry_points={"console_scripts": ["torchmd-train = torchmdnet.scripts.train:main"]}, ) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index d84a59368..bff36a030 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -51,7 +51,8 @@ def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box_vecto @pytest.mark.parametrize("loop", [True, False]) @pytest.mark.parametrize("include_transpose", [True, False]) @pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"]) -def test_neighbors(device, strategy, n_batches, cutoff, loop, include_transpose, box_type): +@pytest.mark.parametrize('dtype', [torch.float32, torch.float64]) +def test_neighbors(device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") if box_type == "triclinic" and strategy == "cell": @@ -63,7 +64,7 @@ def test_neighbors(device, strategy, n_batches, cutoff, loop, include_transpose, batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch).to(device) cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) lbox=10.0 - pos = torch.rand(cumsum[-1], 3, device=device)*lbox + pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype)*lbox #Ensure there is at least one pair pos[0,:] = torch.zeros(3) pos[1,:] = torch.zeros(3) @@ -166,7 +167,7 @@ def test_large_size(strategy, n_batches): max_num_pairs = ref_neighbors.shape[1] #Must check without PBC since Distance does not support it - box = None #torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]).to(pos.dtype).to(device) + box = None nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True, include_transpose=True, resize_to_fit=True) neighbors, distances, distance_vecs = nl(pos, batch) neighbors = neighbors.cpu().detach().numpy() @@ -176,3 +177,80 @@ def test_large_size(strategy, n_batches): assert np.allclose(neighbors, ref_neighbors) assert np.allclose(distances, ref_distances) assert np.allclose(distance_vecs, ref_distance_vecs) + + +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"]) +@pytest.mark.parametrize("loop", [True, False]) +@pytest.mark.parametrize("include_transpose", [True, False]) +@pytest.mark.parametrize('dtype', [torch.float32, torch.float64]) +@pytest.mark.parametrize('num_atoms', [1, 2, 3, 5, 100, 1000]) +@pytest.mark.parametrize('grad', ['deltas', 'distances', 'combined']) +@pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"]) +def test_neighbor_grads(device, strategy, loop, include_transpose, dtype, num_atoms, grad, box_type): + if not torch.cuda.is_available() and device == 'cuda': + pytest.skip('No GPU') + if device=="cpu" and strategy!="brute": + pytest.skip("Only brute force supported on CPU") + if box_type == "triclinic" and strategy == "cell": + pytest.skip("Triclinic only supported for brute force") + cutoff=4.999999 + lbox=10.0 + torch.random.manual_seed(1234) + np.random.seed(123456) + # Generate random positions + positions = 0.25*lbox * torch.rand(num_atoms, 3, device=device, dtype=dtype) + if(box_type is None): + box = None + else: + box = torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]).to(dtype).to(device) + # Compute reference values using pure pytorch + ref_neighbors = torch.vstack((torch.tril_indices(num_atoms,num_atoms, -1, device=device),)) + if include_transpose: + ref_neighbors = torch.hstack((ref_neighbors, torch.stack((ref_neighbors[1], ref_neighbors[0])))) + if loop: + index = torch.arange(num_atoms, device=device) + ref_neighbors = torch.hstack((ref_neighbors, torch.stack((index, index)))) + ref_positions = positions.clone() + ref_positions.requires_grad_(True) + # Every pair is included, so there is no need to filter out pairs even after PBC + ref_deltas = ref_positions[ref_neighbors[0]] - ref_positions[ref_neighbors[1]] + if box is not None: + ref_box = box.clone() + ref_deltas -= torch.outer(torch.round(ref_deltas[:,2]/ref_box[2,2]), ref_box[2]) + ref_deltas -= torch.outer(torch.round(ref_deltas[:,1]/ref_box[1,1]), ref_box[1]) + ref_deltas -= torch.outer(torch.round(ref_deltas[:,0]/ref_box[0,0]), ref_box[0]) + + if loop: + ref_distances = torch.zeros((ref_deltas.size(0),), device=device, dtype=dtype) + mask = ref_neighbors[0] != ref_neighbors[1] + ref_distances[mask] = torch.linalg.norm(ref_deltas[mask], dim=-1) + else: + ref_distances = torch.linalg.norm(ref_deltas, dim=-1) + max_num_pairs = max(ref_neighbors.shape[1],1) + positions.requires_grad_(True) + nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, loop=loop, include_transpose=include_transpose, return_vecs=True, resize_to_fit=True, box=box) + neighbors, distances, deltas = nl(positions) + #Check neighbor pairs are correct + ref_neighbors_sort, _, _ = sort_neighbors(ref_neighbors.clone().cpu().detach().numpy(), ref_deltas.clone().cpu().detach().numpy(), ref_distances.clone().cpu().detach().numpy()) + neighbors_sort, _, _ = sort_neighbors(neighbors.clone().cpu().detach().numpy(), deltas.clone().cpu().detach().numpy(), distances.clone().cpu().detach().numpy()) + assert np.allclose(ref_neighbors_sort, neighbors_sort) + + # Compute gradients + if grad == 'deltas': + ref_deltas.sum().backward() + deltas.sum().backward() + elif grad == 'distances': + ref_distances.sum().backward() + distances.sum().backward() + elif grad == 'combined': + (ref_deltas.sum() + ref_distances.sum()).backward() + (deltas.sum() + distances.sum()).backward() + else: + raise ValueError('grad') + ref_pos_grad_sorted = ref_positions.grad.cpu().detach().numpy() + pos_grad_sorted = positions.grad.cpu().detach().numpy() + if dtype == torch.float32: + assert np.allclose(ref_pos_grad_sorted, pos_grad_sorted, atol=1e-2, rtol=1e-2) + else: + assert np.allclose(ref_pos_grad_sorted, pos_grad_sorted, atol=1e-8, rtol=1e-5) diff --git a/torchmdnet/neighbors/__init__.py b/torchmdnet/neighbors/__init__.py index 2664da064..b93632471 100644 --- a/torchmdnet/neighbors/__init__.py +++ b/torchmdnet/neighbors/__init__.py @@ -3,7 +3,7 @@ from torch.utils import cpp_extension src_dir = os.path.dirname(__file__) -sources = ['neighbors.cpp', 'neighbors_cpu.cpp'] + (['neighbors_cuda.cu', 'neighbors_cuda_cell.cu'] if pt.cuda.is_available() else []) +sources = ['neighbors.cpp', 'neighbors_cpu.cpp'] + (['neighbors_cuda.cu', 'neighbors_cuda_cell.cu', 'backwards.cu'] if pt.cuda.is_available() else []) sources = [os.path.join(src_dir, name) for name in sources] cpp_extension.load(name='neighbors', sources=sources, is_python_module=False) diff --git a/torchmdnet/neighbors/common.cuh b/torchmdnet/neighbors/common.cuh index b6962f874..08c41fe55 100644 --- a/torchmdnet/neighbors/common.cuh +++ b/torchmdnet/neighbors/common.cuh @@ -1,7 +1,8 @@ -#pragma once +#ifndef NEIGHBORS_COMMON_CUH +#define NEIGHBORS_COMMON_CUH +#include #include #include -#include #include using c10::cuda::CUDAStreamGuard; @@ -118,3 +119,6 @@ __device__ auto compute_distance(scalar3 pos_i, scalar3 pos_ } } // namespace triclinic + +tensor_list common_backward(AutogradContext* ctx, tensor_list grad_inputs); +#endif diff --git a/torchmdnet/neighbors/neighbors_cpu.cpp b/torchmdnet/neighbors/neighbors_cpu.cpp index 8f4d8d9da..f769fddef 100644 --- a/torchmdnet/neighbors/neighbors_cpu.cpp +++ b/torchmdnet/neighbors/neighbors_cpu.cpp @@ -54,10 +54,12 @@ forward(const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, Tensor neighbors = torch::empty({0}, positions.options().dtype(kInt32)); Tensor distances = torch::empty({0}, positions.options()); Tensor deltas = torch::empty({0}, positions.options()); - neighbors = torch::vstack((torch::tril_indices(n_atoms,n_atoms, -1, neighbors.options()))); - auto mask = index_select(batch, 0, neighbors.index({0, Slice()})) == index_select(batch, 0, neighbors.index({1, Slice()})); + neighbors = torch::vstack((torch::tril_indices(n_atoms, n_atoms, -1, neighbors.options()))); + auto mask = index_select(batch, 0, neighbors.index({0, Slice()})) == + index_select(batch, 0, neighbors.index({1, Slice()})); neighbors = neighbors.index({Slice(), mask}).to(kInt32); - deltas = index_select(positions, 0, neighbors.index({0, Slice()})) - index_select(positions, 0, neighbors.index({1, Slice()})); + deltas = index_select(positions, 0, neighbors.index({0, Slice()})) - + index_select(positions, 0, neighbors.index({1, Slice()})); if (use_periodic) { deltas -= outer(round(deltas.index({Slice(), 2}) / box_vectors.index({2, 2})), box_vectors.index({2})); @@ -67,20 +69,20 @@ forward(const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, box_vectors.index({0})); } distances = frobenius_norm(deltas, 1); - mask = (distances < cutoff_upper)*(distances >= cutoff_lower); + mask = (distances < cutoff_upper) * (distances >= cutoff_lower); neighbors = neighbors.index({Slice(), mask}); deltas = deltas.index({mask, Slice()}); distances = distances.index({mask}); if (include_transpose) { - neighbors = torch::hstack({neighbors, torch::stack({neighbors[1], neighbors[0]})}); - distances = torch::hstack({distances, distances}); - deltas = torch::vstack({deltas, -deltas}); + neighbors = torch::hstack({neighbors, torch::stack({neighbors[1], neighbors[0]})}); + distances = torch::hstack({distances, distances}); + deltas = torch::vstack({deltas, -deltas}); } - if(loop) { - const Tensor range = torch::arange(0, n_atoms, torch::kInt32); - neighbors = torch::hstack({neighbors, torch::stack({range, range})}); - distances = torch::hstack({distances, torch::zeros_like(range)}); - deltas = torch::vstack({deltas, torch::zeros({n_atoms,3}, deltas.options())}); + if (loop) { + const Tensor range = torch::arange(0, n_atoms, torch::kInt32); + neighbors = torch::hstack({neighbors, torch::stack({range, range})}); + distances = torch::hstack({distances, torch::zeros_like(range)}); + deltas = torch::vstack({deltas, torch::zeros({n_atoms, 3}, deltas.options())}); } Tensor num_pairs_found = torch::empty(1, distances.options().dtype(kInt32)); num_pairs_found[0] = distances.size(0); diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index a5120b7a3..3307d246a 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -1,6 +1,5 @@ #include "common.cuh" #include -#include #include __device__ uint32_t get_row(uint32_t index) { @@ -89,9 +88,9 @@ __global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor pos_i; int64_t batch_i; - if(active){ - pos_i = {positions[id][0], positions[id][1], positions[id][2]}; - batch_i = batch[id]; + if (active) { + pos_i = {positions[id][0], positions[id][1], positions[id][2]}; + batch_i = batch[id]; } // Distribute the N particles in a group of tiles. Storing in each tile blockDim.x values in // shared memory. This way all threads are accesing the same memory addresses at the same time @@ -150,26 +149,6 @@ __global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor -__global__ void -backward_kernel(const Accessor neighbors, const Accessor deltas, - const Accessor distances, const Accessor grad_distances, - Accessor grad_positions) { - const int32_t i_pair = blockIdx.x * blockDim.x + threadIdx.x; - const int32_t num_pairs = neighbors.size(1); - if (i_pair >= num_pairs) - return; - - const int32_t i_dir = blockIdx.y; - const int32_t i_atom = neighbors[i_dir][i_pair]; - if (i_atom < 0) - return; - - const int32_t i_comp = blockIdx.z; - const scalar_t grad = deltas[i_pair][i_comp] / distances[i_pair] * grad_distances[i_pair]; - atomicAdd(&grad_positions[i_atom][i_comp], (i_dir ? -1 : 1) * grad); -} - static void checkInput(const Tensor& positions, const Tensor& batch) { // This version works with batches // Batch contains the molecule index for each atom in positions @@ -280,30 +259,7 @@ public: } static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) { - const Tensor grad_distances = grad_inputs[1]; - const int num_atoms = ctx->saved_data["num_atoms"].toInt(); - const int num_pairs = grad_distances.size(0); - const int num_threads = 32; - const int num_blocks_x = std::max((num_pairs + num_threads - 1) / num_threads, 1); - const dim3 blocks(num_blocks_x, 2, 3); - const auto stream = getCurrentCUDAStream(grad_distances.get_device()); - - const tensor_list data = ctx->get_saved_variables(); - const Tensor neighbors = data[0]; - const Tensor deltas = data[1]; - const Tensor distances = data[2]; - const Tensor grad_positions = zeros({num_atoms, 3}, grad_distances.options()); - - AT_DISPATCH_FLOATING_TYPES( - grad_distances.scalar_type(), "get_neighbor_pairs_backward", [&]() { - const CUDAStreamGuard guard(stream); - backward_kernel<<>>( - get_accessor(neighbors), get_accessor(deltas), - get_accessor(distances), get_accessor(grad_distances), - get_accessor(grad_positions)); - }); - - return {grad_positions, Tensor(), Tensor(), Tensor()}; + return common_backward(ctx, grad_inputs); } }; diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index 0ed3bf5bc..72056fa91 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -2,34 +2,9 @@ */ #include "common.cuh" +#include #include -#include -#include - -template -__global__ void -backward_kernel(const Accessor neighbors, const Accessor deltas, - const Accessor distances, const Accessor grad_distances, - Accessor grad_positions) { - // What the backward kernel does: - // For each pair of atoms, it calculates the gradient of the distance between them - // with respect to the positions of the atoms. - // The gradient is then added to the gradient of the positions. - // The gradient of the distance is calculated using the chain rule: - const int32_t i_pair = blockIdx.x * blockDim.x + threadIdx.x; - const int32_t num_pairs = neighbors.size(1); - if (i_pair >= num_pairs) - return; - - const int32_t i_dir = blockIdx.y; - const int32_t i_atom = neighbors[i_dir][i_pair]; - if (i_atom < 0) - return; - - const int32_t i_comp = blockIdx.z; - const scalar_t grad = deltas[i_pair][i_comp] / distances[i_pair] * grad_distances[i_pair]; - atomicAdd(&grad_positions[i_atom][i_comp], (i_dir ? -1 : 1) * grad); -} +#include static void checkInput(const Tensor& positions, const Tensor& batch) { // Batch contains the molecule index for each atom in positions @@ -181,36 +156,37 @@ __global__ void assignHash(const Accessor positions, uint64_t* hash hash_values[i_atom] = i_atom; } -// Adaptor from pytorch cached allocator to thrust -template class CudaAllocator { -public: - using value_type = T; - CudaAllocator() { - } - T* allocate(std::ptrdiff_t num_elements) { - return static_cast( - at::cuda::getCUDADeviceAllocator()->raw_allocate(num_elements * sizeof(T))); +// This is a custom allocator for thrust that uses the caching allocator from pytorch +// Its existence is due to the fact that Pytorch does not support uint64_t as a valid Tensor type +template struct torch_cached_allocator : thrust::device_malloc_allocator { + typedef thrust::device_malloc_allocator super_t; + typedef typename super_t::pointer pointer; + typedef typename super_t::size_type size_type; + + pointer allocate(size_type n) { + auto ptr = static_cast(at::cuda::CUDACachingAllocator::raw_alloc(n * sizeof(T))); + return pointer(ptr); } - void deallocate(T* ptr, size_t) { - at::cuda::getCUDADeviceAllocator()->raw_deallocate(ptr); + + void deallocate(pointer p, size_type n) { + at::cuda::CUDACachingAllocator::raw_delete(p.get()); } }; /* - * @brief Sort the positions by hash, based on the cell assigned to each position and the batch + * @brief Sort the positions by hash, first by the cell assigned to each position and the batch * index * @param positions The positions of the atoms * @param batch The batch index of each atom - * @param box_size The size of the box in each dimension + * @param box_size The box vectors * @param cutoff The cutoff * @return A tuple of the sorted positions and the original indices of each atom in the sorted list */ static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, const Tensor& box_size, const Scalar& cutoff) { const int num_atoms = positions.size(0); - const auto options = positions.options(); - thrust::device_vector hash_keys(num_atoms); - Tensor hash_values = empty({num_atoms}, options.dtype(torch::kInt32)); + thrust::device_vector> hash_keys(num_atoms); + Tensor hash_values = empty({num_atoms}, positions.options().dtype(torch::kInt32)); const int threads = 128; const int blocks = (num_atoms + threads - 1) / threads; auto stream = at::cuda::getCurrentCUDAStream(); @@ -225,8 +201,8 @@ static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, cutoff_, num_atoms); }); thrust::device_ptr index_ptr(hash_values.data_ptr()); - thrust::sort_by_key(thrust::cuda::par.on(stream), hash_keys.begin(), hash_keys.end(), - index_ptr); + thrust::sort_by_key(thrust::cuda::par(torch_cached_allocator()).on(stream), + hash_keys.begin(), hash_keys.end(), index_ptr); Tensor sorted_positions = positions.index_select(0, hash_values); return std::make_tuple(sorted_positions, hash_values); } @@ -274,7 +250,7 @@ __global__ void fillCellOffsetsD(const Accessor sorted_positions, @param sorted_positions The positions sorted by cell @param sorted_indices The original indices of the sorted positions @param batch The batch index of each position - @param box_size The size of the box + @param box_size The box vectors @param cutoff The cutoff distance @return A tuple of cell_start and cell_end arrays */ @@ -436,6 +412,7 @@ public: std::tie(sorted_positions, hash_values) = sortPositionsByHash(positions, batch, box_size, cutoff_upper); Tensor cell_start, cell_end; + // Step 3 std::tie(cell_start, cell_end) = fillCellOffsets(sorted_positions, hash_values, batch, box_size, cutoff_upper); const Tensor neighbors = full({2, max_num_pairs_}, -1, options.dtype(kInt32)); @@ -443,7 +420,7 @@ public: const Tensor distances = full(max_num_pairs_, 0, options); const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); const auto stream = getCurrentCUDAStream(positions.get_device()); - { // Use the cell list for each batch to find the neighbors + { // Traverse the cell list to find the neighbors const CUDAStreamGuard guard(stream); AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "forward", [&] { const scalar_t cutoff_upper_ = cutoff_upper.to(); @@ -470,30 +447,7 @@ public: } static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) { - const Tensor grad_distances = grad_inputs[1]; - const int num_atoms = ctx->saved_data["num_atoms"].toInt(); - const int num_pairs = grad_distances.size(0); - const int num_threads = 128; - const int num_blocks_x = std::max((num_pairs + num_threads - 1) / num_threads, 1); - const dim3 blocks(num_blocks_x, 2, 3); - const auto stream = getCurrentCUDAStream(grad_distances.get_device()); - - const tensor_list data = ctx->get_saved_variables(); - const Tensor neighbors = data[0]; - const Tensor deltas = data[1]; - const Tensor distances = data[2]; - const Tensor grad_positions = zeros({num_atoms, 3}, grad_distances.options()); - - AT_DISPATCH_FLOATING_TYPES( - grad_distances.scalar_type(), "get_neighbor_pairs_backward", [&]() { - const CUDAStreamGuard guard(stream); - backward_kernel<<>>( - get_accessor(neighbors), get_accessor(deltas), - get_accessor(distances), get_accessor(grad_distances), - get_accessor(grad_positions)); - }); - - return {grad_positions, Tensor(), Tensor(), Tensor()}; + return common_backward(ctx, grad_inputs); } }; From 24ff0fe8e0863a04dc5647ed10dbdb4bbf9a6f7b Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 9 May 2023 16:52:12 +0200 Subject: [PATCH 29/76] Actually add the backwards pass code --- torchmdnet/neighbors/backwards.cu | 54 +++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 torchmdnet/neighbors/backwards.cu diff --git a/torchmdnet/neighbors/backwards.cu b/torchmdnet/neighbors/backwards.cu new file mode 100644 index 000000000..acd4f3524 --- /dev/null +++ b/torchmdnet/neighbors/backwards.cu @@ -0,0 +1,54 @@ +#include "common.cuh" + +template +__global__ void +backward_kernel(const Accessor neighbors, const Accessor deltas, + const Accessor grad_deltas, const Accessor distances, + const Accessor grad_distances, Accessor grad_positions) { + const int32_t i_pair = blockIdx.x * blockDim.x + threadIdx.x; + const int32_t num_pairs = neighbors.size(1); + if (i_pair >= num_pairs) + return; + const int32_t i_dir = blockIdx.y; + const int32_t i_atom = neighbors[i_dir][i_pair]; + const int32_t i_comp = blockIdx.z; + if (i_atom < 0){ + return; + } + const scalar_t grad_deltas_ = grad_deltas[i_pair][i_comp]; + // Handle self interaction + const scalar_t dist = distances[i_pair]; + const scalar_t grad_distances_ = deltas[i_pair][i_comp] / dist * grad_distances[i_pair]; + const scalar_t grad = + (i_dir ? -1 : 1) * + (i_atom == neighbors[1 - i_dir][i_pair] ? scalar_t(0.0) : (grad_deltas_ + grad_distances_)); + atomicAdd(&grad_positions[i_atom][i_comp], grad); +} + +tensor_list common_backward(AutogradContext* ctx, tensor_list grad_inputs) { + const Tensor grad_deltas = grad_inputs[1]; + const Tensor grad_distances = grad_inputs[2]; + const int num_atoms = ctx->saved_data["num_atoms"].toInt(); + const int num_pairs = grad_distances.size(0); + const int num_threads = 128; + const int num_blocks_x = std::max((num_pairs + num_threads - 1) / num_threads, 1); + const dim3 blocks(num_blocks_x, 2, 3); + const auto stream = getCurrentCUDAStream(grad_distances.get_device()); + + const tensor_list data = ctx->get_saved_variables(); + const Tensor neighbors = data[0]; + const Tensor deltas = data[1]; + const Tensor distances = data[2]; + const Tensor grad_positions = zeros({num_atoms, 3}, grad_distances.options()); + + AT_DISPATCH_FLOATING_TYPES(grad_distances.scalar_type(), "getNeighborPairs::backward", [&]() { + const CUDAStreamGuard guard(stream); + backward_kernel<<>>( + get_accessor(neighbors), get_accessor(deltas), + get_accessor(grad_deltas), get_accessor(distances), + get_accessor(grad_distances), get_accessor(grad_positions)); + }); + + return {grad_positions, Tensor(), Tensor(), Tensor(), Tensor(), + Tensor(), Tensor(), Tensor(), Tensor(), Tensor()}; +} From 14e785061e578c7cc8f0b3606274a03feca056b2 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 9 May 2023 16:53:11 +0200 Subject: [PATCH 30/76] Small modification to make jit.script compatible. Add test for this. --- tests/test_neighbors.py | 50 ++++++++++++++++++++++++++++++++++++++ torchmdnet/models/utils.py | 5 ++-- 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index bff36a030..1ea7ab345 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -1,6 +1,7 @@ import os import pytest import torch +import torch.jit import numpy as np from torchmdnet.models.utils import Distance, DistanceCellList @@ -254,3 +255,52 @@ def test_neighbor_grads(device, strategy, loop, include_transpose, dtype, num_at assert np.allclose(ref_pos_grad_sorted, pos_grad_sorted, atol=1e-2, rtol=1e-2) else: assert np.allclose(ref_pos_grad_sorted, pos_grad_sorted, atol=1e-8, rtol=1e-5) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"]) +@pytest.mark.parametrize("n_batches", [1, 128]) +@pytest.mark.parametrize("cutoff", [1.0]) +@pytest.mark.parametrize("loop", [True, False]) +@pytest.mark.parametrize("include_transpose", [True, False]) +@pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"]) +@pytest.mark.parametrize('dtype', [torch.float32]) +def test_jit_script_compatible(device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if box_type == "triclinic" and strategy == "cell": + pytest.skip("Triclinic only supported for brute force") + if device=="cpu" and strategy!="brute": + pytest.skip("Only brute force supported on CPU") + torch.manual_seed(4321) + n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) + batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch).to(device) + cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) + lbox=10.0 + pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype)*lbox + #Ensure there is at least one pair + pos[0,:] = torch.zeros(3) + pos[1,:] = torch.zeros(3) + pos.requires_grad = True + if(box_type is None): + box = None + else: + box = torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]).to(pos.dtype).to(device) + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box) + max_num_pairs = ref_neighbors.shape[1] + nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True, include_transpose=include_transpose) + batch.to(device) + + nl = torch.jit.script(nl) + neighbors, distances, distance_vecs = nl(pos, batch) + neighbors = neighbors.cpu().detach().numpy() + distance_vecs = distance_vecs.cpu().detach().numpy() + distances = distances.cpu().detach().numpy() + neighbors, distance_vecs, distances = sort_neighbors(neighbors, distance_vecs, distances) + assert neighbors.shape == (2, max_num_pairs) + assert distances.shape == (max_num_pairs,) + assert distance_vecs.shape == (max_num_pairs, 3) + + assert np.allclose(neighbors, ref_neighbors) + assert np.allclose(distances, ref_distances) + assert np.allclose(distance_vecs, ref_distance_vecs) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 89d0780e0..32fcbd5ee 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -127,7 +127,7 @@ def __init__( self.cutoff_lower = cutoff_lower self.max_num_pairs = max_num_pairs self.strategy = strategy - self.box = box + self.box: Optional[Tensor] = box self.loop = loop self.return_vecs = return_vecs self.include_transpose = include_transpose @@ -135,6 +135,7 @@ def __init__( self.use_periodic = True if self.box is None: self.use_periodic = False + self.box = torch.empty((0,0)) if self.strategy == "cell": #Default the box to 3 times the cutoff, really inefficient for the cell list lbox = cutoff_upper * 3.0 @@ -169,8 +170,6 @@ def forward(self, pos: Tensor, batch: Optional[Tensor] = None) -> Tuple[Tensor, otherwise the tensors will have size max_num_pairs, with neighbor pairs (-1, -1) at the end. """ - if self.box is None: - self.box = torch.empty((0, 0), dtype=pos.dtype) self.box = self.box.to(pos.dtype).to(pos.device) max_pairs = self.max_num_pairs if self.max_num_pairs < 0: From 0c5a46d00e2945790cb34f6a6a1363eea7a2d4ca Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 9 May 2023 19:50:54 +0200 Subject: [PATCH 31/76] Make all GPU operations compatible with CUDA graphs, add test for it. This required moving the box Tensor to the CPU and using a lower level library (CUB) for sorting. Modify benchmark to reflect this Small updates to documentation. --- benchmarks/neighbors.py | 19 +++++-- tests/test_neighbors.py | 63 +++++++++++++++++++++ torchmdnet/models/utils.py | 17 +++--- torchmdnet/neighbors/common.cuh | 58 +++++++++++++++---- torchmdnet/neighbors/neighbors_cuda.cu | 43 +++----------- torchmdnet/neighbors/neighbors_cuda_cell.cu | 60 ++++++++------------ 6 files changed, 165 insertions(+), 95 deletions(-) diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py index a2874c08d..325117222 100644 --- a/benchmarks/neighbors.py +++ b/benchmarks/neighbors.py @@ -71,11 +71,20 @@ def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_n start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) - - start.record() - for i in range(nruns): - neighbors, distances, distance_vecs = nl(pos, batch) - end.record() + #record in a cuda graph + if strategy != 'distance': + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + neighbors, distances, distance_vecs = nl(pos, batch) + start.record() + for i in range(nruns): + graph.replay() + end.record() + else: + start.record() + for i in range(nruns): + neighbors, distances, distance_vecs = nl(pos, batch) + end.record() if device == 'cuda': torch.cuda.synchronize() #Final time diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 1ea7ab345..c18a1f762 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -304,3 +304,66 @@ def test_jit_script_compatible(device, strategy, n_batches, cutoff, loop, includ assert np.allclose(neighbors, ref_neighbors) assert np.allclose(distances, ref_distances) assert np.allclose(distance_vecs, ref_distance_vecs) + + + + +@pytest.mark.parametrize("device", ["cuda"]) +@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"]) +@pytest.mark.parametrize("n_batches", [1, 128]) +@pytest.mark.parametrize("cutoff", [1.0]) +@pytest.mark.parametrize("loop", [True, False]) +@pytest.mark.parametrize("include_transpose", [True, False]) +@pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"]) +@pytest.mark.parametrize('dtype', [torch.float32]) +def test_cuda_graph_compatible(device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if box_type == "triclinic" and strategy == "cell": + pytest.skip("Triclinic only supported for brute force") + torch.manual_seed(4321) + n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) + batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch).to(device) + cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) + lbox=10.0 + pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype)*lbox + #Ensure there is at least one pair + pos[0,:] = torch.zeros(3) + pos[1,:] = torch.zeros(3) + pos.requires_grad = True + if(box_type is None): + box = None + else: + box = torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]).to(pos.dtype).to(device) + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box) + max_num_pairs = ref_neighbors.shape[1] + nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True, include_transpose=include_transpose, check_errors=False, resize_to_fit=False) + batch.to(device) + + + graph = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + #Warm up + with torch.cuda.stream(s): + for _ in range(10): + neighbors, distances, distance_vecs = nl(pos, batch) + torch.cuda.synchronize() + #Capture + with torch.cuda.graph(graph): + neighbors, distances, distance_vecs = nl(pos, batch) + neighbors.fill_(0) + graph.replay() + torch.cuda.synchronize() + + neighbors = neighbors.cpu().detach().numpy() + distance_vecs = distance_vecs.cpu().detach().numpy() + distances = distances.cpu().detach().numpy() + neighbors, distance_vecs, distances = sort_neighbors(neighbors, distance_vecs, distances) + assert neighbors.shape == (2, max_num_pairs) + assert distances.shape == (max_num_pairs,) + assert distance_vecs.shape == (max_num_pairs, 3) + + assert np.allclose(neighbors, ref_neighbors) + assert np.allclose(distances, ref_distances) + assert np.allclose(distance_vecs, ref_distance_vecs) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 32fcbd5ee..3d2831d42 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -97,6 +97,8 @@ def __init__( ): super(DistanceCellList, self).__init__() """ Compute the neighbor list for a given cutoff. + This operation can be placed inside a CUDA graph in some cases. + In particular, resize_to_fit and check_errors must be False. Parameters ---------- cutoff_lower : float @@ -108,18 +110,19 @@ def __init__( If negative, it is interpreted as (minus) the maximum number of neighbors per atom. strategy : str Strategy to use for computing the neighbor list. Can be one of - ["brute", "cell"]. - box : torch.Tensor - Size of the box shape (3,3) or None. + ["shared", "brute", "cell"]. + box : Optional[torch.Tensor] + Size of the box, shape (3,3) or None. If strategy is "cell", the box must be diagonal. loop : bool Whether to include self-interactions. include_transpose : bool Whether to include the transpose of the neighbor list. resize_to_fit : bool - Whether to resize the neighbor list to the actual number of pairs found. + Whether to resize the neighbor list to the actual number of pairs found. When False, the list is padded with (-1,-1) pairs up to max_num_pairs + If this is True the operation is not CUDA graph compatible. check_errors : bool - Whether to check for too many pairs. + Whether to check for too many pairs. If this is True the operation is not CUDA graph compatible. return_vecs : bool Whether to return the distance vectors. """ @@ -140,7 +143,7 @@ def __init__( #Default the box to 3 times the cutoff, really inefficient for the cell list lbox = cutoff_upper * 3.0 self.box = torch.tensor([[lbox, 0, 0], [0, lbox, 0], [0, 0, lbox]]) - + self.box = self.box.cpu() #All strategies expect the box to be in CPU memory self.kernel = self._backends[self.strategy] if self.kernel is None: raise ValueError("Unknown strategy: {}".format(self.strategy)) @@ -170,7 +173,7 @@ def forward(self, pos: Tensor, batch: Optional[Tensor] = None) -> Tuple[Tensor, otherwise the tensors will have size max_num_pairs, with neighbor pairs (-1, -1) at the end. """ - self.box = self.box.to(pos.dtype).to(pos.device) + self.box = self.box.to(pos.dtype) max_pairs = self.max_num_pairs if self.max_num_pairs < 0: max_pairs = -self.max_num_pairs*pos.shape[0] diff --git a/torchmdnet/neighbors/common.cuh b/torchmdnet/neighbors/common.cuh index 08c41fe55..63d9e69ba 100644 --- a/torchmdnet/neighbors/common.cuh +++ b/torchmdnet/neighbors/common.cuh @@ -58,6 +58,27 @@ template <> struct vec3 { template using scalar3 = typename vec3::type; +static void checkInput(const Tensor& positions, const Tensor& batch) { + // Batch contains the molecule index for each atom in positions + // Neighbors are only calculated within the same molecule + // Batch is a 1D tensor of size (N_atoms) + // Batch is assumed to be sorted + // Batch is assumed to be contiguous + // Batch is assumed to be of type torch::kLong + TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); + TORCH_CHECK(positions.size(0) > 0, + "Expected the 1nd dimension size of \"positions\" to be more than 0"); + TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); + TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); + + TORCH_CHECK(batch.dim() == 1, "Expected \"batch\" to have one dimension"); + TORCH_CHECK(batch.size(0) == positions.size(0), + "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension " + "size of \"positions\""); + TORCH_CHECK(batch.is_contiguous(), "Expected \"batch\" to be contiguous"); + TORCH_CHECK(batch.dtype() == torch::kInt64, "Expected \"batch\" to be of type torch::kLong"); +} + namespace rect { /* @@ -86,7 +107,20 @@ __device__ auto compute_distance(scalar3 pos_i, scalar3 pos_ } } // namespace rect + namespace triclinic { +template struct Box { + scalar_t size[3][3]; + Box(const Tensor& box_vectors) { + if (box_vectors.size(0) == 3 && box_vectors.size(1) == 3) { + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + size[i][j] = box_vectors[i][j].item(); + } + } + } + } +}; /* * @brief Takes a point to the unit cell using Minimum Image * Convention @@ -95,25 +129,25 @@ namespace triclinic { * @return The point in the unit cell */ template -__device__ auto apply_pbc(scalar3 delta, const Accessor box_vectors) { - scalar_t scale3 = round(delta.z / box_vectors[2][2]); - delta.x -= scale3 * box_vectors[2][0]; - delta.y -= scale3 * box_vectors[2][1]; - delta.z -= scale3 * box_vectors[2][2]; - scalar_t scale2 = round(delta.y / box_vectors[1][1]); - delta.x -= scale2 * box_vectors[1][0]; - delta.y -= scale2 * box_vectors[1][1]; - scalar_t scale1 = round(delta.x / box_vectors[0][0]); - delta.x -= scale1 * box_vectors[0][0]; +__device__ auto apply_pbc(scalar3 delta, const Box& box) { + scalar_t scale3 = round(delta.z / box.size[2][2]); + delta.x -= scale3 * box.size[2][0]; + delta.y -= scale3 * box.size[2][1]; + delta.z -= scale3 * box.size[2][2]; + scalar_t scale2 = round(delta.y / box.size[1][1]); + delta.x -= scale2 * box.size[1][0]; + delta.y -= scale2 * box.size[1][1]; + scalar_t scale1 = round(delta.x / box.size[0][0]); + delta.x -= scale1 * box.size[0][0]; return delta; } template __device__ auto compute_distance(scalar3 pos_i, scalar3 pos_j, - bool use_periodic, const Accessor box_vectors) { + bool use_periodic, const Box& box) { scalar3 delta = {pos_i.x - pos_j.x, pos_i.y - pos_j.y, pos_i.z - pos_j.z}; if (use_periodic) { - delta = apply_pbc(delta, box_vectors); + delta = apply_pbc(delta, box); } return delta; } diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index 3307d246a..5fe099d4d 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -16,7 +16,7 @@ __global__ void forward_kernel_brute(uint32_t num_all_pairs, const Accessor i_curr_pair, Accessor neighbors, Accessor deltas, Accessor distances, bool use_periodic, - const Accessor box_vectors) { + triclinic::Box box) { const uint32_t index = blockIdx.x * blockDim.x + threadIdx.x; if (index >= num_all_pairs) return; @@ -26,7 +26,7 @@ __global__ void forward_kernel_brute(uint32_t num_all_pairs, const Accessor pos_i{positions[row][0], positions[row][1], positions[row][2]}; const scalar3 pos_j{positions[column][0], positions[column][1], positions[column][2]}; - const auto delta = triclinic::compute_distance(pos_i, pos_j, use_periodic, box_vectors); + const auto delta = triclinic::compute_distance(pos_i, pos_j, use_periodic, box); const scalar_t distance2 = delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) { const int32_t i_pair = atomicAdd(&i_curr_pair[0], include_transpose ? 2 : 1); @@ -77,7 +77,7 @@ __global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor i_curr_pair, Accessor neighbors, Accessor deltas, Accessor distances, int32_t num_tiles, - bool use_periodic, const Accessor box_vectors) { + bool use_periodic, triclinic::Box box) { // A thread per atom const int id = blockIdx.x * blockDim.x + threadIdx.x; // All threads must pass through __syncthreads, @@ -116,8 +116,7 @@ __global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor= cutoff_lower2) { @@ -149,31 +148,6 @@ __global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor 0, - "Expected the 1nd dimension size of \"positions\" to be more than 0"); - TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); - TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); - TORCH_CHECK(positions.size(0) < 1l << 15l, - "Expected the 1st dimension size of \"positions\" to be less than ", 1l << 15l); - TORCH_CHECK(batch.dim() == 1, "Expected \"batch\" to have one dimension"); - TORCH_CHECK(batch.size(0) == positions.size(0), - "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension " - "size of \"positions\""); - TORCH_CHECK(batch.is_contiguous(), "Expected \"batch\" to be contiguous"); - TORCH_CHECK(batch.dtype() == torch::kInt64, "Expected \"batch\" to be of type torch::kLong"); -} - enum class strategy { brute, shared }; class Autograd : public Function { @@ -191,6 +165,7 @@ public: TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, "Expected \"box_vectors\" to have shape (3, 3)"); } + TORCH_CHECK(box_vectors.device() == torch::kCPU, "Expected \"box_vectors\" to be on CPU"); const int num_atoms = positions.size(0); const int num_pairs = max_num_pairs_; const TensorOptions options = positions.options(); @@ -209,6 +184,7 @@ public: std::max((num_all_pairs + num_threads - 1ul) / num_threads, 1ul); AT_DISPATCH_FLOATING_TYPES( positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { + triclinic::Box box(box_vectors); const scalar_t cutoff_upper_ = cutoff_upper.to(); const scalar_t cutoff_lower_ = cutoff_lower.to(); TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); @@ -218,8 +194,7 @@ public: cutoff_upper_ * cutoff_upper_, loop, include_transpose, get_accessor(i_curr_pair), get_accessor(neighbors), get_accessor(deltas), - get_accessor(distances), use_periodic, - get_accessor(box_vectors)); + get_accessor(distances), use_periodic, box); if (loop) { const uint64_t num_threads = 128; const uint64_t num_blocks = @@ -237,6 +212,7 @@ public: positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() { const scalar_t cutoff_upper_ = cutoff_upper.to(); const scalar_t cutoff_lower_ = cutoff_lower.to(); + triclinic::Box box(box_vectors); TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); constexpr int BLOCKSIZE = 64; const int num_blocks = std::max((num_atoms + BLOCKSIZE - 1) / BLOCKSIZE, 1); @@ -248,8 +224,7 @@ public: cutoff_upper_ * cutoff_upper_, loop, include_transpose, get_accessor(i_curr_pair), get_accessor(neighbors), get_accessor(deltas), - get_accessor(distances), num_tiles, use_periodic, - get_accessor(box_vectors)); + get_accessor(distances), num_tiles, use_periodic, box); }); } } diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index 72056fa91..0ae9554aa 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -2,32 +2,9 @@ */ #include "common.cuh" +#include #include #include -#include - -static void checkInput(const Tensor& positions, const Tensor& batch) { - // Batch contains the molecule index for each atom in positions - // Neighbors are only calculated within the same molecule - // Batch is a 1D tensor of size (N_atoms) - // Batch is assumed to be sorted and starts at zero. - // Batch is assumed to be contiguous - // Batch is assumed to be of type torch::kLong - // Batch is assumed to be non-negative - // Each batch can have a different number of atoms - TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); - TORCH_CHECK(positions.size(0) > 0, - "Expected the 1nd dimension size of \"positions\" to be more than 0"); - TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); - TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); - - TORCH_CHECK(batch.dim() == 1, "Expected \"batch\" to have one dimension"); - TORCH_CHECK(batch.size(0) == positions.size(0), - "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension " - "size of \"positions\""); - TORCH_CHECK(batch.is_contiguous(), "Expected \"batch\" to be contiguous"); - TORCH_CHECK(batch.dtype() == torch::kInt64, "Expected \"batch\" to be of type torch::kLong"); -} /* * @brief Encodes an unsigned integer lower than 1024 as a 32 bit integer by filling every third @@ -163,12 +140,12 @@ template struct torch_cached_allocator : thrust::device_malloc_allo typedef typename super_t::pointer pointer; typedef typename super_t::size_type size_type; - pointer allocate(size_type n) { + static pointer allocate(size_type n) { auto ptr = static_cast(at::cuda::CUDACachingAllocator::raw_alloc(n * sizeof(T))); return pointer(ptr); } - void deallocate(pointer p, size_type n) { + static void deallocate(pointer p, size_type n) { at::cuda::CUDACachingAllocator::raw_delete(p.get()); } }; @@ -185,7 +162,8 @@ template struct torch_cached_allocator : thrust::device_malloc_allo static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, const Tensor& box_size, const Scalar& cutoff) { const int num_atoms = positions.size(0); - thrust::device_vector> hash_keys(num_atoms); + torch_cached_allocator alloc; + auto hash_keys = (uint64_t*)(alloc.allocate(num_atoms * sizeof(uint64_t)).get()); Tensor hash_values = empty({num_atoms}, positions.options().dtype(torch::kInt32)); const int threads = 128; const int blocks = (num_atoms + threads - 1) / threads; @@ -196,13 +174,23 @@ static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, box_size[1][1].item(), box_size[2][2].item()}; assignHash<<>>( - get_accessor(positions), thrust::raw_pointer_cast(hash_keys.data()), - get_accessor(hash_values), get_accessor(batch), box_size_, - cutoff_, num_atoms); + get_accessor(positions), hash_keys, get_accessor(hash_values), + get_accessor(batch), box_size_, cutoff_, num_atoms); }); - thrust::device_ptr index_ptr(hash_values.data_ptr()); - thrust::sort_by_key(thrust::cuda::par(torch_cached_allocator()).on(stream), - hash_keys.begin(), hash_keys.end(), index_ptr); + // I have to use cub directly because thrust::sort_by_key is not compatible with graphs + // and torch::lexsort does not support uint64_t + size_t tmp_storage_bytes = 0; + void* tmp_storage = NULL; + uint64_t* d_keys_out = (uint64_t*)alloc.allocate(num_atoms * sizeof(uint64_t)).get(); + int32_t* d_values_out = (int32_t*)alloc.allocate(num_atoms * sizeof(int32_t)).get(); + int32_t* hash_values_ptr = hash_values.data_ptr(); + cub::DeviceRadixSort::SortPairs(tmp_storage, tmp_storage_bytes, hash_keys, d_keys_out, + hash_values_ptr, d_values_out, num_atoms, 0, 64, stream); + tmp_storage = alloc.allocate(tmp_storage_bytes).get(); + cub::DeviceRadixSort::SortPairs(tmp_storage, tmp_storage_bytes, hash_keys, d_keys_out, + hash_values_ptr, d_values_out, num_atoms, 0, 64, stream); + cudaMemcpyAsync(hash_values_ptr, d_values_out, num_atoms * sizeof(int32_t), + cudaMemcpyDeviceToDevice, stream); Tensor sorted_positions = positions.index_select(0, hash_values); return std::make_tuple(sorted_positions, hash_values); } @@ -383,7 +371,7 @@ forward_kernel(const Accessor sorted_positions, class Autograd : public Function { public: static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, - const Tensor& box_size_gpu, bool use_periodic, + const Tensor& box_size, bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop, bool include_transpose) { // The algorithm for the cell list construction can be summarized in three separate steps: @@ -394,11 +382,9 @@ public: // 3. Identify where each cell starts and ends in the sorted particle positions // array. checkInput(positions, batch); - auto box_size = box_size_gpu.cpu(); TORCH_CHECK(box_size.dim() == 2, "Expected \"box_size\" to have two dimensions"); TORCH_CHECK(box_size.size(0) == 3 && box_size.size(1) == 3, "Expected \"box_size\" to have shape (3, 3)"); - // Ensure that box size has no non-zero values outside of the diagonal TORCH_CHECK(box_size[0][1].item() == 0 && box_size[0][2].item() == 0 && box_size[1][0].item() == 0 && box_size[1][2].item() == 0 && box_size[2][0].item() == 0 && box_size[2][1].item() == 0, @@ -419,7 +405,7 @@ public: const Tensor deltas = empty({max_num_pairs_, 3}, options); const Tensor distances = full(max_num_pairs_, 0, options); const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); - const auto stream = getCurrentCUDAStream(positions.get_device()); + const auto stream = getCurrentCUDAStream(positions.get_device()); { // Traverse the cell list to find the neighbors const CUDAStreamGuard guard(stream); AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "forward", [&] { From 9269d2229848fc3a0b2e3922c0cfe82744625c0e Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 9 May 2023 20:14:02 +0200 Subject: [PATCH 32/76] Fix comment --- torchmdnet/neighbors/backwards.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/neighbors/backwards.cu b/torchmdnet/neighbors/backwards.cu index acd4f3524..67ed26195 100644 --- a/torchmdnet/neighbors/backwards.cu +++ b/torchmdnet/neighbors/backwards.cu @@ -16,9 +16,9 @@ backward_kernel(const Accessor neighbors, const Accessor Date: Tue, 9 May 2023 20:36:03 +0200 Subject: [PATCH 33/76] Fix leak in cell list --- torchmdnet/neighbors/neighbors_cuda_cell.cu | 59 +++++++++++---------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index 0ae9554aa..00229f043 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -133,21 +133,26 @@ __global__ void assignHash(const Accessor positions, uint64_t* hash hash_values[i_atom] = i_atom; } -// This is a custom allocator for thrust that uses the caching allocator from pytorch -// Its existence is due to the fact that Pytorch does not support uint64_t as a valid Tensor type -template struct torch_cached_allocator : thrust::device_malloc_allocator { - typedef thrust::device_malloc_allocator super_t; - typedef typename super_t::pointer pointer; - typedef typename super_t::size_type size_type; - - static pointer allocate(size_type n) { - auto ptr = static_cast(at::cuda::CUDACachingAllocator::raw_alloc(n * sizeof(T))); - return pointer(ptr); +/* + * @brief A buffer that is allocated and deallocated using the CUDA caching allocator from torch + */ +template struct cached_buffer { + cached_buffer(size_t size) : size_(size) { + ptr_ = static_cast(at::cuda::CUDACachingAllocator::raw_alloc(size * sizeof(T))); } - - static void deallocate(pointer p, size_type n) { - at::cuda::CUDACachingAllocator::raw_delete(p.get()); + ~cached_buffer() { + at::cuda::CUDACachingAllocator::raw_delete(ptr_); } + T* get() { + return ptr_; + } + size_t size() { + return size_; + } + +private: + T* ptr_; + size_t size_; }; /* @@ -162,8 +167,7 @@ template struct torch_cached_allocator : thrust::device_malloc_allo static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, const Tensor& box_size, const Scalar& cutoff) { const int num_atoms = positions.size(0); - torch_cached_allocator alloc; - auto hash_keys = (uint64_t*)(alloc.allocate(num_atoms * sizeof(uint64_t)).get()); + auto hash_keys = cached_buffer(num_atoms); Tensor hash_values = empty({num_atoms}, positions.options().dtype(torch::kInt32)); const int threads = 128; const int blocks = (num_atoms + threads - 1) / threads; @@ -174,22 +178,23 @@ static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, box_size[1][1].item(), box_size[2][2].item()}; assignHash<<>>( - get_accessor(positions), hash_keys, get_accessor(hash_values), - get_accessor(batch), box_size_, cutoff_, num_atoms); + get_accessor(positions), hash_keys.get(), + get_accessor(hash_values), get_accessor(batch), box_size_, + cutoff_, num_atoms); }); // I have to use cub directly because thrust::sort_by_key is not compatible with graphs // and torch::lexsort does not support uint64_t size_t tmp_storage_bytes = 0; - void* tmp_storage = NULL; - uint64_t* d_keys_out = (uint64_t*)alloc.allocate(num_atoms * sizeof(uint64_t)).get(); - int32_t* d_values_out = (int32_t*)alloc.allocate(num_atoms * sizeof(int32_t)).get(); + auto d_keys_out = cached_buffer(num_atoms); + auto d_values_out = cached_buffer(num_atoms); int32_t* hash_values_ptr = hash_values.data_ptr(); - cub::DeviceRadixSort::SortPairs(tmp_storage, tmp_storage_bytes, hash_keys, d_keys_out, - hash_values_ptr, d_values_out, num_atoms, 0, 64, stream); - tmp_storage = alloc.allocate(tmp_storage_bytes).get(); - cub::DeviceRadixSort::SortPairs(tmp_storage, tmp_storage_bytes, hash_keys, d_keys_out, - hash_values_ptr, d_values_out, num_atoms, 0, 64, stream); - cudaMemcpyAsync(hash_values_ptr, d_values_out, num_atoms * sizeof(int32_t), + cub::DeviceRadixSort::SortPairs(nullptr, tmp_storage_bytes, hash_keys.get(), d_keys_out.get(), + hash_values_ptr, d_values_out.get(), num_atoms, 0, 64, stream); + auto tmp_storage = cached_buffer(tmp_storage_bytes); + cub::DeviceRadixSort::SortPairs(tmp_storage.get(), tmp_storage_bytes, hash_keys.get(), + d_keys_out.get(), hash_values_ptr, d_values_out.get(), + num_atoms, 0, 64, stream); + cudaMemcpyAsync(hash_values_ptr, d_values_out.get(), num_atoms * sizeof(int32_t), cudaMemcpyDeviceToDevice, stream); Tensor sorted_positions = positions.index_select(0, hash_values); return std::make_tuple(sorted_positions, hash_values); @@ -405,7 +410,7 @@ public: const Tensor deltas = empty({max_num_pairs_, 3}, options); const Tensor distances = full(max_num_pairs_, 0, options); const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); - const auto stream = getCurrentCUDAStream(positions.get_device()); + const auto stream = getCurrentCUDAStream(positions.get_device()); { // Traverse the cell list to find the neighbors const CUDAStreamGuard guard(stream); AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "forward", [&] { From 69715c2e1ba1226f7e7938bb916a7f5cb836b5d6 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 9 May 2023 20:43:31 +0200 Subject: [PATCH 34/76] Change hash from int64_t to uint64_t --- torchmdnet/neighbors/neighbors_cuda_cell.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index 00229f043..959576d48 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -120,15 +120,15 @@ __global__ void assignHash(const Accessor positions, uint64_t* hash const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; if (i_atom >= num_atoms) return; - const int32_t i_batch = batch[i_atom]; + const uint32_t i_batch = batch[i_atom]; // Move to the unit cell scalar3 pi = {positions[i_atom][0], positions[i_atom][1], positions[i_atom][2]}; auto ci = getCell(pi, box_size, cutoff); // Calculate the hash - const int32_t hash = hashMorton(ci); + const uint32_t hash = hashMorton(ci); // Create a hash combining the Morton hash and the batch index, so that atoms in the same cell // are contiguous - const int64_t hash_final = (static_cast(hash) << 32) | i_batch; + const uint64_t hash_final = (static_cast(hash) << 32) | i_batch; hash_keys[i_atom] = hash_final; hash_values[i_atom] = i_atom; } From 5f0c0664e9ac0f243c8a74811626a7a0664df6c6 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 9 May 2023 20:43:49 +0200 Subject: [PATCH 35/76] Fix some formatting --- benchmarks/neighbors.py | 166 +++++++++++------ tests/test_neighbors.py | 402 ++++++++++++++++++++++++++++------------ 2 files changed, 387 insertions(+), 181 deletions(-) diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py index 325117222..361ab003f 100644 --- a/benchmarks/neighbors.py +++ b/benchmarks/neighbors.py @@ -4,10 +4,9 @@ from torchmdnet.models.utils import Distance, DistanceCellList - - - -def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_num_neighbors, density): +def benchmark_neighbors( + device, strategy, n_batches, total_num_particles, mean_num_neighbors, density +): """Benchmark the neighbor list generation. Parameters @@ -33,9 +32,11 @@ def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_n np.random.seed(43211) num_particles = total_num_particles // n_batches expected_num_neighbors = mean_num_neighbors - cutoff = np.cbrt(3 * expected_num_neighbors / (4 * np.pi * density)); - n_atoms_per_batch = torch.randint(int(num_particles/2), int(num_particles*2), size=(n_batches,),device="cpu") - #Fix so that the total number of particles is correct. Special care if the difference is negative + cutoff = np.cbrt(3 * expected_num_neighbors / (4 * np.pi * density)) + n_atoms_per_batch = torch.randint( + int(num_particles / 2), int(num_particles * 2), size=(n_batches,), device="cpu" + ) + # Fix so that the total number of particles is correct. Special care if the difference is negative difference = total_num_particles - n_atoms_per_batch.sum() if difference > 0: while difference > 0: @@ -49,30 +50,45 @@ def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_n if n_atoms_per_batch[i] > num_particles: n_atoms_per_batch[i] -= 1 difference += 1 - lbox = np.cbrt(num_particles / density); - batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch).to(device) - cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) - pos = torch.rand(cumsum[-1], 3, device="cpu").to(device)*lbox - if strategy != 'distance': - max_num_pairs = (expected_num_neighbors * n_atoms_per_batch.sum()).item()*2 - box = torch.eye(3, device=device)*lbox - nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, loop=False, include_transpose=True, resize_to_fit=False) + lbox = np.cbrt(num_particles / density) + batch = torch.repeat_interleave( + torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch + ).to(device) + cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) + pos = torch.rand(cumsum[-1], 3, device="cpu").to(device) * lbox + if strategy != "distance": + max_num_pairs = (expected_num_neighbors * n_atoms_per_batch.sum()).item() * 2 + box = torch.eye(3, device=device) * lbox + nl = DistanceCellList( + cutoff_upper=cutoff, + max_num_pairs=max_num_pairs, + strategy=strategy, + box=box, + loop=False, + include_transpose=True, + resize_to_fit=False, + ) else: - max_num_neighbors = int(expected_num_neighbors*5) - nl = Distance(loop=False, cutoff_lower=0.0, cutoff_upper=cutoff, max_num_neighbors=max_num_neighbors) - #Warmup + max_num_neighbors = int(expected_num_neighbors * 5) + nl = Distance( + loop=False, + cutoff_lower=0.0, + cutoff_upper=cutoff, + max_num_neighbors=max_num_neighbors, + ) + # Warmup for i in range(10): neighbors, distances, distance_vecs = nl(pos, batch) - if device == 'cuda': + if device == "cuda": torch.cuda.synchronize() nruns = 50 - if device == 'cuda': + if device == "cuda": torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) - #record in a cuda graph - if strategy != 'distance': + # record in a cuda graph + if strategy != "distance": graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): neighbors, distances, distance_vecs = nl(pos, batch) @@ -85,59 +101,89 @@ def benchmark_neighbors(device, strategy, n_batches, total_num_particles, mean_n for i in range(nruns): neighbors, distances, distance_vecs = nl(pos, batch) end.record() - if device == 'cuda': + if device == "cuda": torch.cuda.synchronize() - #Final time - return (start.elapsed_time(end) / nruns) + # Final time + return start.elapsed_time(end) / nruns + -if __name__ == '__main__': +if __name__ == "__main__": n_particles = 32767 - mean_num_neighbors = min(n_particles, 64); - density=0.5 - print("Benchmarking neighbor list generation for {} particles with {} neighbors on average".format(n_particles, mean_num_neighbors)) + mean_num_neighbors = min(n_particles, 64) + density = 0.5 + print( + "Benchmarking neighbor list generation for {} particles with {} neighbors on average".format( + n_particles, mean_num_neighbors + ) + ) results = {} batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] - for strategy in ['shared', 'brute', 'cell', 'distance']: + for strategy in ["shared", "brute", "cell", "distance"]: # print("Strategy: {}".format(strategy)) # print("--------") # print("{:<10} {:<10}".format("Batch size", "Time (ms)")) # print("{:<10} {:<10}".format("----------", "---------")) - #Loop over different number of batches, random + # Loop over different number of batches, random for n_batches in batch_sizes: - time = benchmark_neighbors(device='cuda', - strategy=strategy, - n_batches=n_batches, - total_num_particles=n_particles, - mean_num_neighbors=mean_num_neighbors, - density=density - ) - #Store results in a dictionary + time = benchmark_neighbors( + device="cuda", + strategy=strategy, + n_batches=n_batches, + total_num_particles=n_particles, + mean_num_neighbors=mean_num_neighbors, + density=density, + ) + # Store results in a dictionary results[strategy, n_batches] = time - #print("{:<10} {:<10.2f}".format(n_batches, time)) + # print("{:<10} {:<10.2f}".format(n_batches, time)) print("Summary") print("-------") - print("{:<10} {:<21} {:<21} {:<18} {:<10}".format("Batch size", "Shared(ms)", "Brute(ms)", "Cell(ms)", "Distance(ms)")) - print("{:<10} {:<21} {:<21} {:<18} {:<10}".format("----------", "---------", "---------", "---------", "---------")) - #Print a column per strategy, show speedup over Distance in parenthesis + print( + "{:<10} {:<21} {:<21} {:<18} {:<10}".format( + "Batch size", "Shared(ms)", "Brute(ms)", "Cell(ms)", "Distance(ms)" + ) + ) + print( + "{:<10} {:<21} {:<21} {:<18} {:<10}".format( + "----------", "---------", "---------", "---------", "---------" + ) + ) + # Print a column per strategy, show speedup over Distance in parenthesis for n_batches in batch_sizes: - base = results['distance', n_batches] - print("{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format(n_batches, - results['shared', n_batches], - base/results['shared', n_batches], - results['brute', n_batches], - base/results['brute', n_batches], - results['cell', n_batches], - base/results['cell', n_batches], - results['distance', n_batches])) + base = results["distance", n_batches] + print( + "{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format( + n_batches, + results["shared", n_batches], + base / results["shared", n_batches], + results["brute", n_batches], + base / results["brute", n_batches], + results["cell", n_batches], + base / results["cell", n_batches], + results["distance", n_batches], + ) + ) - #Print a second table showing time per atom, show in ns + # Print a second table showing time per atom, show in ns print("\n") print("Time per atom") - print("{:<10} {:<10} {:<10} {:<10} {:<10}".format("Batch size", "Shared(ns)", "Brute(ns)", "Cell(ns)", "Distance(ns)")) - print("{:<10} {:<10} {:<10} {:<10} {:<10}".format("----------", "---------", "---------", "---------", "---------")) + print( + "{:<10} {:<10} {:<10} {:<10} {:<10}".format( + "Batch size", "Shared(ns)", "Brute(ns)", "Cell(ns)", "Distance(ns)" + ) + ) + print( + "{:<10} {:<10} {:<10} {:<10} {:<10}".format( + "----------", "---------", "---------", "---------", "---------" + ) + ) for n_batches in batch_sizes: - print("{:<10} {:<10.2f} {:<10.2f} {:<10.2f} {:<10.2f}".format(n_batches, - results['shared', n_batches]/n_particles*1e6, - results['brute', n_batches]/n_particles*1e6, - results['cell', n_batches]/n_particles*1e6, - results['distance', n_batches]/n_particles*1e6)) + print( + "{:<10} {:<10.2f} {:<10.2f} {:<10.2f} {:<10.2f}".format( + n_batches, + results["shared", n_batches] / n_particles * 1e6, + results["brute", n_batches] / n_particles * 1e6, + results["cell", n_batches] / n_particles * 1e6, + results["distance", n_batches] / n_particles * 1e6, + ) + ) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index c18a1f762..64a740fc5 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -5,19 +5,20 @@ import numpy as np from torchmdnet.models.utils import Distance, DistanceCellList + def sort_neighbors(neighbors, deltas, distances): i_sorted = np.lexsort(neighbors) return neighbors[:, i_sorted], deltas[i_sorted], distances[i_sorted] def apply_pbc(deltas, box_vectors): - if(box_vectors is None): + if box_vectors is None: return deltas else: ref_vectors = box_vectors.cpu().detach().numpy() - deltas -= np.outer(np.round(deltas[:,2]/ref_vectors[2,2]), ref_vectors[2]) - deltas -= np.outer(np.round(deltas[:,1]/ref_vectors[1,1]), ref_vectors[1]) - deltas -= np.outer(np.round(deltas[:,0]/ref_vectors[0,0]), ref_vectors[0]) + deltas -= np.outer(np.round(deltas[:, 2] / ref_vectors[2, 2]), ref_vectors[2]) + deltas -= np.outer(np.round(deltas[:, 1] / ref_vectors[1, 1]), ref_vectors[1]) + deltas -= np.outer(np.round(deltas[:, 0] / ref_vectors[0, 0]), ref_vectors[0]) return deltas @@ -25,26 +26,46 @@ def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box_vecto batch = batch.cpu() n_atoms_per_batch = torch.bincount(batch) n_batches = n_atoms_per_batch.shape[0] - cumsum = torch.cumsum(torch.cat([torch.tensor([0]), n_atoms_per_batch]), dim=0).cpu().detach().numpy() - ref_neighbors = np.concatenate([np.tril_indices(int(n_atoms_per_batch[i]), -1)+cumsum[i] for i in range(n_batches)], axis=1) + cumsum = ( + torch.cumsum(torch.cat([torch.tensor([0]), n_atoms_per_batch]), dim=0) + .cpu() + .detach() + .numpy() + ) + ref_neighbors = np.concatenate( + [ + np.tril_indices(int(n_atoms_per_batch[i]), -1) + cumsum[i] + for i in range(n_batches) + ], + axis=1, + ) # add the upper triangle - if(include_transpose): - ref_neighbors = np.concatenate([ref_neighbors, np.flip(ref_neighbors, axis=0)], axis=1) - if(loop): # Add self interactions - ilist=np.arange(cumsum[-1]) - ref_neighbors = np.concatenate([ref_neighbors, np.stack([ilist, ilist], axis=0)], axis=1) + if include_transpose: + ref_neighbors = np.concatenate( + [ref_neighbors, np.flip(ref_neighbors, axis=0)], axis=1 + ) + if loop: # Add self interactions + ilist = np.arange(cumsum[-1]) + ref_neighbors = np.concatenate( + [ref_neighbors, np.stack([ilist, ilist], axis=0)], axis=1 + ) pos_np = pos.cpu().detach().numpy() - ref_distance_vecs = apply_pbc(pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]], box_vectors) + ref_distance_vecs = apply_pbc( + pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]], box_vectors + ) ref_distances = np.linalg.norm(ref_distance_vecs, axis=-1) - #remove pairs with distance > cutoff + # remove pairs with distance > cutoff mask = ref_distances < cutoff ref_neighbors = ref_neighbors[:, mask] ref_distance_vecs = ref_distance_vecs[mask] ref_distances = ref_distances[mask] - ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors(ref_neighbors, ref_distance_vecs, ref_distances) + ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors( + ref_neighbors, ref_distance_vecs, ref_distances + ) return ref_neighbors, ref_distance_vecs, ref_distances + @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("strategy", ["brute", "shared", "cell"]) @pytest.mark.parametrize("n_batches", [1, 2, 3, 4, 128]) @@ -52,37 +73,58 @@ def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box_vecto @pytest.mark.parametrize("loop", [True, False]) @pytest.mark.parametrize("include_transpose", [True, False]) @pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"]) -@pytest.mark.parametrize('dtype', [torch.float32, torch.float64]) -def test_neighbors(device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype): +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_neighbors( + device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype +): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") if box_type == "triclinic" and strategy == "cell": pytest.skip("Triclinic only supported for brute force") - if device=="cpu" and strategy!="brute": + if device == "cpu" and strategy != "brute": pytest.skip("Only brute force supported on CPU") torch.manual_seed(4321) n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) - batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch).to(device) - cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) - lbox=10.0 - pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype)*lbox - #Ensure there is at least one pair - pos[0,:] = torch.zeros(3) - pos[1,:] = torch.zeros(3) + batch = torch.repeat_interleave( + torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch + ).to(device) + cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) + lbox = 10.0 + pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox + # Ensure there is at least one pair + pos[0, :] = torch.zeros(3) + pos[1, :] = torch.zeros(3) pos.requires_grad = True - if(box_type is None): + if box_type is None: box = None else: - box = torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]).to(pos.dtype).to(device) - ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box) + box = ( + torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]) + .to(pos.dtype) + .to(device) + ) + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors( + pos, batch, loop, include_transpose, cutoff, box + ) max_num_pairs = ref_neighbors.shape[1] - nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True, include_transpose=include_transpose) + nl = DistanceCellList( + cutoff_lower=0.0, + loop=loop, + cutoff_upper=cutoff, + max_num_pairs=max_num_pairs, + strategy=strategy, + box=box, + return_vecs=True, + include_transpose=include_transpose, + ) batch.to(device) neighbors, distances, distance_vecs = nl(pos, batch) neighbors = neighbors.cpu().detach().numpy() distance_vecs = distance_vecs.cpu().detach().numpy() distances = distances.cpu().detach().numpy() - neighbors, distance_vecs, distances = sort_neighbors(neighbors, distance_vecs, distances) + neighbors, distance_vecs, distances = sort_neighbors( + neighbors, distance_vecs, distances + ) assert neighbors.shape == (2, max_num_pairs) assert distances.shape == (max_num_pairs,) assert distance_vecs.shape == (max_num_pairs, 3) @@ -100,43 +142,65 @@ def test_neighbors(device, strategy, n_batches, cutoff, loop, include_transpose, def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") - if device=="cpu" and strategy!="brute": + if device == "cpu" and strategy != "brute": pytest.skip("Only brute force supported on CPU") torch.manual_seed(4321) n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) - batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.long), n_atoms_per_batch).to(device) - cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) - lbox=10.0 - pos = torch.rand(cumsum[-1], 3, device=device)*lbox - #Ensure there is at least one pair - pos[0,:] = torch.zeros(3) - pos[1,:] = torch.zeros(3) + batch = torch.repeat_interleave( + torch.arange(n_batches, dtype=torch.long), n_atoms_per_batch + ).to(device) + cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) + lbox = 10.0 + pos = torch.rand(cumsum[-1], 3, device=device) * lbox + # Ensure there is at least one pair + pos[0, :] = torch.zeros(3) + pos[1, :] = torch.zeros(3) pos.requires_grad = True - ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, True, cutoff, None) - #Find the particle appearing in the most pairs - max_num_neighbors = torch.max(torch.bincount(torch.tensor(ref_neighbors[0,:]))) - d = Distance(cutoff_lower=0.0, cutoff_upper=cutoff, loop=loop, max_num_neighbors=max_num_neighbors, return_vecs=True) + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors( + pos, batch, loop, True, cutoff, None + ) + # Find the particle appearing in the most pairs + max_num_neighbors = torch.max(torch.bincount(torch.tensor(ref_neighbors[0, :]))) + d = Distance( + cutoff_lower=0.0, + cutoff_upper=cutoff, + loop=loop, + max_num_neighbors=max_num_neighbors, + return_vecs=True, + ) ref_neighbors, ref_distances, ref_distance_vecs = d(pos, batch) ref_neighbors = ref_neighbors.cpu().detach().numpy() ref_distance_vecs = ref_distance_vecs.cpu().detach().numpy() ref_distances = ref_distances.cpu().detach().numpy() - ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors(ref_neighbors, ref_distance_vecs, ref_distances) + ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors( + ref_neighbors, ref_distance_vecs, ref_distances + ) max_num_pairs = ref_neighbors.shape[1] box = None - nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True, include_transpose=True) + nl = DistanceCellList( + cutoff_lower=0.0, + loop=loop, + cutoff_upper=cutoff, + max_num_pairs=max_num_pairs, + strategy=strategy, + box=box, + return_vecs=True, + include_transpose=True, + ) neighbors, distances, distance_vecs = nl(pos, batch) neighbors = neighbors.cpu().detach().numpy() distance_vecs = distance_vecs.cpu().detach().numpy() distances = distances.cpu().detach().numpy() - neighbors, distance_vecs, distances = sort_neighbors(neighbors, distance_vecs, distances) + neighbors, distance_vecs, distances = sort_neighbors( + neighbors, distance_vecs, distances + ) assert np.allclose(neighbors, ref_neighbors) assert np.allclose(distances, ref_distances) assert np.allclose(distance_vecs, ref_distance_vecs) - @pytest.mark.parametrize("strategy", ["brute", "cell", "shared"]) @pytest.mark.parametrize("n_batches", [1, 2, 3, 4]) def test_large_size(strategy, n_batches): @@ -146,69 +210,101 @@ def test_large_size(strategy, n_batches): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") torch.manual_seed(4321) - num_atoms=int(32000/n_batches) - n_atoms_per_batch = torch.ones(n_batches, dtype=torch.int64)*num_atoms - batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch).to(device) - cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) - lbox=45.0 - pos = torch.rand(cumsum[-1], 3, device=device)*lbox - #Ensure there is at least one pair - pos[0,:] = torch.zeros(3) - pos[1,:] = torch.zeros(3) + num_atoms = int(32000 / n_batches) + n_atoms_per_batch = torch.ones(n_batches, dtype=torch.int64) * num_atoms + batch = torch.repeat_interleave( + torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch + ).to(device) + cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) + lbox = 45.0 + pos = torch.rand(cumsum[-1], 3, device=device) * lbox + # Ensure there is at least one pair + pos[0, :] = torch.zeros(3) + pos[1, :] = torch.zeros(3) pos.requires_grad = True - #Find the particle appearing in the most pairs + # Find the particle appearing in the most pairs max_num_neighbors = 64 - d = Distance(cutoff_lower=0.0, cutoff_upper=cutoff, loop=loop, max_num_neighbors=max_num_neighbors, return_vecs=True) + d = Distance( + cutoff_lower=0.0, + cutoff_upper=cutoff, + loop=loop, + max_num_neighbors=max_num_neighbors, + return_vecs=True, + ) ref_neighbors, ref_distances, ref_distance_vecs = d(pos, batch) ref_neighbors = ref_neighbors.cpu().detach().numpy() ref_distance_vecs = ref_distance_vecs.cpu().detach().numpy() ref_distances = ref_distances.cpu().detach().numpy() - ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors(ref_neighbors, ref_distance_vecs, ref_distances) + ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors( + ref_neighbors, ref_distance_vecs, ref_distances + ) max_num_pairs = ref_neighbors.shape[1] - #Must check without PBC since Distance does not support it + # Must check without PBC since Distance does not support it box = None - nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True, include_transpose=True, resize_to_fit=True) + nl = DistanceCellList( + cutoff_lower=0.0, + loop=loop, + cutoff_upper=cutoff, + max_num_pairs=max_num_pairs, + strategy=strategy, + box=box, + return_vecs=True, + include_transpose=True, + resize_to_fit=True, + ) neighbors, distances, distance_vecs = nl(pos, batch) neighbors = neighbors.cpu().detach().numpy() distance_vecs = distance_vecs.cpu().detach().numpy() distances = distances.cpu().detach().numpy() - neighbors, distance_vecs, distances = sort_neighbors(neighbors, distance_vecs, distances) + neighbors, distance_vecs, distances = sort_neighbors( + neighbors, distance_vecs, distances + ) assert np.allclose(neighbors, ref_neighbors) assert np.allclose(distances, ref_distances) assert np.allclose(distance_vecs, ref_distance_vecs) -@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("strategy", ["brute", "shared", "cell"]) @pytest.mark.parametrize("loop", [True, False]) @pytest.mark.parametrize("include_transpose", [True, False]) -@pytest.mark.parametrize('dtype', [torch.float32, torch.float64]) -@pytest.mark.parametrize('num_atoms', [1, 2, 3, 5, 100, 1000]) -@pytest.mark.parametrize('grad', ['deltas', 'distances', 'combined']) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +@pytest.mark.parametrize("num_atoms", [1, 2, 3, 5, 100, 1000]) +@pytest.mark.parametrize("grad", ["deltas", "distances", "combined"]) @pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"]) -def test_neighbor_grads(device, strategy, loop, include_transpose, dtype, num_atoms, grad, box_type): - if not torch.cuda.is_available() and device == 'cuda': - pytest.skip('No GPU') - if device=="cpu" and strategy!="brute": +def test_neighbor_grads( + device, strategy, loop, include_transpose, dtype, num_atoms, grad, box_type +): + if not torch.cuda.is_available() and device == "cuda": + pytest.skip("No GPU") + if device == "cpu" and strategy != "brute": pytest.skip("Only brute force supported on CPU") if box_type == "triclinic" and strategy == "cell": pytest.skip("Triclinic only supported for brute force") - cutoff=4.999999 - lbox=10.0 + cutoff = 4.999999 + lbox = 10.0 torch.random.manual_seed(1234) np.random.seed(123456) # Generate random positions - positions = 0.25*lbox * torch.rand(num_atoms, 3, device=device, dtype=dtype) - if(box_type is None): + positions = 0.25 * lbox * torch.rand(num_atoms, 3, device=device, dtype=dtype) + if box_type is None: box = None else: - box = torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]).to(dtype).to(device) + box = ( + torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]) + .to(dtype) + .to(device) + ) # Compute reference values using pure pytorch - ref_neighbors = torch.vstack((torch.tril_indices(num_atoms,num_atoms, -1, device=device),)) + ref_neighbors = torch.vstack( + (torch.tril_indices(num_atoms, num_atoms, -1, device=device),) + ) if include_transpose: - ref_neighbors = torch.hstack((ref_neighbors, torch.stack((ref_neighbors[1], ref_neighbors[0])))) + ref_neighbors = torch.hstack( + (ref_neighbors, torch.stack((ref_neighbors[1], ref_neighbors[0]))) + ) if loop: index = torch.arange(num_atoms, device=device) ref_neighbors = torch.hstack((ref_neighbors, torch.stack((index, index)))) @@ -218,9 +314,15 @@ def test_neighbor_grads(device, strategy, loop, include_transpose, dtype, num_at ref_deltas = ref_positions[ref_neighbors[0]] - ref_positions[ref_neighbors[1]] if box is not None: ref_box = box.clone() - ref_deltas -= torch.outer(torch.round(ref_deltas[:,2]/ref_box[2,2]), ref_box[2]) - ref_deltas -= torch.outer(torch.round(ref_deltas[:,1]/ref_box[1,1]), ref_box[1]) - ref_deltas -= torch.outer(torch.round(ref_deltas[:,0]/ref_box[0,0]), ref_box[0]) + ref_deltas -= torch.outer( + torch.round(ref_deltas[:, 2] / ref_box[2, 2]), ref_box[2] + ) + ref_deltas -= torch.outer( + torch.round(ref_deltas[:, 1] / ref_box[1, 1]), ref_box[1] + ) + ref_deltas -= torch.outer( + torch.round(ref_deltas[:, 0] / ref_box[0, 0]), ref_box[0] + ) if loop: ref_distances = torch.zeros((ref_deltas.size(0),), device=device, dtype=dtype) @@ -228,27 +330,44 @@ def test_neighbor_grads(device, strategy, loop, include_transpose, dtype, num_at ref_distances[mask] = torch.linalg.norm(ref_deltas[mask], dim=-1) else: ref_distances = torch.linalg.norm(ref_deltas, dim=-1) - max_num_pairs = max(ref_neighbors.shape[1],1) + max_num_pairs = max(ref_neighbors.shape[1], 1) positions.requires_grad_(True) - nl = DistanceCellList(cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, loop=loop, include_transpose=include_transpose, return_vecs=True, resize_to_fit=True, box=box) + nl = DistanceCellList( + cutoff_upper=cutoff, + max_num_pairs=max_num_pairs, + strategy=strategy, + loop=loop, + include_transpose=include_transpose, + return_vecs=True, + resize_to_fit=True, + box=box, + ) neighbors, distances, deltas = nl(positions) - #Check neighbor pairs are correct - ref_neighbors_sort, _, _ = sort_neighbors(ref_neighbors.clone().cpu().detach().numpy(), ref_deltas.clone().cpu().detach().numpy(), ref_distances.clone().cpu().detach().numpy()) - neighbors_sort, _, _ = sort_neighbors(neighbors.clone().cpu().detach().numpy(), deltas.clone().cpu().detach().numpy(), distances.clone().cpu().detach().numpy()) + # Check neighbor pairs are correct + ref_neighbors_sort, _, _ = sort_neighbors( + ref_neighbors.clone().cpu().detach().numpy(), + ref_deltas.clone().cpu().detach().numpy(), + ref_distances.clone().cpu().detach().numpy(), + ) + neighbors_sort, _, _ = sort_neighbors( + neighbors.clone().cpu().detach().numpy(), + deltas.clone().cpu().detach().numpy(), + distances.clone().cpu().detach().numpy(), + ) assert np.allclose(ref_neighbors_sort, neighbors_sort) # Compute gradients - if grad == 'deltas': + if grad == "deltas": ref_deltas.sum().backward() deltas.sum().backward() - elif grad == 'distances': + elif grad == "distances": ref_distances.sum().backward() distances.sum().backward() - elif grad == 'combined': + elif grad == "combined": (ref_deltas.sum() + ref_distances.sum()).backward() (deltas.sum() + distances.sum()).backward() else: - raise ValueError('grad') + raise ValueError("grad") ref_pos_grad_sorted = ref_positions.grad.cpu().detach().numpy() pos_grad_sorted = positions.grad.cpu().detach().numpy() if dtype == torch.float32: @@ -264,31 +383,50 @@ def test_neighbor_grads(device, strategy, loop, include_transpose, dtype, num_at @pytest.mark.parametrize("loop", [True, False]) @pytest.mark.parametrize("include_transpose", [True, False]) @pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"]) -@pytest.mark.parametrize('dtype', [torch.float32]) -def test_jit_script_compatible(device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype): +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_jit_script_compatible( + device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype +): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") if box_type == "triclinic" and strategy == "cell": pytest.skip("Triclinic only supported for brute force") - if device=="cpu" and strategy!="brute": + if device == "cpu" and strategy != "brute": pytest.skip("Only brute force supported on CPU") torch.manual_seed(4321) n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) - batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch).to(device) - cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) - lbox=10.0 - pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype)*lbox - #Ensure there is at least one pair - pos[0,:] = torch.zeros(3) - pos[1,:] = torch.zeros(3) + batch = torch.repeat_interleave( + torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch + ).to(device) + cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) + lbox = 10.0 + pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox + # Ensure there is at least one pair + pos[0, :] = torch.zeros(3) + pos[1, :] = torch.zeros(3) pos.requires_grad = True - if(box_type is None): + if box_type is None: box = None else: - box = torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]).to(pos.dtype).to(device) - ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box) + box = ( + torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]) + .to(pos.dtype) + .to(device) + ) + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors( + pos, batch, loop, include_transpose, cutoff, box + ) max_num_pairs = ref_neighbors.shape[1] - nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True, include_transpose=include_transpose) + nl = DistanceCellList( + cutoff_lower=0.0, + loop=loop, + cutoff_upper=cutoff, + max_num_pairs=max_num_pairs, + strategy=strategy, + box=box, + return_vecs=True, + include_transpose=include_transpose, + ) batch.to(device) nl = torch.jit.script(nl) @@ -296,7 +434,9 @@ def test_jit_script_compatible(device, strategy, n_batches, cutoff, loop, includ neighbors = neighbors.cpu().detach().numpy() distance_vecs = distance_vecs.cpu().detach().numpy() distances = distances.cpu().detach().numpy() - neighbors, distance_vecs, distances = sort_neighbors(neighbors, distance_vecs, distances) + neighbors, distance_vecs, distances = sort_neighbors( + neighbors, distance_vecs, distances + ) assert neighbors.shape == (2, max_num_pairs) assert distances.shape == (max_num_pairs,) assert distance_vecs.shape == (max_num_pairs, 3) @@ -306,8 +446,6 @@ def test_jit_script_compatible(device, strategy, n_batches, cutoff, loop, includ assert np.allclose(distance_vecs, ref_distance_vecs) - - @pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize("strategy", ["brute", "shared", "cell"]) @pytest.mark.parametrize("n_batches", [1, 128]) @@ -315,41 +453,61 @@ def test_jit_script_compatible(device, strategy, n_batches, cutoff, loop, includ @pytest.mark.parametrize("loop", [True, False]) @pytest.mark.parametrize("include_transpose", [True, False]) @pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"]) -@pytest.mark.parametrize('dtype', [torch.float32]) -def test_cuda_graph_compatible(device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype): +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_cuda_graph_compatible( + device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype +): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") if box_type == "triclinic" and strategy == "cell": pytest.skip("Triclinic only supported for brute force") torch.manual_seed(4321) n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) - batch = torch.repeat_interleave(torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch).to(device) - cumsum = np.cumsum( np.concatenate([[0], n_atoms_per_batch])) - lbox=10.0 - pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype)*lbox - #Ensure there is at least one pair - pos[0,:] = torch.zeros(3) - pos[1,:] = torch.zeros(3) + batch = torch.repeat_interleave( + torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch + ).to(device) + cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) + lbox = 10.0 + pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox + # Ensure there is at least one pair + pos[0, :] = torch.zeros(3) + pos[1, :] = torch.zeros(3) pos.requires_grad = True - if(box_type is None): + if box_type is None: box = None else: - box = torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]).to(pos.dtype).to(device) - ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box) + box = ( + torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]) + .to(pos.dtype) + .to(device) + ) + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors( + pos, batch, loop, include_transpose, cutoff, box + ) max_num_pairs = ref_neighbors.shape[1] - nl = DistanceCellList(cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, box=box, return_vecs=True, include_transpose=include_transpose, check_errors=False, resize_to_fit=False) + nl = DistanceCellList( + cutoff_lower=0.0, + loop=loop, + cutoff_upper=cutoff, + max_num_pairs=max_num_pairs, + strategy=strategy, + box=box, + return_vecs=True, + include_transpose=include_transpose, + check_errors=False, + resize_to_fit=False, + ) batch.to(device) - graph = torch.cuda.CUDAGraph() s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) - #Warm up + # Warm up with torch.cuda.stream(s): for _ in range(10): neighbors, distances, distance_vecs = nl(pos, batch) torch.cuda.synchronize() - #Capture + # Capture with torch.cuda.graph(graph): neighbors, distances, distance_vecs = nl(pos, batch) neighbors.fill_(0) @@ -359,7 +517,9 @@ def test_cuda_graph_compatible(device, strategy, n_batches, cutoff, loop, includ neighbors = neighbors.cpu().detach().numpy() distance_vecs = distance_vecs.cpu().detach().numpy() distances = distances.cpu().detach().numpy() - neighbors, distance_vecs, distances = sort_neighbors(neighbors, distance_vecs, distances) + neighbors, distance_vecs, distances = sort_neighbors( + neighbors, distance_vecs, distances + ) assert neighbors.shape == (2, max_num_pairs) assert distances.shape == (max_num_pairs,) assert distance_vecs.shape == (max_num_pairs, 3) From d3578f99031dfa252adab065ab90bd799a13ea95 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 9 May 2023 20:44:48 +0200 Subject: [PATCH 36/76] Fix more formatting --- torchmdnet/models/utils.py | 66 ++++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 3d2831d42..4b7cfec96 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -77,23 +77,33 @@ def forward(self, z, x, edge_index, edge_weight, edge_attr): def message(self, x_j, W): return x_j * W -from torchmdnet.neighbors import get_neighbor_pairs_brute, get_neighbor_pairs_cell, get_neighbor_pairs_shared -class DistanceCellList(torch.nn.Module): - _backends = { "brute": get_neighbor_pairs_brute, "cell": get_neighbor_pairs_cell, "shared": get_neighbor_pairs_shared } +from torchmdnet.neighbors import ( + get_neighbor_pairs_brute, + get_neighbor_pairs_cell, + get_neighbor_pairs_shared, +) + + +class DistanceCellList(torch.nn.Module): + _backends = { + "brute": get_neighbor_pairs_brute, + "cell": get_neighbor_pairs_cell, + "shared": get_neighbor_pairs_shared, + } def __init__( - self, - cutoff_lower=0.0, - cutoff_upper=5.0, - max_num_pairs=32, - return_vecs=False, - loop=False, - strategy="brute", - include_transpose=True, - resize_to_fit=True, - check_errors=False, - box=None + self, + cutoff_lower=0.0, + cutoff_upper=5.0, + max_num_pairs=32, + return_vecs=False, + loop=False, + strategy="brute", + include_transpose=True, + resize_to_fit=True, + check_errors=False, + box=None, ): super(DistanceCellList, self).__init__() """ Compute the neighbor list for a given cutoff. @@ -138,19 +148,21 @@ def __init__( self.use_periodic = True if self.box is None: self.use_periodic = False - self.box = torch.empty((0,0)) + self.box = torch.empty((0, 0)) if self.strategy == "cell": - #Default the box to 3 times the cutoff, really inefficient for the cell list + # Default the box to 3 times the cutoff, really inefficient for the cell list lbox = cutoff_upper * 3.0 self.box = torch.tensor([[lbox, 0, 0], [0, lbox, 0], [0, 0, lbox]]) - self.box = self.box.cpu() #All strategies expect the box to be in CPU memory + self.box = self.box.cpu() # All strategies expect the box to be in CPU memory self.kernel = self._backends[self.strategy] if self.kernel is None: raise ValueError("Unknown strategy: {}".format(self.strategy)) self.check_errors = check_errors - def forward(self, pos: Tensor, batch: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Optional[Tensor]]: - """ Compute the neighbor list for a given cutoff. + def forward( + self, pos: Tensor, batch: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + """Compute the neighbor list for a given cutoff. Parameters ---------- pos : torch.Tensor @@ -176,7 +188,7 @@ def forward(self, pos: Tensor, batch: Optional[Tensor] = None) -> Tuple[Tensor, self.box = self.box.to(pos.dtype) max_pairs = self.max_num_pairs if self.max_num_pairs < 0: - max_pairs = -self.max_num_pairs*pos.shape[0] + max_pairs = -self.max_num_pairs * pos.shape[0] if batch is None: batch = torch.zeros(pos.shape[0], dtype=torch.long, device=pos.device) neighbors, distance_vecs, distances, num_pairs = self.kernel( @@ -188,22 +200,27 @@ def forward(self, pos: Tensor, batch: Optional[Tensor] = None) -> Tuple[Tensor, max_num_pairs=max_pairs, include_transpose=self.include_transpose, box_vectors=self.box, - use_periodic=self.use_periodic + use_periodic=self.use_periodic, ) if self.check_errors: if num_pairs[0] > self.max_num_pairs: - raise RuntimeError("Found num_pairs({}) > max_num_pairs({})".format(num_pairs[0], self.max_num_pairs)) - #Remove (-1,-1) pairs + raise RuntimeError( + "Found num_pairs({}) > max_num_pairs({})".format( + num_pairs[0], self.max_num_pairs + ) + ) + # Remove (-1,-1) pairs if self.resize_to_fit: mask = neighbors[0] != -1 neighbors = neighbors[:, mask] distances = distances[mask] - distance_vecs = distance_vecs[mask,:] + distance_vecs = distance_vecs[mask, :] if self.return_vecs: return neighbors, distances, distance_vecs else: return neighbors, distances, None + class GaussianSmearing(nn.Module): def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True): super(GaussianSmearing, self).__init__() @@ -379,6 +396,7 @@ def forward(self, pos, batch): # Union typing works with TorchScript (https://github.com/pytorch/pytorch/pull/53180) return edge_index, edge_weight, None + class GatedEquivariantBlock(nn.Module): """Gated Equivariant Block as defined in Schütt et al. (2021): Equivariant message passing for the prediction of tensorial properties and molecular spectra From f15d6cc07e0db2e6b085a652ef57e7b530a46f62 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 11 May 2023 09:51:34 +0200 Subject: [PATCH 37/76] Fix max_num_pairs incorrect error checking with negative max_num_pairs --- torchmdnet/models/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 4b7cfec96..5824a6863 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -96,7 +96,7 @@ def __init__( self, cutoff_lower=0.0, cutoff_upper=5.0, - max_num_pairs=32, + max_num_pairs=-32, return_vecs=False, loop=False, strategy="brute", @@ -203,10 +203,10 @@ def forward( use_periodic=self.use_periodic, ) if self.check_errors: - if num_pairs[0] > self.max_num_pairs: + if num_pairs[0] > max_pairs: raise RuntimeError( "Found num_pairs({}) > max_num_pairs({})".format( - num_pairs[0], self.max_num_pairs + num_pairs[0], max_pairs ) ) # Remove (-1,-1) pairs From 47d256d51b51f52d754bdd876c93a45a0224d390 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 11 May 2023 09:52:04 +0200 Subject: [PATCH 38/76] Cast neighbor pairs to long (expected by torch geometric) --- torchmdnet/models/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 5824a6863..3d3fa3f7e 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -215,6 +215,7 @@ def forward( neighbors = neighbors[:, mask] distances = distances[mask] distance_vecs = distance_vecs[mask, :] + neighbors = neighbors.to(torch.long) if self.return_vecs: return neighbors, distances, distance_vecs else: From eebaebd092dd0d494ab8ad69deb2ce19d820d9c7 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 11 May 2023 15:27:49 +0200 Subject: [PATCH 39/76] Add double precision compatibility test with Distance --- tests/test_neighbors.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 64a740fc5..52dd7460a 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -139,7 +139,8 @@ def test_neighbors( @pytest.mark.parametrize("n_batches", [1, 2, 3, 4]) @pytest.mark.parametrize("cutoff", [0.1, 1.0, 1000.0]) @pytest.mark.parametrize("loop", [True, False]) -def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop): +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop, dtype): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") if device == "cpu" and strategy != "brute": @@ -152,7 +153,7 @@ def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop): ).to(device) cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) lbox = 10.0 - pos = torch.rand(cumsum[-1], 3, device=device) * lbox + pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox # Ensure there is at least one pair pos[0, :] = torch.zeros(3) pos[1, :] = torch.zeros(3) From dd73537cf0f50f57588d1a678859e10c5e559ddf Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 11 May 2023 15:38:52 +0200 Subject: [PATCH 40/76] Check errors by default --- torchmdnet/models/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index cd0aab29e..a3fd8ec31 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -102,7 +102,7 @@ def __init__( strategy="brute", include_transpose=True, resize_to_fit=True, - check_errors=False, + check_errors=True, box=None, ): super(DistanceCellList, self).__init__() From ece61e5f13ace722202f7fea4e3367a5f9174abc Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 11 May 2023 15:39:09 +0200 Subject: [PATCH 41/76] Move extension compilation to DistanceCellList constructor, so a compilation is not triggered by a script not using it. --- torchmdnet/models/utils.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index a3fd8ec31..7dadc963e 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -77,20 +77,7 @@ def forward(self, z, x, edge_index, edge_weight, edge_attr): def message(self, x_j, W): return x_j * W - -from torchmdnet.neighbors import ( - get_neighbor_pairs_brute, - get_neighbor_pairs_cell, - get_neighbor_pairs_shared, -) - - class DistanceCellList(torch.nn.Module): - _backends = { - "brute": get_neighbor_pairs_brute, - "cell": get_neighbor_pairs_cell, - "shared": get_neighbor_pairs_shared, - } def __init__( self, @@ -154,6 +141,17 @@ def __init__( lbox = cutoff_upper * 3.0 self.box = torch.tensor([[lbox, 0, 0], [0, lbox, 0], [0, 0, lbox]]) self.box = self.box.cpu() # All strategies expect the box to be in CPU memory + from torchmdnet.neighbors import ( + get_neighbor_pairs_brute, + get_neighbor_pairs_cell, + get_neighbor_pairs_shared, + ) + self._backends = { + "brute": get_neighbor_pairs_brute, + "cell": get_neighbor_pairs_cell, + "shared": get_neighbor_pairs_shared, + } + self.kernel = self._backends[self.strategy] if self.kernel is None: raise ValueError("Unknown strategy: {}".format(self.strategy)) From cf152331f9d1c94f75f2071e3a1906859c1470a2 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 11 May 2023 18:41:42 +0200 Subject: [PATCH 42/76] Compile backends only if the DistanceCellList module is initialized --- torchmdnet/models/utils.py | 13 ++----------- torchmdnet/neighbors/__init__.py | 23 ++++++++++++++++------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 7dadc963e..87b79641a 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -77,6 +77,7 @@ def forward(self, z, x, edge_index, edge_weight, edge_attr): def message(self, x_j, W): return x_j * W +import torchmdnet.neighbors as neighbors class DistanceCellList(torch.nn.Module): def __init__( @@ -141,17 +142,7 @@ def __init__( lbox = cutoff_upper * 3.0 self.box = torch.tensor([[lbox, 0, 0], [0, lbox, 0], [0, 0, lbox]]) self.box = self.box.cpu() # All strategies expect the box to be in CPU memory - from torchmdnet.neighbors import ( - get_neighbor_pairs_brute, - get_neighbor_pairs_cell, - get_neighbor_pairs_shared, - ) - self._backends = { - "brute": get_neighbor_pairs_brute, - "cell": get_neighbor_pairs_cell, - "shared": get_neighbor_pairs_shared, - } - + self._backends = neighbors.get_backends() self.kernel = self._backends[self.strategy] if self.kernel is None: raise ValueError("Unknown strategy: {}".format(self.strategy)) diff --git a/torchmdnet/neighbors/__init__.py b/torchmdnet/neighbors/__init__.py index b93632471..28ab51ab7 100644 --- a/torchmdnet/neighbors/__init__.py +++ b/torchmdnet/neighbors/__init__.py @@ -2,11 +2,20 @@ import torch as pt from torch.utils import cpp_extension -src_dir = os.path.dirname(__file__) -sources = ['neighbors.cpp', 'neighbors_cpu.cpp'] + (['neighbors_cuda.cu', 'neighbors_cuda_cell.cu', 'backwards.cu'] if pt.cuda.is_available() else []) -sources = [os.path.join(src_dir, name) for name in sources] -cpp_extension.load(name='neighbors', sources=sources, is_python_module=False) -get_neighbor_pairs_brute = pt.ops.neighbors.get_neighbor_pairs_brute -get_neighbor_pairs_shared = pt.ops.neighbors.get_neighbor_pairs_shared -get_neighbor_pairs_cell = pt.ops.neighbors.get_neighbor_pairs_cell +def compile_extension(): + src_dir = os.path.dirname(__file__) + sources = ['neighbors.cpp', 'neighbors_cpu.cpp'] + (['neighbors_cuda.cu', 'neighbors_cuda_cell.cu', 'backwards.cu'] if pt.cuda.is_available() else []) + sources = [os.path.join(src_dir, name) for name in sources] + cpp_extension.load(name='neighbors', sources=sources, is_python_module=False) + +def get_backends(): + compile_extension() + get_neighbor_pairs_brute = pt.ops.neighbors.get_neighbor_pairs_brute + get_neighbor_pairs_shared = pt.ops.neighbors.get_neighbor_pairs_shared + get_neighbor_pairs_cell = pt.ops.neighbors.get_neighbor_pairs_cell + return { + "brute": get_neighbor_pairs_brute, + "cell": get_neighbor_pairs_cell, + "shared": get_neighbor_pairs_shared, + } From 9f459bfd2841410e78efa120a5fea59d66f16c37 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 12 May 2023 09:08:03 +0200 Subject: [PATCH 43/76] Force usage of shared strategy for N>32768 --- torchmdnet/neighbors/neighbors_cuda.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index 5fe099d4d..95189d5aa 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -167,6 +167,10 @@ public: } TORCH_CHECK(box_vectors.device() == torch::kCPU, "Expected \"box_vectors\" to be on CPU"); const int num_atoms = positions.size(0); + if(num_atoms > 32768){ + //The brute force method runs into integer overflow for num_atoms > 32768 + strat = strategy::shared; + } const int num_pairs = max_num_pairs_; const TensorOptions options = positions.options(); const auto stream = getCurrentCUDAStream(positions.get_device()); From 5601b47bcc20a6d71da74e3d4940bf58e8f084f8 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 12 May 2023 09:08:28 +0200 Subject: [PATCH 44/76] Update benchmark --- benchmarks/neighbors.py | 117 ++++++++++++++++++++++++++++------------ 1 file changed, 82 insertions(+), 35 deletions(-) diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py index 361ab003f..610979250 100644 --- a/benchmarks/neighbors.py +++ b/benchmarks/neighbors.py @@ -43,7 +43,6 @@ def benchmark_neighbors( i = np.random.randint(0, n_batches) n_atoms_per_batch[i] += 1 difference -= 1 - else: while difference < 0: i = np.random.randint(0, n_batches) @@ -66,6 +65,7 @@ def benchmark_neighbors( box=box, loop=False, include_transpose=True, + check_errors=False, resize_to_fit=False, ) else: @@ -77,19 +77,20 @@ def benchmark_neighbors( max_num_neighbors=max_num_neighbors, ) # Warmup - for i in range(10): - neighbors, distances, distance_vecs = nl(pos, batch) - if device == "cuda": - torch.cuda.synchronize() + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for i in range(10): + neighbors, distances, distance_vecs = nl(pos, batch) + torch.cuda.synchronize() nruns = 50 - if device == "cuda": - torch.cuda.synchronize() + torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) + graph = torch.cuda.CUDAGraph() # record in a cuda graph if strategy != "distance": - graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): neighbors, distances, distance_vecs = nl(pos, batch) start.record() @@ -101,8 +102,7 @@ def benchmark_neighbors( for i in range(nruns): neighbors, distances, distance_vecs = nl(pos, batch) end.record() - if device == "cuda": - torch.cuda.synchronize() + torch.cuda.synchronize() # Final time return start.elapsed_time(end) / nruns @@ -117,13 +117,8 @@ def benchmark_neighbors( ) ) results = {} - batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] for strategy in ["shared", "brute", "cell", "distance"]: - # print("Strategy: {}".format(strategy)) - # print("--------") - # print("{:<10} {:<10}".format("Batch size", "Time (ms)")) - # print("{:<10} {:<10}".format("----------", "---------")) - # Loop over different number of batches, random for n_batches in batch_sizes: time = benchmark_neighbors( device="cuda", @@ -135,7 +130,6 @@ def benchmark_neighbors( ) # Store results in a dictionary results[strategy, n_batches] = time - # print("{:<10} {:<10.2f}".format(n_batches, time)) print("Summary") print("-------") print( @@ -163,27 +157,80 @@ def benchmark_neighbors( results["distance", n_batches], ) ) + n_particles_list = np.power(2, np.arange(8, 18)) - # Print a second table showing time per atom, show in ns - print("\n") - print("Time per atom") - print( - "{:<10} {:<10} {:<10} {:<10} {:<10}".format( - "Batch size", "Shared(ns)", "Brute(ns)", "Cell(ns)", "Distance(ns)" + for n_batches in [1, 2, 32, 64]: + print( + "Benchmarking neighbor list generation for {} batches with {} neighbors on average".format( + n_batches, mean_num_neighbors + ) ) - ) - print( - "{:<10} {:<10} {:<10} {:<10} {:<10}".format( - "----------", "---------", "---------", "---------", "---------" + results = {} + for strategy in ["shared", "brute", "cell", "distance"]: + for n_particles in n_particles_list: + mean_num_neighbors = min(n_particles, 64) + time = benchmark_neighbors( + device="cuda", + strategy=strategy, + n_batches=n_batches, + total_num_particles=n_particles, + mean_num_neighbors=mean_num_neighbors, + density=density, + ) + # Store results in a dictionary + results[strategy, n_particles] = time + print("Summary") + print("-------") + print( + "{:<10} {:<21} {:<21} {:<18} {:<10}".format( + "N Particles", "Shared(ms)", "Brute(ms)", "Cell(ms)", "Distance(ms)" + ) ) - ) - for n_batches in batch_sizes: print( - "{:<10} {:<10.2f} {:<10.2f} {:<10.2f} {:<10.2f}".format( - n_batches, - results["shared", n_batches] / n_particles * 1e6, - results["brute", n_batches] / n_particles * 1e6, - results["cell", n_batches] / n_particles * 1e6, - results["distance", n_batches] / n_particles * 1e6, + "{:<10} {:<21} {:<21} {:<18} {:<10}".format( + "----------", "---------", "---------", "---------", "---------" ) ) + # Print a column per strategy, show speedup over Distance in parenthesis + for n_particles in n_particles_list: + base = results["distance", n_particles] + brute_speedup = base / results["brute", n_particles] + if n_particles > 32000: + results["brute", n_particles] = 0 + brute_speedup = 0 + print( + "{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format( + n_particles, + results["shared", n_particles], + base / results["shared", n_particles], + results["brute", n_particles], + brute_speedup, + results["cell", n_particles], + base / results["cell", n_particles], + results["distance", n_particles], + ) + ) + + # # Print a second table showing time per atom, show in ns + # print("\n") + # print("Time per atom") + # print( + # "{:<10} {:<10} {:<10} {:<10} {:<10}".format( + # "Batch size", "Shared(ns)", "Brute(ns)", "Cell(ns)", "Distance(ns)" + # ) + # ) + # print( + # "{:<10} {:<10} {:<10} {:<10} {:<10}".format( + # "----------", "---------", "---------", "---------", "---------" + # ) + # ) + # for n_batches in batch_sizes: + # print( + # "{:<10} {:<10.2f} {:<10.2f} {:<10.2f} {:<10.2f}".format( + # n_batches, + # results["shared", n_batches] / n_particles * 1e6, + # results["brute", n_batches] / n_particles * 1e6, + # results["cell", n_batches] / n_particles * 1e6, + # results["distance", n_batches] / n_particles * 1e6, + # ) + # ) From 5fb85ddc5274ced22b3e1d23236b7d978d903e4a Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 12 May 2023 09:12:25 +0200 Subject: [PATCH 45/76] Add algorithm description to doc --- torchmdnet/models/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 87b79641a..32a532b1e 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -109,6 +109,9 @@ def __init__( strategy : str Strategy to use for computing the neighbor list. Can be one of ["shared", "brute", "cell"]. + Shared: An O(N^2) algorithm that leverages CUDA shared memory, best for large number of particles. + Brute: A brute force O(N^2) algorithm, best for small number of particles. + Cell: A cell list algorithm, best for large number of particles, low cutoffs and low batch size. box : Optional[torch.Tensor] Size of the box, shape (3,3) or None. If strategy is "cell", the box must be diagonal. From 4be5d52194d383dc9eb2a78c5ff68dacdede6d88 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 12 May 2023 09:14:00 +0200 Subject: [PATCH 46/76] Update formatting --- torchmdnet/neighbors/__init__.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/torchmdnet/neighbors/__init__.py b/torchmdnet/neighbors/__init__.py index 28ab51ab7..b0b60480f 100644 --- a/torchmdnet/neighbors/__init__.py +++ b/torchmdnet/neighbors/__init__.py @@ -5,9 +5,14 @@ def compile_extension(): src_dir = os.path.dirname(__file__) - sources = ['neighbors.cpp', 'neighbors_cpu.cpp'] + (['neighbors_cuda.cu', 'neighbors_cuda_cell.cu', 'backwards.cu'] if pt.cuda.is_available() else []) + sources = ["neighbors.cpp", "neighbors_cpu.cpp"] + ( + ["neighbors_cuda.cu", "neighbors_cuda_cell.cu", "backwards.cu"] + if pt.cuda.is_available() + else [] + ) sources = [os.path.join(src_dir, name) for name in sources] - cpp_extension.load(name='neighbors', sources=sources, is_python_module=False) + cpp_extension.load(name="neighbors", sources=sources, is_python_module=False) + def get_backends(): compile_extension() @@ -15,7 +20,7 @@ def get_backends(): get_neighbor_pairs_shared = pt.ops.neighbors.get_neighbor_pairs_shared get_neighbor_pairs_cell = pt.ops.neighbors.get_neighbor_pairs_cell return { - "brute": get_neighbor_pairs_brute, - "cell": get_neighbor_pairs_cell, - "shared": get_neighbor_pairs_shared, + "brute": get_neighbor_pairs_brute, + "cell": get_neighbor_pairs_cell, + "shared": get_neighbor_pairs_shared, } From 730b0a17c6dc4d4fdc2b1d9b8769c80cfd765210 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 12 May 2023 09:16:02 +0200 Subject: [PATCH 47/76] Add header comment --- torchmdnet/neighbors/backwards.cu | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchmdnet/neighbors/backwards.cu b/torchmdnet/neighbors/backwards.cu index 67ed26195..9ac194db5 100644 --- a/torchmdnet/neighbors/backwards.cu +++ b/torchmdnet/neighbors/backwards.cu @@ -1,3 +1,6 @@ +/* Raul P. Pelaez 2023. Backwards pass for the CUDA neighbor list operation. + Computes the gradient of the positions with respect to the distances and deltas. + */ #include "common.cuh" template @@ -12,8 +15,8 @@ backward_kernel(const Accessor neighbors, const Accessor neighbors, const Accessorsaved_data["num_atoms"].toInt(); From 9610a9081069dc683912d72313a781259e1fd9cd Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 15 May 2023 12:16:50 +0200 Subject: [PATCH 48/76] Remove some unused code. Add some comments. --- torchmdnet/neighbors/backwards.cu | 2 +- torchmdnet/neighbors/common.cuh | 26 ++++++++------------------ 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/torchmdnet/neighbors/backwards.cu b/torchmdnet/neighbors/backwards.cu index 9ac194db5..407fbc47a 100644 --- a/torchmdnet/neighbors/backwards.cu +++ b/torchmdnet/neighbors/backwards.cu @@ -28,8 +28,8 @@ backward_kernel(const Accessor neighbors, const Accessorsaved_data["num_atoms"].toInt(); diff --git a/torchmdnet/neighbors/common.cuh b/torchmdnet/neighbors/common.cuh index 63d9e69ba..61a9701a0 100644 --- a/torchmdnet/neighbors/common.cuh +++ b/torchmdnet/neighbors/common.cuh @@ -1,3 +1,5 @@ +/* Raul P. Pelaez 2023. Common utilities for the CUDA neighbor operation. + */ #ifndef NEIGHBORS_COMMON_CUH #define NEIGHBORS_COMMON_CUH #include @@ -34,24 +36,14 @@ template <> __device__ __forceinline__ double sqrt_(double x) { return ::sqrt(x); }; -template struct vec4 { - using type = void; -}; -template <> struct vec4 { - using type = float4; -}; -template <> struct vec4 { - using type = double4; -}; - -template using scalar4 = typename vec4::type; - template struct vec3 { using type = void; }; + template <> struct vec3 { using type = float3; }; + template <> struct vec3 { using type = double3; }; @@ -59,12 +51,6 @@ template <> struct vec3 { template using scalar3 = typename vec3::type; static void checkInput(const Tensor& positions, const Tensor& batch) { - // Batch contains the molecule index for each atom in positions - // Neighbors are only calculated within the same molecule - // Batch is a 1D tensor of size (N_atoms) - // Batch is assumed to be sorted - // Batch is assumed to be contiguous - // Batch is assumed to be of type torch::kLong TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); @@ -154,5 +140,9 @@ __device__ auto compute_distance(scalar3 pos_i, scalar3 pos_ } // namespace triclinic +/* + * Backward pass for the CUDA neighbor list operation. + * Computes the gradient of the positions with respect to the distances and deltas. + */ tensor_list common_backward(AutogradContext* ctx, tensor_list grad_inputs); #endif From 8866416810c087125df6879713b3eb98b85daf4a Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 17 May 2023 11:19:21 +0200 Subject: [PATCH 49/76] Clarify handling of cutoffs in the documentation --- torchmdnet/models/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 32a532b1e..6b178a4ea 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -97,6 +97,7 @@ def __init__( """ Compute the neighbor list for a given cutoff. This operation can be placed inside a CUDA graph in some cases. In particular, resize_to_fit and check_errors must be False. + Note that this module returns neighbors such that distance(i,j) >= cutoff_lower and distance(i,j) < cutoff_upper. Parameters ---------- cutoff_lower : float From b0454baa04be5103bd55f2ed4206b37adcfe71c7 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 17 May 2023 15:12:16 +0200 Subject: [PATCH 50/76] Several optimizations to the cell list. Fix some formatting issues. --- torchmdnet/neighbors/backwards.cu | 15 +- torchmdnet/neighbors/common.cuh | 2 +- torchmdnet/neighbors/neighbors_cuda.cu | 2 +- torchmdnet/neighbors/neighbors_cuda_cell.cu | 340 +++++++++++++------- 4 files changed, 227 insertions(+), 132 deletions(-) diff --git a/torchmdnet/neighbors/backwards.cu b/torchmdnet/neighbors/backwards.cu index 407fbc47a..4875d7e84 100644 --- a/torchmdnet/neighbors/backwards.cu +++ b/torchmdnet/neighbors/backwards.cu @@ -10,8 +10,9 @@ backward_kernel(const Accessor neighbors, const Accessor grad_distances, Accessor grad_positions) { const int32_t i_pair = blockIdx.x * blockDim.x + threadIdx.x; const int32_t num_pairs = neighbors.size(1); - if (i_pair >= num_pairs) + if (i_pair >= num_pairs){ return; + } const int32_t i_dir = blockIdx.y; const int32_t i_atom = neighbors[i_dir][i_pair]; const int32_t i_comp = blockIdx.z; @@ -29,9 +30,9 @@ backward_kernel(const Accessor neighbors, const Accessorsaved_data["num_atoms"].toInt(); const int num_pairs = grad_distances.size(0); const int num_threads = 128; @@ -40,9 +41,9 @@ tensor_list common_backward(AutogradContext* ctx, tensor_list grad_inputs) { const auto stream = getCurrentCUDAStream(grad_distances.get_device()); const tensor_list data = ctx->get_saved_variables(); - const Tensor neighbors = data[0]; - const Tensor deltas = data[1]; - const Tensor distances = data[2]; + const Tensor& neighbors = data[0]; + const Tensor& deltas = data[1]; + const Tensor& distances = data[2]; const Tensor grad_positions = zeros({num_atoms, 3}, grad_distances.options()); AT_DISPATCH_FLOATING_TYPES(grad_distances.scalar_type(), "getNeighborPairs::backward", [&]() { diff --git a/torchmdnet/neighbors/common.cuh b/torchmdnet/neighbors/common.cuh index 61a9701a0..28a163620 100644 --- a/torchmdnet/neighbors/common.cuh +++ b/torchmdnet/neighbors/common.cuh @@ -144,5 +144,5 @@ __device__ auto compute_distance(scalar3 pos_i, scalar3 pos_ * Backward pass for the CUDA neighbor list operation. * Computes the gradient of the positions with respect to the distances and deltas. */ -tensor_list common_backward(AutogradContext* ctx, tensor_list grad_inputs); +tensor_list common_backward(AutogradContext* ctx, const tensor_list &grad_inputs); #endif diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index 95189d5aa..a85ff2d5a 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -237,7 +237,7 @@ public: return {neighbors, deltas, distances, i_curr_pair}; } - static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) { + static tensor_list backward(AutogradContext* ctx, const tensor_list &grad_inputs) { return common_backward(ctx, grad_inputs); } }; diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cu index 959576d48..deba7160b 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cu @@ -6,6 +6,17 @@ #include #include +/* + * @brief Get the position of the i'th particle + * @param positions The positions tensor + * @param i The index of the particle + * @return The position of the i'th particle + */ +template +__device__ scalar3 fetchPosition(const Accessor positions, const int i) { + return {positions[i][0], positions[i][1], positions[i][2]}; +} + /* * @brief Encodes an unsigned integer lower than 1024 as a 32 bit integer by filling every third * bit. @@ -136,11 +147,11 @@ __global__ void assignHash(const Accessor positions, uint64_t* hash /* * @brief A buffer that is allocated and deallocated using the CUDA caching allocator from torch */ -template struct cached_buffer { - cached_buffer(size_t size) : size_(size) { +template struct CachedBuffer { + explicit CachedBuffer(size_t size) : size_(size) { ptr_ = static_cast(at::cuda::CUDACachingAllocator::raw_alloc(size * sizeof(T))); } - ~cached_buffer() { + ~CachedBuffer() { at::cuda::CUDACachingAllocator::raw_delete(ptr_); } T* get() { @@ -162,12 +173,13 @@ private: * @param batch The batch index of each atom * @param box_size The box vectors * @param cutoff The cutoff - * @return A tuple of the sorted positions and the original indices of each atom in the sorted list + * @return A tuple of the sorted positions, sorted batch indices and the original indices of each + * atom in the sorted list */ static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, const Tensor& box_size, const Scalar& cutoff) { const int num_atoms = positions.size(0); - auto hash_keys = cached_buffer(num_atoms); + auto hash_keys = CachedBuffer(num_atoms); Tensor hash_values = empty({num_atoms}, positions.options().dtype(torch::kInt32)); const int threads = 128; const int blocks = (num_atoms + threads - 1) / threads; @@ -185,42 +197,40 @@ static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, // I have to use cub directly because thrust::sort_by_key is not compatible with graphs // and torch::lexsort does not support uint64_t size_t tmp_storage_bytes = 0; - auto d_keys_out = cached_buffer(num_atoms); - auto d_values_out = cached_buffer(num_atoms); - int32_t* hash_values_ptr = hash_values.data_ptr(); + auto d_keys_out = CachedBuffer(num_atoms); + auto d_values_out = CachedBuffer(num_atoms); + auto* hash_values_ptr = hash_values.data_ptr(); cub::DeviceRadixSort::SortPairs(nullptr, tmp_storage_bytes, hash_keys.get(), d_keys_out.get(), hash_values_ptr, d_values_out.get(), num_atoms, 0, 64, stream); - auto tmp_storage = cached_buffer(tmp_storage_bytes); + auto tmp_storage = CachedBuffer(tmp_storage_bytes); cub::DeviceRadixSort::SortPairs(tmp_storage.get(), tmp_storage_bytes, hash_keys.get(), d_keys_out.get(), hash_values_ptr, d_values_out.get(), num_atoms, 0, 64, stream); cudaMemcpyAsync(hash_values_ptr, d_values_out.get(), num_atoms * sizeof(int32_t), cudaMemcpyDeviceToDevice, stream); Tensor sorted_positions = positions.index_select(0, hash_values); - return std::make_tuple(sorted_positions, hash_values); + Tensor sorted_batch = batch.index_select(0, hash_values); + return std::make_tuple(sorted_positions, sorted_batch, hash_values); } template __global__ void fillCellOffsetsD(const Accessor sorted_positions, const Accessor sorted_indices, Accessor cell_start, Accessor cell_end, - const Accessor batch, scalar3 box_size, - scalar_t cutoff) { + scalar3 box_size, scalar_t cutoff) { // Since positions are sorted by cell, for a given atom, if the previous atom is in a different // cell, then the current atom is the first atom in its cell We use this fact to fill the // cell_start and cell_end arrays const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; if (i_atom >= sorted_positions.size(0)) return; - const scalar3 pi = {sorted_positions[i_atom][0], sorted_positions[i_atom][1], - sorted_positions[i_atom][2]}; + const auto pi = fetchPosition(sorted_positions, i_atom); const int3 cell_dim = getCellDimensions(box_size, cutoff); const int icell = getCellIndex(getCell(pi, box_size, cutoff), cell_dim); int im1_cell; if (i_atom > 0) { int im1 = i_atom - 1; - const scalar3 pim1 = {sorted_positions[im1][0], sorted_positions[im1][1], - sorted_positions[im1][2]}; + const auto pim1 = fetchPosition(sorted_positions, im1); im1_cell = getCellIndex(getCell(pim1, box_size, cutoff), cell_dim); } else { im1_cell = 0; @@ -228,8 +238,9 @@ __global__ void fillCellOffsetsD(const Accessor sorted_positions, if (icell != im1_cell || i_atom == 0) { int n_cells = cell_start.size(0); cell_start[icell] = i_atom; - if (i_atom > 0) + if (i_atom > 0) { cell_end[im1_cell] = i_atom; + } } if (i_atom == sorted_positions.size(0) - 1) { cell_end[icell] = i_atom + 1; @@ -237,9 +248,7 @@ __global__ void fillCellOffsetsD(const Accessor sorted_positions, } /* - @brief - Fill the cell offsets for each batch, identifying the start and end of each cell in the sorted - positions + @brief Fills the cell_start and cell_end arrays, identifying the first and last atom in each cell @param sorted_positions The positions sorted by cell @param sorted_indices The original indices of the sorted positions @param batch The batch index of each position @@ -248,7 +257,7 @@ __global__ void fillCellOffsetsD(const Accessor sorted_positions, @return A tuple of cell_start and cell_end arrays */ static auto fillCellOffsets(const Tensor& sorted_positions, const Tensor& sorted_indices, - const Tensor& batch, const Tensor& box_size, const Scalar& cutoff) { + const Tensor& box_size, const Scalar& cutoff) { const TensorOptions options = sorted_positions.options(); int3 cell_dim; AT_DISPATCH_FLOATING_TYPES(sorted_positions.scalar_type(), "fillCellOffsets", [&] { @@ -271,8 +280,8 @@ static auto fillCellOffsets(const Tensor& sorted_positions, const Tensor& sorted box_size[2][2].item()}; fillCellOffsetsD<<>>( get_accessor(sorted_positions), get_accessor(sorted_indices), - get_accessor(cell_start), get_accessor(cell_end), - get_accessor(batch), box_size_, cutoff_); + get_accessor(cell_start), get_accessor(cell_end), box_size_, + cutoff_); }); return std::make_tuple(cell_start, cell_end); } @@ -294,83 +303,186 @@ __device__ int getNeighborCellIndex(int3 cell_i, int i, int3 cell_dim) { return icellj; } +template struct Particle { + int index; // Index in the sorted arrays + int original_index; // Index in the original arrays + int batch; + scalar3 position; + scalar_t cutoff_upper2, cutoff_lower2; +}; + +struct CellList { + Tensor cell_start, cell_end; + Tensor original_indices; + Tensor sorted_positions, sorted_batch; +}; + +struct PairList { + Tensor i_curr_pair; + Tensor neighbors; + Tensor deltas; + Tensor distances; + const bool loop, include_transpose, use_periodic; + PairList(int max_num_pairs, TensorOptions options, bool loop, bool include_transpose, + bool use_periodic) + : i_curr_pair(zeros({1}, options.dtype(torch::kInt))), + neighbors(full({2, max_num_pairs}, -1, options.dtype(torch::kInt))), + deltas(empty({max_num_pairs, 3}, options)), distances(full({max_num_pairs}, 0, options)), + loop(loop), include_transpose(include_transpose), use_periodic(use_periodic) { + } +}; + +CellList constructCellList(const Tensor& positions, const Tensor& batch, const Tensor& box_size, + const Scalar& cutoff) { + // The algorithm for the cell list construction can be summarized in three separate steps: + // 1. Hash (label) the particles according to the cell (bin) they lie in. + // 2. Sort the particles and hashes using the hashes as the ordering label + // (technically this is known as sorting by key). So that particles with positions + // lying in the same cell become contiguous in memory. + // 3. Identify where each cell starts and ends in the sorted particle positions + // array. + const TensorOptions options = positions.options(); + CellList cl; + // Steps 1 and 2 + std::tie(cl.sorted_positions, cl.sorted_batch, cl.original_indices) = + sortPositionsByHash(positions, batch, box_size, cutoff); + // Step 3 + std::tie(cl.cell_start, cl.cell_end) = + fillCellOffsets(cl.sorted_positions, cl.original_indices, box_size, cutoff); + return cl; +} + +template struct CellListAccessor { + Accessor cell_start, cell_end; + Accessor original_indices; + Accessor sorted_positions; + Accessor sorted_batch; + + CellListAccessor(const CellList& cl) + : cell_start(get_accessor(cl.cell_start)), + cell_end(get_accessor(cl.cell_end)), + original_indices(get_accessor(cl.original_indices)), + sorted_positions(get_accessor(cl.sorted_positions)), + sorted_batch(get_accessor(cl.sorted_batch)) { + } +}; + +template struct PairListAccessor { + Accessor i_curr_pair; + Accessor neighbors; + Accessor deltas; + Accessor distances; + bool loop, include_transpose, use_periodic; + PairListAccessor(const PairList& pl) + : i_curr_pair(get_accessor(pl.i_curr_pair)), + neighbors(get_accessor(pl.neighbors)), + deltas(get_accessor(pl.deltas)), + distances(get_accessor(pl.distances)), loop(pl.loop), + include_transpose(pl.include_transpose), use_periodic(pl.use_periodic) { + } +}; + +/* + * @brief Add a pair of particles to the pair list. If necessary, also add the transpose pair. + * @param list The pair list + * @param i The index of the first particle + * @param j The index of the second particle + * @param distance2 The squared distance between the particles + * @param delta The vector between the particles + */ +template +__device__ void addNeighborPair(PairListAccessor& list, const int i, const int j, + scalar_t distance2, const scalar3 delta) { + const bool requires_transpose = list.include_transpose and (j != i); + const int32_t i_pair = atomicAdd(&list.i_curr_pair[0], requires_transpose ? 2 : 1); + // We handle too many neighbors outside of the kernel + if (i_pair + requires_transpose < list.neighbors.size(1)) { + const int ni = thrust::max(i, j); + const int nj = thrust::min(i, j); + const scalar_t delta_sign = (ni == i) ? scalar_t(1.0) : scalar_t(-1.0); + const scalar_t distance = sqrt_(distance2); + list.neighbors[0][i_pair] = ni; + list.neighbors[1][i_pair] = nj; + list.deltas[i_pair][0] = delta_sign * delta.x; + list.deltas[i_pair][1] = delta_sign * delta.y; + list.deltas[i_pair][2] = delta_sign * delta.z; + list.distances[i_pair] = distance; + if (requires_transpose) { + list.neighbors[0][i_pair + 1] = nj; + list.neighbors[1][i_pair + 1] = ni; + list.deltas[i_pair + 1][0] = -delta_sign * delta.x; + list.deltas[i_pair + 1][1] = -delta_sign * delta.y; + list.deltas[i_pair + 1][2] = -delta_sign * delta.z; + list.distances[i_pair + 1] = distance; + } + } +} + +/* + * @brief Add to the pair list all neighbors of particle i_atom in cell j_cell + * @param i_atom The Information of the particle for which we are adding neighbors + * @param j_cell The index of the cell in which we are looking for neighbors + * @param cl The cell list + * @param box_size The box size + * @param list The pair list + */ +template +__device__ void addNeighborsForCell(const Particle& i_atom, int j_cell, + const CellListAccessor& cl, + scalar3 box_size, PairListAccessor& list) { + + const auto first_particle = cl.cell_start[j_cell]; + if (first_particle != -1) { // Continue only if there are particles in this cell + const auto last_particle = cl.cell_end[j_cell]; + for (int cur_j = first_particle; cur_j < last_particle; cur_j++) { + const auto j_batch = cl.sorted_batch[cur_j]; + // Particles are sorted by batch after cell, so we can break early here + if (j_batch > i_atom.batch) { + break; + } + if ((j_batch == i_atom.batch) and + ((cur_j < i_atom.index) or (list.loop and cur_j == i_atom.index))) { + const auto position_j = fetchPosition(cl.sorted_positions, cur_j); + const auto delta = rect::compute_distance(i_atom.position, position_j, + list.use_periodic, box_size); + const scalar_t distance2 = + delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; + if ((distance2 < i_atom.cutoff_upper2 and distance2 >= i_atom.cutoff_lower2) or + (list.loop and cur_j == i_atom.index)) { + const int orj = cl.original_indices[cur_j]; + addNeighborPair(list, i_atom.original_index, orj, distance2, delta); + } // endif + } // endif + } // endfor + } // endif +} + // Traverse the cell list for each atom and find the neighbors template -__global__ void -forward_kernel(const Accessor sorted_positions, - const Accessor original_index, const Accessor batch, - const Accessor cell_start, const Accessor cell_end, - Accessor neighbors, Accessor deltas, - Accessor distances, Accessor i_curr_pair, int num_atoms, - int num_pairs, bool use_periodic, scalar3 box_size, scalar_t cutoff_lower, - scalar_t cutoff_upper, bool loop, bool include_transpose) { +__global__ void traverseCellList(const CellListAccessor cell_list, + PairListAccessor list, int num_atoms, + scalar3 box_size, scalar_t cutoff_lower, + scalar_t cutoff_upper) { // Each atom traverses the cells around it and finds the neighbors // Atoms for all batches are placed in the same cell list, but other batches are ignored while // traversing - const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; - if (i_atom >= num_atoms) + Particle i_atom; + i_atom.index = blockIdx.x * blockDim.x + threadIdx.x; + if (i_atom.index >= num_atoms) { return; - const int ori = original_index[i_atom]; - const auto i_batch = batch[ori]; - const scalar3 pi = {sorted_positions[i_atom][0], sorted_positions[i_atom][1], - sorted_positions[i_atom][2]}; - const int3 cell_i = getCell(pi, box_size, cutoff_upper); + } + i_atom.original_index = cell_list.original_indices[i_atom.index]; + i_atom.batch = cell_list.sorted_batch[i_atom.index]; + i_atom.position = fetchPosition(cell_list.sorted_positions, i_atom.index); + i_atom.cutoff_lower2 = cutoff_lower * cutoff_lower; + i_atom.cutoff_upper2 = cutoff_upper * cutoff_upper; + const int3 cell_i = getCell(i_atom.position, box_size, cutoff_upper); const int3 cell_dim = getCellDimensions(box_size, cutoff_upper); // Loop over the 27 cells around the current cell for (int i = 0; i < 27; i++) { - int icellj = getNeighborCellIndex(cell_i, i, cell_dim); - const int firstParticle = cell_start[icellj]; - if (firstParticle != -1) { // Continue only if there are particles in this cell - // Index of the last particle in the cell's list - const int lastParticle = cell_end[icellj]; - const int nincell = lastParticle - firstParticle; - for (int j = 0; j < nincell; j++) { - const int cur_j = j + firstParticle; - const int orj = original_index[cur_j]; - const auto j_batch = batch[orj]; - if (j_batch > - i_batch) // Particles are sorted by batch after cell, so we can break early here - break; - const bool testPair = - (j_batch == i_batch) and ((orj < ori) or (loop and orj == ori)); - if (testPair) { - const scalar3 pj = {sorted_positions[cur_j][0], - sorted_positions[cur_j][1], - sorted_positions[cur_j][2]}; - const auto delta = - rect::compute_distance(pi, pj, use_periodic, box_size); - const scalar_t distance2 = - delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; - const scalar_t cutoff_upper2 = cutoff_upper * cutoff_upper; - const scalar_t cutoff_lower2 = cutoff_lower * cutoff_lower; - if ((distance2 < cutoff_upper2 and distance2 >= cutoff_lower2) or - (loop and orj == ori)) { - const bool requires_transpose = include_transpose and (orj != ori); - const int32_t i_pair = - atomicAdd(&i_curr_pair[0], requires_transpose ? 2 : 1); - // We handle too many neighbors outside of the kernel - if (i_pair + requires_transpose < neighbors.size(1)) { - const scalar_t distance = sqrt_(distance2); - neighbors[0][i_pair] = ori; - neighbors[1][i_pair] = orj; - deltas[i_pair][0] = delta.x; - deltas[i_pair][1] = delta.y; - deltas[i_pair][2] = delta.z; - distances[i_pair] = distance; - if (requires_transpose) { - neighbors[0][i_pair + 1] = orj; - neighbors[1][i_pair + 1] = ori; - deltas[i_pair + 1][0] = -delta.x; - deltas[i_pair + 1][1] = -delta.y; - deltas[i_pair + 1][2] = -delta.z; - distances[i_pair + 1] = distance; - } - } // endif - } // endif - } // endfor - } // endif - } // endfor - } // endfor + const int neighbor_cell = getNeighborCellIndex(cell_i, i, cell_dim); + addNeighborsForCell(i_atom, neighbor_cell, cell_list, box_size, list); + } } class Autograd : public Function { @@ -379,13 +491,9 @@ public: const Tensor& box_size, bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop, bool include_transpose) { - // The algorithm for the cell list construction can be summarized in three separate steps: - // 1. Hash (label) the particles according to the cell (bin) they lie in. - // 2. Sort the particles and hashes using the hashes as the ordering label - // (technically this is known as sorting by key). So that particles with positions - // lying in the same cell become contiguous in memory. - // 3. Identify where each cell starts and ends in the sorted particle positions - // array. + // This module computes the pair list for a given set of particles, which may be in multiple + // batches. The strategy is to first compute a cell list for all particles, and then + // traverse the cell list for each particle to construct a pair list. checkInput(positions, batch); TORCH_CHECK(box_size.dim() == 2, "Expected \"box_size\" to have two dimensions"); TORCH_CHECK(box_size.size(0) == 3 && box_size.size(1) == 3, @@ -397,19 +505,8 @@ public: const auto max_num_pairs_ = max_num_pairs.toInt(); TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); const int num_atoms = positions.size(0); - const TensorOptions options = positions.options(); - // Steps 1 and 2 - Tensor sorted_positions, hash_values; - std::tie(sorted_positions, hash_values) = - sortPositionsByHash(positions, batch, box_size, cutoff_upper); - Tensor cell_start, cell_end; - // Step 3 - std::tie(cell_start, cell_end) = - fillCellOffsets(sorted_positions, hash_values, batch, box_size, cutoff_upper); - const Tensor neighbors = full({2, max_num_pairs_}, -1, options.dtype(kInt32)); - const Tensor deltas = empty({max_num_pairs_, 3}, options); - const Tensor distances = full(max_num_pairs_, 0, options); - const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); + const auto cell_list = constructCellList(positions, batch, box_size, cutoff_upper); + PairList list(max_num_pairs_, positions.options(), loop, include_transpose, use_periodic); const auto stream = getCurrentCUDAStream(positions.get_device()); { // Traverse the cell list to find the neighbors const CUDAStreamGuard guard(stream); @@ -420,24 +517,21 @@ public: const scalar3 box_size_ = {box_size[0][0].item(), box_size[1][1].item(), box_size[2][2].item()}; - const int threads = 128; + PairListAccessor list_accessor(list); + CellListAccessor cell_list_accessor(cell_list); + const int threads = 256; const int blocks = (num_atoms + threads - 1) / threads; - forward_kernel<<>>( - get_accessor(sorted_positions), - get_accessor(hash_values), get_accessor(batch), - get_accessor(cell_start), get_accessor(cell_end), - get_accessor(neighbors), get_accessor(deltas), - get_accessor(distances), get_accessor(i_curr_pair), - num_atoms, max_num_pairs_, use_periodic, box_size_, cutoff_lower_, - cutoff_upper_, loop, include_transpose); + traverseCellList<<>>(cell_list_accessor, list_accessor, + num_atoms, box_size_, + cutoff_lower_, cutoff_upper_); }); } - ctx->save_for_backward({neighbors, deltas, distances}); + ctx->save_for_backward({list.neighbors, list.deltas, list.distances}); ctx->saved_data["num_atoms"] = num_atoms; - return {neighbors, deltas, distances, i_curr_pair}; + return {list.neighbors, list.deltas, list.distances, list.i_curr_pair}; } - static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) { + static tensor_list backward(AutogradContext* ctx, const tensor_list& grad_inputs) { return common_backward(ctx, grad_inputs); } }; From 9d5028e3a03bce324263b1f893d2a5289a96781f Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 17 May 2023 15:12:44 +0200 Subject: [PATCH 51/76] Update tests to omit certain redundant combinations (like the shared strategy in the CPU) --- tests/test_neighbors.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 52dd7460a..8b517db1f 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -66,8 +66,7 @@ def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box_vecto return ref_neighbors, ref_distance_vecs, ref_distances -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"]) +@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")]) @pytest.mark.parametrize("n_batches", [1, 2, 3, 4, 128]) @pytest.mark.parametrize("cutoff", [0.1, 1.0, 3.0, 4.9]) @pytest.mark.parametrize("loop", [True, False]) @@ -80,7 +79,7 @@ def test_neighbors( if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") if box_type == "triclinic" and strategy == "cell": - pytest.skip("Triclinic only supported for brute force") + pytest.skip("Triclinic not supported for cell") if device == "cpu" and strategy != "brute": pytest.skip("Only brute force supported on CPU") torch.manual_seed(4321) @@ -134,8 +133,7 @@ def test_neighbors( assert np.allclose(distance_vecs, ref_distance_vecs) -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -@pytest.mark.parametrize("strategy", ["brute", "cell", "shared"]) +@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")]) @pytest.mark.parametrize("n_batches", [1, 2, 3, 4]) @pytest.mark.parametrize("cutoff", [0.1, 1.0, 1000.0]) @pytest.mark.parametrize("loop", [True, False]) @@ -267,8 +265,7 @@ def test_large_size(strategy, n_batches): assert np.allclose(distance_vecs, ref_distance_vecs) -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"]) +@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")]) @pytest.mark.parametrize("loop", [True, False]) @pytest.mark.parametrize("include_transpose", [True, False]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) @@ -377,8 +374,7 @@ def test_neighbor_grads( assert np.allclose(ref_pos_grad_sorted, pos_grad_sorted, atol=1e-8, rtol=1e-5) -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"]) +@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")]) @pytest.mark.parametrize("n_batches", [1, 128]) @pytest.mark.parametrize("cutoff", [1.0]) @pytest.mark.parametrize("loop", [True, False]) From 990f23f1a70cf13f921cf18fbdf1a8732841054b Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 18 May 2023 10:51:49 +0200 Subject: [PATCH 52/76] Change import location --- torchmdnet/models/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 6b178a4ea..b608dd8e0 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from torch_geometric.nn import MessagePassing from torch_cluster import radius_graph +import torchmdnet.neighbors as neighbors import warnings @@ -77,7 +78,7 @@ def forward(self, z, x, edge_index, edge_weight, edge_attr): def message(self, x_j, W): return x_j * W -import torchmdnet.neighbors as neighbors + class DistanceCellList(torch.nn.Module): def __init__( From 2f81243284bcd4a6570860822aad0435a3f18441 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 18 May 2023 10:53:39 +0200 Subject: [PATCH 53/76] Change name to OptimizedDistance --- benchmarks/neighbors.py | 4 ++-- tests/test_neighbors.py | 14 +++++++------- torchmdnet/models/utils.py | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py index 610979250..43df0cd41 100644 --- a/benchmarks/neighbors.py +++ b/benchmarks/neighbors.py @@ -1,7 +1,7 @@ import os import torch import numpy as np -from torchmdnet.models.utils import Distance, DistanceCellList +from torchmdnet.models.utils import Distance, OptimizedDistance def benchmark_neighbors( @@ -58,7 +58,7 @@ def benchmark_neighbors( if strategy != "distance": max_num_pairs = (expected_num_neighbors * n_atoms_per_batch.sum()).item() * 2 box = torch.eye(3, device=device) * lbox - nl = DistanceCellList( + nl = OptimizedDistance( cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 8b517db1f..06707056d 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -3,7 +3,7 @@ import torch import torch.jit import numpy as np -from torchmdnet.models.utils import Distance, DistanceCellList +from torchmdnet.models.utils import Distance, OptimizedDistance def sort_neighbors(neighbors, deltas, distances): @@ -106,7 +106,7 @@ def test_neighbors( pos, batch, loop, include_transpose, cutoff, box ) max_num_pairs = ref_neighbors.shape[1] - nl = DistanceCellList( + nl = OptimizedDistance( cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, @@ -178,7 +178,7 @@ def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop, dty max_num_pairs = ref_neighbors.shape[1] box = None - nl = DistanceCellList( + nl = OptimizedDistance( cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, @@ -242,7 +242,7 @@ def test_large_size(strategy, n_batches): # Must check without PBC since Distance does not support it box = None - nl = DistanceCellList( + nl = OptimizedDistance( cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, @@ -330,7 +330,7 @@ def test_neighbor_grads( ref_distances = torch.linalg.norm(ref_deltas, dim=-1) max_num_pairs = max(ref_neighbors.shape[1], 1) positions.requires_grad_(True) - nl = DistanceCellList( + nl = OptimizedDistance( cutoff_upper=cutoff, max_num_pairs=max_num_pairs, strategy=strategy, @@ -414,7 +414,7 @@ def test_jit_script_compatible( pos, batch, loop, include_transpose, cutoff, box ) max_num_pairs = ref_neighbors.shape[1] - nl = DistanceCellList( + nl = OptimizedDistance( cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, @@ -482,7 +482,7 @@ def test_cuda_graph_compatible( pos, batch, loop, include_transpose, cutoff, box ) max_num_pairs = ref_neighbors.shape[1] - nl = DistanceCellList( + nl = OptimizedDistance( cutoff_lower=0.0, loop=loop, cutoff_upper=cutoff, diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index b608dd8e0..47d9bdfdb 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -79,7 +79,7 @@ def message(self, x_j, W): return x_j * W -class DistanceCellList(torch.nn.Module): +class OptimizedDistance(torch.nn.Module): def __init__( self, @@ -94,7 +94,7 @@ def __init__( check_errors=True, box=None, ): - super(DistanceCellList, self).__init__() + super(OptimizedDistance, self).__init__() """ Compute the neighbor list for a given cutoff. This operation can be placed inside a CUDA graph in some cases. In particular, resize_to_fit and check_errors must be False. From 866d1644ea7c1a24617f4864a7ec8c02655e45f4 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 18 May 2023 10:56:44 +0200 Subject: [PATCH 54/76] Clarify meaning of max_num_pairs argument --- torchmdnet/models/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 47d9bdfdb..239461e42 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -106,7 +106,8 @@ def __init__( cutoff_upper : float Upper cutoff for the neighbor list. max_num_pairs : int - Maximum number of pairs to store. + Maximum number of pairs to store, if the number of pairs found is less than this, the list is padded with (-1,-1) pairs up to max_num_pairs unless resize_to_fit is True, in which case the list is resized to the actual number of pairs found. + If the number of pairs found is larger than this, the pairs are randomly sampled. When check_errors is True, an exception is raised in this case. If negative, it is interpreted as (minus) the maximum number of neighbors per atom. strategy : str Strategy to use for computing the neighbor list. Can be one of From 7aa4e807d2397a928428345df82077b9f91ea62a Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 18 May 2023 15:29:51 +0200 Subject: [PATCH 55/76] Move some code around. Separate shared and brute to different files. --- torchmdnet/neighbors/__init__.py | 2 +- torchmdnet/neighbors/common.cuh | 68 ++++- torchmdnet/neighbors/neighbors_cuda.cu | 277 ++---------------- torchmdnet/neighbors/neighbors_cuda_brute.cuh | 115 ++++++++ ...s_cuda_cell.cu => neighbors_cuda_cell.cuh} | 93 +----- .../neighbors/neighbors_cuda_shared.cuh | 119 ++++++++ 6 files changed, 344 insertions(+), 330 deletions(-) create mode 100644 torchmdnet/neighbors/neighbors_cuda_brute.cuh rename torchmdnet/neighbors/{neighbors_cuda_cell.cu => neighbors_cuda_cell.cuh} (85%) create mode 100644 torchmdnet/neighbors/neighbors_cuda_shared.cuh diff --git a/torchmdnet/neighbors/__init__.py b/torchmdnet/neighbors/__init__.py index b0b60480f..29fa7af06 100644 --- a/torchmdnet/neighbors/__init__.py +++ b/torchmdnet/neighbors/__init__.py @@ -6,7 +6,7 @@ def compile_extension(): src_dir = os.path.dirname(__file__) sources = ["neighbors.cpp", "neighbors_cpu.cpp"] + ( - ["neighbors_cuda.cu", "neighbors_cuda_cell.cu", "backwards.cu"] + ["neighbors_cuda.cu", "backwards.cu"] if pt.cuda.is_available() else [] ) diff --git a/torchmdnet/neighbors/common.cuh b/torchmdnet/neighbors/common.cuh index 28a163620..ec1a6661b 100644 --- a/torchmdnet/neighbors/common.cuh +++ b/torchmdnet/neighbors/common.cuh @@ -50,6 +50,72 @@ template <> struct vec3 { template using scalar3 = typename vec3::type; +/* + * @brief Get the position of the i'th particle + * @param positions The positions tensor + * @param i The index of the particle + * @return The position of the i'th particle + */ +template +__device__ scalar3 fetchPosition(const Accessor positions, const int i) { + return {positions[i][0], positions[i][1], positions[i][2]}; +} + +struct PairList { + Tensor i_curr_pair; + Tensor neighbors; + Tensor deltas; + Tensor distances; + const bool loop, include_transpose, use_periodic; + PairList(int max_num_pairs, TensorOptions options, bool loop, bool include_transpose, + bool use_periodic) + : i_curr_pair(zeros({1}, options.dtype(torch::kInt))), + neighbors(full({2, max_num_pairs}, -1, options.dtype(torch::kInt))), + deltas(empty({max_num_pairs, 3}, options)), distances(full({max_num_pairs}, 0, options)), + loop(loop), include_transpose(include_transpose), use_periodic(use_periodic) { + } +}; + +template struct PairListAccessor { + Accessor i_curr_pair; + Accessor neighbors; + Accessor deltas; + Accessor distances; + bool loop, include_transpose, use_periodic; + explicit PairListAccessor(const PairList& pl) + : i_curr_pair(get_accessor(pl.i_curr_pair)), + neighbors(get_accessor(pl.neighbors)), + deltas(get_accessor(pl.deltas)), + distances(get_accessor(pl.distances)), loop(pl.loop), + include_transpose(pl.include_transpose), use_periodic(pl.use_periodic) { + } +}; + +template +__device__ void writeAtomPair(PairListAccessor& list, int i, int j, + scalar3 delta, scalar_t distance, int i_pair) { + list.neighbors[0][i_pair] = i; + list.neighbors[1][i_pair] = j; + list.deltas[i_pair][0] = delta.x; + list.deltas[i_pair][1] = delta.y; + list.deltas[i_pair][2] = delta.z; + list.distances[i_pair] = distance; +} + +template +__device__ void addAtomPairToList(PairListAccessor& list, int i, int j, + scalar3 delta, scalar_t distance, + bool add_transpose) { + const int32_t i_pair = atomicAdd(&list.i_curr_pair[0], add_transpose ? 2 : 1); + // Neighbors after the max number of pairs are ignored, although the pair is counted + if (i_pair + add_transpose < list.neighbors.size(1)) { + writeAtomPair(list, i, j, delta, distance, i_pair); + if (add_transpose) { + writeAtomPair(list, j, i, {-delta.x, -delta.y, -delta.z}, distance, i_pair + 1); + } + } +} + static void checkInput(const Tensor& positions, const Tensor& batch) { TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); TORCH_CHECK(positions.size(0) > 0, @@ -144,5 +210,5 @@ __device__ auto compute_distance(scalar3 pos_i, scalar3 pos_ * Backward pass for the CUDA neighbor list operation. * Computes the gradient of the positions with respect to the distances and deltas. */ -tensor_list common_backward(AutogradContext* ctx, const tensor_list &grad_inputs); +tensor_list common_backward(AutogradContext* ctx, const tensor_list& grad_inputs); #endif diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index a85ff2d5a..deaa90e72 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -1,264 +1,45 @@ -#include "common.cuh" -#include -#include - -__device__ uint32_t get_row(uint32_t index) { - uint32_t row = floor((sqrtf(8 * index + 1) + 1) / 2); - if (row * (row - 1) > 2 * index) - row--; - return row; -} - -template -__global__ void forward_kernel_brute(uint32_t num_all_pairs, const Accessor positions, - const Accessor batch, scalar_t cutoff_lower2, - scalar_t cutoff_upper2, bool loop, bool include_transpose, - Accessor i_curr_pair, - Accessor neighbors, Accessor deltas, - Accessor distances, bool use_periodic, - triclinic::Box box) { - const uint32_t index = blockIdx.x * blockDim.x + threadIdx.x; - if (index >= num_all_pairs) - return; - const uint32_t row = get_row(index); - const uint32_t column = (index - row * (row - 1) / 2); - if (batch[row] == batch[column]) { - const scalar3 pos_i{positions[row][0], positions[row][1], positions[row][2]}; - const scalar3 pos_j{positions[column][0], positions[column][1], - positions[column][2]}; - const auto delta = triclinic::compute_distance(pos_i, pos_j, use_periodic, box); - const scalar_t distance2 = delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; - if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) { - const int32_t i_pair = atomicAdd(&i_curr_pair[0], include_transpose ? 2 : 1); - // We handle too many neighbors outside of the kernel - if (i_pair + include_transpose < neighbors.size(1)) { - const scalar_t r2 = sqrt_(distance2); - neighbors[0][i_pair] = row; - neighbors[1][i_pair] = column; - deltas[i_pair][0] = delta.x; - deltas[i_pair][1] = delta.y; - deltas[i_pair][2] = delta.z; - distances[i_pair] = r2; - if (include_transpose) { - neighbors[0][i_pair + 1] = column; - neighbors[1][i_pair + 1] = row; - deltas[i_pair + 1][0] = -delta.x; - deltas[i_pair + 1][1] = -delta.y; - deltas[i_pair + 1][2] = -delta.z; - distances[i_pair + 1] = r2; - } - } - } - } -} - -template -__global__ void add_self_kernel(const int num_atoms, Accessor positions, - Accessor i_curr_pair, Accessor neighbors, - Accessor deltas, Accessor distances) { - const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; - if (i_atom >= num_atoms) - return; - const int32_t i_pair = atomicAdd(&i_curr_pair[0], 1); - if (i_pair < neighbors.size(1)) { - neighbors[0][i_pair] = i_atom; - neighbors[1][i_pair] = i_atom; - deltas[i_pair][0] = 0; - deltas[i_pair][1] = 0; - deltas[i_pair][2] = 0; - distances[i_pair] = 0; - } -} - -template -__global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor positions, - const Accessor batch, scalar_t cutoff_lower2, - scalar_t cutoff_upper2, bool loop, bool include_transpose, - Accessor i_curr_pair, - Accessor neighbors, Accessor deltas, - Accessor distances, int32_t num_tiles, - bool use_periodic, triclinic::Box box) { - // A thread per atom - const int id = blockIdx.x * blockDim.x + threadIdx.x; - // All threads must pass through __syncthreads, - // but when N is not a multiple of 32 some threads are assigned a particle i>N. - // This threads cant return, so they are masked to not do any work - const bool active = id < num_atoms; - __shared__ scalar3 sh_pos[BLOCKSIZE]; - __shared__ int64_t sh_batch[BLOCKSIZE]; - scalar3 pos_i; - int64_t batch_i; - if (active) { - pos_i = {positions[id][0], positions[id][1], positions[id][2]}; - batch_i = batch[id]; - } - // Distribute the N particles in a group of tiles. Storing in each tile blockDim.x values in - // shared memory. This way all threads are accesing the same memory addresses at the same time - for (int tile = 0; tile < num_tiles; tile++) { - // Load this tiles particles values to shared memory - const int i_load = tile * blockDim.x + threadIdx.x; - if (i_load < num_atoms) { // Even if im not active, my thread may load a value each tile to - // shared memory. - sh_pos[threadIdx.x] = {positions[i_load][0], positions[i_load][1], - positions[i_load][2]}; - sh_batch[threadIdx.x] = batch[i_load]; - } - // Wait for all threads to arrive - __syncthreads(); - // Go through all the particles in the current tile -#pragma unroll 8 - for (int counter = 0; counter < blockDim.x; counter++) { - if (!active) - break; // An out of bounds thread must be masked - const int cur_j = tile * blockDim.x + counter; - const bool testPair = cur_j < num_atoms and (cur_j < id or (loop and cur_j == id)); - if (testPair) { - const auto batch_j = sh_batch[counter]; - if (batch_i == batch_j) { - const auto pos_j = sh_pos[counter]; - const auto delta = triclinic::compute_distance(pos_i, pos_j, use_periodic, box); - const scalar_t distance2 = - delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; - if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) { - const bool requires_transpose = include_transpose && !(cur_j == id); - const int32_t i_pair = - atomicAdd(&i_curr_pair[0], requires_transpose ? 2 : 1); - if (i_pair + requires_transpose < neighbors.size(1)) { - const auto distance = sqrt_(distance2); - neighbors[0][i_pair] = id; - neighbors[1][i_pair] = cur_j; - deltas[i_pair][0] = delta.x; - deltas[i_pair][1] = delta.y; - deltas[i_pair][2] = delta.z; - distances[i_pair] = distance; - if (requires_transpose) { - neighbors[0][i_pair + 1] = cur_j; - neighbors[1][i_pair + 1] = id; - deltas[i_pair + 1][0] = -delta.x; - deltas[i_pair + 1][1] = -delta.y; - deltas[i_pair + 1][2] = -delta.z; - distances[i_pair + 1] = distance; - } - } - } - } - } - } - __syncthreads(); - } -} - -enum class strategy { brute, shared }; - -class Autograd : public Function { -public: - static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, - const Scalar& cutoff_lower, const Scalar& cutoff_upper, - const Tensor& box_vectors, bool use_periodic, - const Scalar& max_num_pairs, bool loop, bool include_transpose, - strategy strat) { - checkInput(positions, batch); - const auto max_num_pairs_ = max_num_pairs.toLong(); - TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); - if (use_periodic) { - TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions"); - TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, - "Expected \"box_vectors\" to have shape (3, 3)"); - } - TORCH_CHECK(box_vectors.device() == torch::kCPU, "Expected \"box_vectors\" to be on CPU"); - const int num_atoms = positions.size(0); - if(num_atoms > 32768){ - //The brute force method runs into integer overflow for num_atoms > 32768 - strat = strategy::shared; - } - const int num_pairs = max_num_pairs_; - const TensorOptions options = positions.options(); - const auto stream = getCurrentCUDAStream(positions.get_device()); - const Tensor neighbors = full({2, num_pairs}, -1, options.dtype(kInt32)); - const Tensor deltas = empty({num_pairs, 3}, options); - const Tensor distances = full(num_pairs, 0, options); - const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); - { - const CUDAStreamGuard guard(stream); - const int32_t num_atoms = positions.size(0); - if (strat == strategy::brute) { - const uint64_t num_all_pairs = num_atoms * (num_atoms - 1ul) / 2ul; - const uint64_t num_threads = 128; - const uint64_t num_blocks = - std::max((num_all_pairs + num_threads - 1ul) / num_threads, 1ul); - AT_DISPATCH_FLOATING_TYPES( - positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { - triclinic::Box box(box_vectors); - const scalar_t cutoff_upper_ = cutoff_upper.to(); - const scalar_t cutoff_lower_ = cutoff_lower.to(); - TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); - forward_kernel_brute<<>>( - num_all_pairs, get_accessor(positions), - get_accessor(batch), cutoff_lower_ * cutoff_lower_, - cutoff_upper_ * cutoff_upper_, loop, include_transpose, - get_accessor(i_curr_pair), - get_accessor(neighbors), get_accessor(deltas), - get_accessor(distances), use_periodic, box); - if (loop) { - const uint64_t num_threads = 128; - const uint64_t num_blocks = - std::max((num_atoms + num_threads - 1ul) / num_threads, 1ul); - add_self_kernel<<>>( - num_atoms, get_accessor(positions), - get_accessor(i_curr_pair), - get_accessor(neighbors), - get_accessor(deltas), - get_accessor(distances)); - } - }); - } else if (strat == strategy::shared) { - AT_DISPATCH_FLOATING_TYPES( - positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() { - const scalar_t cutoff_upper_ = cutoff_upper.to(); - const scalar_t cutoff_lower_ = cutoff_lower.to(); - triclinic::Box box(box_vectors); - TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); - constexpr int BLOCKSIZE = 64; - const int num_blocks = std::max((num_atoms + BLOCKSIZE - 1) / BLOCKSIZE, 1); - const int num_threads = BLOCKSIZE; - const int num_tiles = num_blocks; - forward_kernel_shared<<>>( - num_atoms, get_accessor(positions), - get_accessor(batch), cutoff_lower_ * cutoff_lower_, - cutoff_upper_ * cutoff_upper_, loop, include_transpose, - get_accessor(i_curr_pair), - get_accessor(neighbors), get_accessor(deltas), - get_accessor(distances), num_tiles, use_periodic, box); - }); - } - } - ctx->save_for_backward({neighbors, deltas, distances}); - ctx->saved_data["num_atoms"] = num_atoms; - return {neighbors, deltas, distances, i_curr_pair}; - } - - static tensor_list backward(AutogradContext* ctx, const tensor_list &grad_inputs) { - return common_backward(ctx, grad_inputs); - } -}; +/* Raul P. Pelaez 2023 + Connection between the neighbor CUDA implementations and the torch extension. + See neighbors.cpp for the definition of the torch extension functions. + */ +#include "neighbors_cuda_brute.cuh" +#include "neighbors_cuda_cell.cuh" +#include "neighbors_cuda_shared.cuh" TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { m.impl("get_neighbor_pairs_brute", [](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop, bool include_transpose) { - const tensor_list results = Autograd::apply( - positions, batch, cutoff_lower, cutoff_upper, box_vectors, use_periodic, - max_num_pairs, loop, include_transpose, strategy::brute); + tensor_list results; + if (positions.size(0) >= 32768) { + // Revert to shared if there are too many particles, which brute can't handle + results = AutogradSharedCUDA::apply(positions, batch, cutoff_lower, cutoff_upper, + box_vectors, use_periodic, max_num_pairs, + loop, include_transpose); + } else { + results = AutogradBruteCUDA::apply(positions, batch, cutoff_lower, cutoff_upper, + box_vectors, use_periodic, max_num_pairs, + loop, include_transpose); + } return std::make_tuple(results[0], results[1], results[2], results[3]); }); m.impl("get_neighbor_pairs_shared", [](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop, bool include_transpose) { - const tensor_list results = Autograd::apply( + const tensor_list results = AutogradSharedCUDA::apply( positions, batch, cutoff_lower, cutoff_upper, box_vectors, use_periodic, - max_num_pairs, loop, include_transpose, strategy::shared); + max_num_pairs, loop, include_transpose); + return std::make_tuple(results[0], results[1], results[2], results[3]); + }); + m.impl("get_neighbor_pairs_cell", + [](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, + bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Scalar& max_num_pairs, bool loop, bool include_transpose) { + const tensor_list results = AutogradCellCUDA::apply( + positions, batch, box_vectors, use_periodic, cutoff_lower, cutoff_upper, + max_num_pairs, loop, include_transpose); return std::make_tuple(results[0], results[1], results[2], results[3]); }); } diff --git a/torchmdnet/neighbors/neighbors_cuda_brute.cuh b/torchmdnet/neighbors/neighbors_cuda_brute.cuh new file mode 100644 index 000000000..3756062aa --- /dev/null +++ b/torchmdnet/neighbors/neighbors_cuda_brute.cuh @@ -0,0 +1,115 @@ +/* Raul P. Pelaez 2023. Brute force neighbor list construction in CUDA. + + A brute force approach that assigns a thread per each possible pair of particles in the system. + Based on an implementation by Raimondas Galvelis. + Works fantastically for small (less than 10K atoms) systems, but cannot handle more than 32K atoms. + */ +#ifndef NEIGHBORS_BRUTE_CUH +#define NEIGHBORS_BRUTE_CUH +#include "common.cuh" +#include +#include +#include + +__device__ uint32_t get_row(uint32_t index) { + uint32_t row = floor((sqrtf(8 * index + 1) + 1) / 2); + if (row * (row - 1) > 2 * index) + row--; + return row; +} + +template +__global__ void forward_kernel_brute(uint32_t num_all_pairs, const Accessor positions, + const Accessor batch, scalar_t cutoff_lower2, + scalar_t cutoff_upper2, PairListAccessor list, + triclinic::Box box) { + const uint32_t index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= num_all_pairs) + return; + const uint32_t row = get_row(index); + const uint32_t column = (index - row * (row - 1) / 2); + if (batch[row] == batch[column]) { + const auto pos_i = fetchPosition(positions, row); + const auto pos_j = fetchPosition(positions, column); + const auto delta = triclinic::compute_distance(pos_i, pos_j, list.use_periodic, box); + const scalar_t distance2 = delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; + if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) { + const scalar_t r2 = sqrt_(distance2); + addAtomPairToList(list, row, column, delta, r2, list.include_transpose); + } + } +} + +template +__global__ void add_self_kernel(const int num_atoms, Accessor positions, + PairListAccessor list) { + const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; + if (i_atom >= num_atoms) + return; + __shared__ int i_pair; + if (threadIdx.x == 0) { // Each block adds blockDim.x pairs to the list. + // Handle the last block, so that only num_atoms are added in total + i_pair = atomicAdd(&list.i_curr_pair[0], + thrust::min(blockDim.x, num_atoms - blockIdx.x * blockDim.x)); + } + __syncthreads(); + scalar3 delta{}; + scalar_t distance = 0; + writeAtomPair(list, i_atom, i_atom, delta, distance, i_pair + threadIdx.x); +} + +class AutogradBruteCUDA : public Function { +public: + static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, + const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Tensor& box_vectors, bool use_periodic, + const Scalar& max_num_pairs, bool loop, bool include_transpose) { + checkInput(positions, batch); + const auto max_num_pairs_ = max_num_pairs.toLong(); + TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); + if (use_periodic) { + TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions"); + TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, + "Expected \"box_vectors\" to have shape (3, 3)"); + } + TORCH_CHECK(box_vectors.device() == torch::kCPU, "Expected \"box_vectors\" to be on CPU"); + const int num_atoms = positions.size(0); + TORCH_CHECK(num_atoms < 32768, + "The brute strategy fails with \"num_atoms\" larger than 32768"); + const int num_pairs = max_num_pairs_; + const TensorOptions options = positions.options(); + const auto stream = getCurrentCUDAStream(positions.get_device()); + PairList list(num_pairs, positions.options(), loop, include_transpose, use_periodic); + const CUDAStreamGuard guard(stream); + const uint64_t num_all_pairs = num_atoms * (num_atoms - 1UL) / 2UL; + const uint64_t num_threads = 128; + const uint64_t num_blocks = + std::max((num_all_pairs + num_threads - 1UL) / num_threads, 1UL); + AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { + PairListAccessor list_accessor(list); + triclinic::Box box(box_vectors); + const scalar_t cutoff_upper_ = cutoff_upper.to(); + const scalar_t cutoff_lower_ = cutoff_lower.to(); + TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); + forward_kernel_brute<<>>( + num_all_pairs, get_accessor(positions), + get_accessor(batch), cutoff_lower_ * cutoff_lower_, + cutoff_upper_ * cutoff_upper_, list_accessor, box); + if (loop) { + const uint32_t num_threads = 256; + const uint32_t num_blocks = + std::max((num_atoms + num_threads - 1U) / num_threads, 1U); + add_self_kernel<<>>( + num_atoms, get_accessor(positions), list_accessor); + } + }); + ctx->save_for_backward({list.neighbors, list.deltas, list.distances}); + ctx->saved_data["num_atoms"] = num_atoms; + return {list.neighbors, list.deltas, list.distances, list.i_curr_pair}; + } + + static tensor_list backward(AutogradContext* ctx, const tensor_list& grad_inputs) { + return common_backward(ctx, grad_inputs); + } +}; +#endif diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cu b/torchmdnet/neighbors/neighbors_cuda_cell.cuh similarity index 85% rename from torchmdnet/neighbors/neighbors_cuda_cell.cu rename to torchmdnet/neighbors/neighbors_cuda_cell.cuh index deba7160b..a6a065418 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cu +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cuh @@ -1,22 +1,13 @@ /* Raul P. Pelaez 2023. Batched cell list neighbor list implementation for CUDA. */ +#ifndef NEIGHBOR_CUDA_CELL_H +#define NEIGHBOR_CUDA_CELL_H #include "common.cuh" #include #include #include -/* - * @brief Get the position of the i'th particle - * @param positions The positions tensor - * @param i The index of the particle - * @return The position of the i'th particle - */ -template -__device__ scalar3 fetchPosition(const Accessor positions, const int i) { - return {positions[i][0], positions[i][1], positions[i][2]}; -} - /* * @brief Encodes an unsigned integer lower than 1024 as a 32 bit integer by filling every third * bit. @@ -317,21 +308,6 @@ struct CellList { Tensor sorted_positions, sorted_batch; }; -struct PairList { - Tensor i_curr_pair; - Tensor neighbors; - Tensor deltas; - Tensor distances; - const bool loop, include_transpose, use_periodic; - PairList(int max_num_pairs, TensorOptions options, bool loop, bool include_transpose, - bool use_periodic) - : i_curr_pair(zeros({1}, options.dtype(torch::kInt))), - neighbors(full({2, max_num_pairs}, -1, options.dtype(torch::kInt))), - deltas(empty({max_num_pairs, 3}, options)), distances(full({max_num_pairs}, 0, options)), - loop(loop), include_transpose(include_transpose), use_periodic(use_periodic) { - } -}; - CellList constructCellList(const Tensor& positions, const Tensor& batch, const Tensor& box_size, const Scalar& cutoff) { // The algorithm for the cell list construction can be summarized in three separate steps: @@ -358,7 +334,7 @@ template struct CellListAccessor { Accessor sorted_positions; Accessor sorted_batch; - CellListAccessor(const CellList& cl) + explicit CellListAccessor(const CellList& cl) : cell_start(get_accessor(cl.cell_start)), cell_end(get_accessor(cl.cell_end)), original_indices(get_accessor(cl.original_indices)), @@ -367,21 +343,6 @@ template struct CellListAccessor { } }; -template struct PairListAccessor { - Accessor i_curr_pair; - Accessor neighbors; - Accessor deltas; - Accessor distances; - bool loop, include_transpose, use_periodic; - PairListAccessor(const PairList& pl) - : i_curr_pair(get_accessor(pl.i_curr_pair)), - neighbors(get_accessor(pl.neighbors)), - deltas(get_accessor(pl.deltas)), - distances(get_accessor(pl.distances)), loop(pl.loop), - include_transpose(pl.include_transpose), use_periodic(pl.use_periodic) { - } -}; - /* * @brief Add a pair of particles to the pair list. If necessary, also add the transpose pair. * @param list The pair list @@ -392,30 +353,14 @@ template struct PairListAccessor { */ template __device__ void addNeighborPair(PairListAccessor& list, const int i, const int j, - scalar_t distance2, const scalar3 delta) { + scalar_t distance2, scalar3 delta) { const bool requires_transpose = list.include_transpose and (j != i); - const int32_t i_pair = atomicAdd(&list.i_curr_pair[0], requires_transpose ? 2 : 1); - // We handle too many neighbors outside of the kernel - if (i_pair + requires_transpose < list.neighbors.size(1)) { - const int ni = thrust::max(i, j); - const int nj = thrust::min(i, j); - const scalar_t delta_sign = (ni == i) ? scalar_t(1.0) : scalar_t(-1.0); - const scalar_t distance = sqrt_(distance2); - list.neighbors[0][i_pair] = ni; - list.neighbors[1][i_pair] = nj; - list.deltas[i_pair][0] = delta_sign * delta.x; - list.deltas[i_pair][1] = delta_sign * delta.y; - list.deltas[i_pair][2] = delta_sign * delta.z; - list.distances[i_pair] = distance; - if (requires_transpose) { - list.neighbors[0][i_pair + 1] = nj; - list.neighbors[1][i_pair + 1] = ni; - list.deltas[i_pair + 1][0] = -delta_sign * delta.x; - list.deltas[i_pair + 1][1] = -delta_sign * delta.y; - list.deltas[i_pair + 1][2] = -delta_sign * delta.z; - list.distances[i_pair + 1] = distance; - } - } + const int ni = thrust::max(i, j); + const int nj = thrust::min(i, j); + const scalar_t delta_sign = (ni == i) ? scalar_t(1.0) : scalar_t(-1.0); + const scalar_t distance = sqrt_(distance2); + delta = {delta_sign * delta.x, delta_sign * delta.y, delta_sign * delta.z}; + addAtomPairToList(list, ni, nj, delta, distance, requires_transpose); } /* @@ -430,7 +375,6 @@ template __device__ void addNeighborsForCell(const Particle& i_atom, int j_cell, const CellListAccessor& cl, scalar3 box_size, PairListAccessor& list) { - const auto first_particle = cl.cell_start[j_cell]; if (first_particle != -1) { // Continue only if there are particles in this cell const auto last_particle = cl.cell_end[j_cell]; @@ -485,7 +429,7 @@ __global__ void traverseCellList(const CellListAccessor cell_list, } } -class Autograd : public Function { +class AutogradCellCUDA : public Function { public: static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, const Tensor& box_size, bool use_periodic, @@ -519,7 +463,7 @@ public: box_size[2][2].item()}; PairListAccessor list_accessor(list); CellListAccessor cell_list_accessor(cell_list); - const int threads = 256; + const int threads = 64; const int blocks = (num_atoms + threads - 1) / threads; traverseCellList<<>>(cell_list_accessor, list_accessor, num_atoms, box_size_, @@ -535,15 +479,4 @@ public: return common_backward(ctx, grad_inputs); } }; - -TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { - m.impl("get_neighbor_pairs_cell", - [](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, - bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, - const Scalar& max_num_pairs, bool loop, bool include_transpose) { - const tensor_list results = - Autograd::apply(positions, batch, box_vectors, use_periodic, cutoff_lower, - cutoff_upper, max_num_pairs, loop, include_transpose); - return std::make_tuple(results[0], results[1], results[2], results[3]); - }); -} +#endif diff --git a/torchmdnet/neighbors/neighbors_cuda_shared.cuh b/torchmdnet/neighbors/neighbors_cuda_shared.cuh new file mode 100644 index 000000000..01861e901 --- /dev/null +++ b/torchmdnet/neighbors/neighbors_cuda_shared.cuh @@ -0,0 +1,119 @@ +/* Raul P. Pelaez 2023. Shared memory neighbor list construction for CUDA. + This brute force approach checks all pairs of atoms by collaborativelly loading and processing + tiles of atoms into shared memory. + This approach is tipically slower than the brute force approach, but can handle an arbitrarily + large number of atoms. + */ +#ifndef NEIGHBORS_SHARED_CUH +#define NEIGHBORS_SHARED_CUH +#include "common.cuh" +#include +#include +#include + +template +__global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor positions, + const Accessor batch, scalar_t cutoff_lower2, + scalar_t cutoff_upper2, PairListAccessor list, + int32_t num_tiles, triclinic::Box box) { + // A thread per atom + const int id = blockIdx.x * blockDim.x + threadIdx.x; + // All threads must pass through __syncthreads, + // but when N is not a multiple of 32 some threads are assigned a particle i>N. + // This threads cant return, so they are masked to not do any work + const bool active = id < num_atoms; + __shared__ scalar3 sh_pos[BLOCKSIZE]; + __shared__ int64_t sh_batch[BLOCKSIZE]; + scalar3 pos_i; + int64_t batch_i; + if (active) { + pos_i = fetchPosition(positions, id); + batch_i = batch[id]; + } + // Distribute the N particles in a group of tiles. Storing in each tile blockDim.x values in + // shared memory. This way all threads are accesing the same memory addresses at the same time + for (int tile = 0; tile < num_tiles; tile++) { + // Load this tiles particles values to shared memory + const int i_load = tile * blockDim.x + threadIdx.x; + if (i_load < num_atoms) { // Even if im not active, my thread may load a value each tile to + // shared memory. + sh_pos[threadIdx.x] = fetchPosition(positions, i_load); + sh_batch[threadIdx.x] = batch[i_load]; + } + // Wait for all threads to arrive + __syncthreads(); + // Go through all the particles in the current tile +#pragma unroll 8 + for (int counter = 0; counter < blockDim.x; counter++) { + if (!active) + break; // An out of bounds thread must be masked + const int cur_j = tile * blockDim.x + counter; + const bool testPair = cur_j < num_atoms and (cur_j < id or (list.loop and cur_j == id)); + if (testPair) { + const auto batch_j = sh_batch[counter]; + if (batch_i == batch_j) { + const auto pos_j = sh_pos[counter]; + const auto delta = + triclinic::compute_distance(pos_i, pos_j, list.use_periodic, box); + const scalar_t distance2 = + delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; + if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) { + const bool requires_transpose = list.include_transpose && !(cur_j == id); + const auto distance = sqrt_(distance2); + addAtomPairToList(list, id, cur_j, delta, distance, requires_transpose); + } + } + } + } + __syncthreads(); + } +} + +class AutogradSharedCUDA : public Function { +public: + static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, + const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Tensor& box_vectors, bool use_periodic, + const Scalar& max_num_pairs, bool loop, bool include_transpose) { + checkInput(positions, batch); + const auto max_num_pairs_ = max_num_pairs.toLong(); + TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); + if (use_periodic) { + TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions"); + TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, + "Expected \"box_vectors\" to have shape (3, 3)"); + } + TORCH_CHECK(box_vectors.device() == torch::kCPU, "Expected \"box_vectors\" to be on CPU"); + const int num_atoms = positions.size(0); + const int num_pairs = max_num_pairs_; + const TensorOptions options = positions.options(); + const auto stream = getCurrentCUDAStream(positions.get_device()); + PairList list(num_pairs, positions.options(), loop, include_transpose, use_periodic); + const CUDAStreamGuard guard(stream); + AT_DISPATCH_FLOATING_TYPES( + positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() { + const scalar_t cutoff_upper_ = cutoff_upper.to(); + const scalar_t cutoff_lower_ = cutoff_lower.to(); + triclinic::Box box(box_vectors); + TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); + constexpr int BLOCKSIZE = 64; + const int num_blocks = std::max((num_atoms + BLOCKSIZE - 1) / BLOCKSIZE, 1); + const int num_threads = BLOCKSIZE; + const int num_tiles = num_blocks; + PairListAccessor list_accessor(list); + forward_kernel_shared<<>>( + num_atoms, get_accessor(positions), + get_accessor(batch), cutoff_lower_ * cutoff_lower_, + cutoff_upper_ * cutoff_upper_, list_accessor, num_tiles, box); + }); + ctx->save_for_backward({list.neighbors, list.deltas, list.distances}); + ctx->saved_data["num_atoms"] = num_atoms; + return {list.neighbors, list.deltas, list.distances, list.i_curr_pair}; + } + + static tensor_list backward(AutogradContext* ctx, const tensor_list& grad_inputs) { + return common_backward(ctx, grad_inputs); + } +}; + +#endif From a345612ad910fd151f075a0fce8c5193fb702de8 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 22 May 2023 10:08:15 +0200 Subject: [PATCH 56/76] Add tests for non-diagonal triclinic boxes --- tests/test_neighbors.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 06707056d..d017f0904 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -96,12 +96,18 @@ def test_neighbors( pos.requires_grad = True if box_type is None: box = None - else: + elif box_type == "rectangular": box = ( torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]) .to(pos.dtype) .to(device) ) + elif box_type == "triclinic": + box = ( + torch.tensor([[lbox, 0.0, 0.0], [0.1, lbox, 0.0], [0.3, 0.2, lbox]]) + .to(pos.dtype) + .to(device) + ) ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors( pos, batch, loop, include_transpose, cutoff, box ) From eb8ab5cbc912cdbf3a2ed09be3e88c26e7f2c11a Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 22 May 2023 10:21:08 +0200 Subject: [PATCH 57/76] Improve documentation of OptimizedDistance --- torchmdnet/models/utils.py | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 239461e42..2ad1fd382 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -99,6 +99,20 @@ def __init__( This operation can be placed inside a CUDA graph in some cases. In particular, resize_to_fit and check_errors must be False. Note that this module returns neighbors such that distance(i,j) >= cutoff_lower and distance(i,j) < cutoff_upper. + This function optionally supports periodic boundary conditions with + arbitrary triclinic boxes. The box vectors `a`, `b`, and `c` must satisfy + certain requirements: + + `a[1] = a[2] = b[2] = 0` + `a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff` + `a[0] >= 2*b[0]` + `a[0] >= 2*c[0]` + `b[1] >= 2*c[1]` + + These requirements correspond to a particular rotation of the system and + reduced form of the vectors, as well as the requirement that the cutoff be + no larger than half the box width. + Parameters ---------- cutoff_lower : float @@ -115,20 +129,26 @@ def __init__( Shared: An O(N^2) algorithm that leverages CUDA shared memory, best for large number of particles. Brute: A brute force O(N^2) algorithm, best for small number of particles. Cell: A cell list algorithm, best for large number of particles, low cutoffs and low batch size. - box : Optional[torch.Tensor] - Size of the box, shape (3,3) or None. - If strategy is "cell", the box must be diagonal. - loop : bool + box : torch.Tensor, optional + The vectors defining the periodic box. This must have shape `(3, 3)`, + where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`. + If this is omitted, periodic boundary conditions are not applied. + loop : bool, optional Whether to include self-interactions. - include_transpose : bool + Default: False + include_transpose : bool, optional Whether to include the transpose of the neighbor list. - resize_to_fit : bool + Default: True + resize_to_fit : bool, optional Whether to resize the neighbor list to the actual number of pairs found. When False, the list is padded with (-1,-1) pairs up to max_num_pairs + Default: True If this is True the operation is not CUDA graph compatible. - check_errors : bool + check_errors : bool, optional Whether to check for too many pairs. If this is True the operation is not CUDA graph compatible. - return_vecs : bool + Default: True + return_vecs : bool, optional Whether to return the distance vectors. + Default: False """ self.cutoff_upper = cutoff_upper self.cutoff_lower = cutoff_lower From 6b941ffa3c21c270214e64e9f938638827cdf206 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 22 May 2023 11:47:26 +0200 Subject: [PATCH 58/76] Small changes --- torchmdnet/neighbors/neighbors_cuda_cell.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/neighbors/neighbors_cuda_cell.cuh index a6a065418..9c5e1fe9c 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cuh +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cuh @@ -124,8 +124,8 @@ __global__ void assignHash(const Accessor positions, uint64_t* hash return; const uint32_t i_batch = batch[i_atom]; // Move to the unit cell - scalar3 pi = {positions[i_atom][0], positions[i_atom][1], positions[i_atom][2]}; - auto ci = getCell(pi, box_size, cutoff); + const auto pi = fetchPosition(positions, i_atom); + const auto ci = getCell(pi, box_size, cutoff); // Calculate the hash const uint32_t hash = hashMorton(ci); // Create a hash combining the Morton hash and the batch index, so that atoms in the same cell From 526d366540603add8de34dcf44458ce20313b1d2 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 22 May 2023 11:47:37 +0200 Subject: [PATCH 59/76] Remove unnecessary thrust headers --- torchmdnet/neighbors/neighbors_cuda_cell.cuh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/neighbors/neighbors_cuda_cell.cuh index 9c5e1fe9c..7d09ed8df 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cuh +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cuh @@ -5,9 +5,7 @@ #define NEIGHBOR_CUDA_CELL_H #include "common.cuh" #include -#include -#include - +#include /* * @brief Encodes an unsigned integer lower than 1024 as a 32 bit integer by filling every third * bit. From a75a521c94f05878d5be37e67717eff3d16c60f0 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 22 May 2023 15:32:48 +0200 Subject: [PATCH 60/76] Change name of extension so it does not collide with NNPOps one --- torchmdnet/neighbors/__init__.py | 8 ++++---- torchmdnet/neighbors/neighbors.cpp | 2 +- torchmdnet/neighbors/neighbors_cpu.cpp | 2 +- torchmdnet/neighbors/neighbors_cuda.cu | 3 ++- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/torchmdnet/neighbors/__init__.py b/torchmdnet/neighbors/__init__.py index 29fa7af06..ccb6b0efc 100644 --- a/torchmdnet/neighbors/__init__.py +++ b/torchmdnet/neighbors/__init__.py @@ -11,14 +11,14 @@ def compile_extension(): else [] ) sources = [os.path.join(src_dir, name) for name in sources] - cpp_extension.load(name="neighbors", sources=sources, is_python_module=False) + cpp_extension.load(name="torchmdnet_neighbors", sources=sources, is_python_module=False, verbose=True) def get_backends(): compile_extension() - get_neighbor_pairs_brute = pt.ops.neighbors.get_neighbor_pairs_brute - get_neighbor_pairs_shared = pt.ops.neighbors.get_neighbor_pairs_shared - get_neighbor_pairs_cell = pt.ops.neighbors.get_neighbor_pairs_cell + get_neighbor_pairs_brute = pt.ops.torchmdnet_neighbors.get_neighbor_pairs_brute + get_neighbor_pairs_shared = pt.ops.torchmdnet_neighbors.get_neighbor_pairs_shared + get_neighbor_pairs_cell = pt.ops.torchmdnet_neighbors.get_neighbor_pairs_cell return { "brute": get_neighbor_pairs_brute, "cell": get_neighbor_pairs_cell, diff --git a/torchmdnet/neighbors/neighbors.cpp b/torchmdnet/neighbors/neighbors.cpp index e8c3dd950..5c407ef3c 100644 --- a/torchmdnet/neighbors/neighbors.cpp +++ b/torchmdnet/neighbors/neighbors.cpp @@ -1,6 +1,6 @@ #include -TORCH_LIBRARY(neighbors, m) { +TORCH_LIBRARY(torchmdnet_neighbors, m) { m.def("get_neighbor_pairs_brute(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)"); m.def("get_neighbor_pairs_shared(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)"); m.def("get_neighbor_pairs_cell(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)"); diff --git a/torchmdnet/neighbors/neighbors_cpu.cpp b/torchmdnet/neighbors/neighbors_cpu.cpp index f769fddef..dac6cfc69 100644 --- a/torchmdnet/neighbors/neighbors_cpu.cpp +++ b/torchmdnet/neighbors/neighbors_cpu.cpp @@ -89,7 +89,7 @@ forward(const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, return {neighbors, deltas, distances, num_pairs_found}; } -TORCH_LIBRARY_IMPL(neighbors, CPU, m) { +TORCH_LIBRARY_IMPL(torchmdnet_neighbors, CPU, m) { m.impl("get_neighbor_pairs_brute", &forward); m.impl("get_neighbor_pairs_shared", &forward); m.impl("get_neighbor_pairs_cell", &forward); diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu index deaa90e72..83f43d985 100644 --- a/torchmdnet/neighbors/neighbors_cuda.cu +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -2,11 +2,12 @@ Connection between the neighbor CUDA implementations and the torch extension. See neighbors.cpp for the definition of the torch extension functions. */ +#include #include "neighbors_cuda_brute.cuh" #include "neighbors_cuda_cell.cuh" #include "neighbors_cuda_shared.cuh" -TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { +TORCH_LIBRARY_IMPL(torchmdnet_neighbors, AutogradCUDA, m) { m.impl("get_neighbor_pairs_brute", [](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, From 59e442f71849f627597fda8163538f96376312b3 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 22 May 2023 15:36:11 +0200 Subject: [PATCH 61/76] Remove verbose --- torchmdnet/neighbors/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/neighbors/__init__.py b/torchmdnet/neighbors/__init__.py index ccb6b0efc..b19c22606 100644 --- a/torchmdnet/neighbors/__init__.py +++ b/torchmdnet/neighbors/__init__.py @@ -11,7 +11,7 @@ def compile_extension(): else [] ) sources = [os.path.join(src_dir, name) for name in sources] - cpp_extension.load(name="torchmdnet_neighbors", sources=sources, is_python_module=False, verbose=True) + cpp_extension.load(name="torchmdnet_neighbors", sources=sources, is_python_module=False) def get_backends(): From 8f9c30ff3e6c84136ef7a63f452d74a7c05b9f75 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 24 May 2023 10:44:56 +0200 Subject: [PATCH 62/76] Do not include batch in hash. This allows to simplify the code and use torch::sort for minimal performance loss. Change block size to 128 --- torchmdnet/neighbors/neighbors_cuda_cell.cuh | 78 +++++--------------- 1 file changed, 17 insertions(+), 61 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/neighbors/neighbors_cuda_cell.cuh index 7d09ed8df..f3964a0c8 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cuh +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cuh @@ -27,7 +27,7 @@ inline __host__ __device__ uint encodeMorton(const uint& i) { * @param ci The cell index * @return The Morton hash */ -inline __host__ __device__ uint hashMorton(int3 ci) { +inline __host__ __device__ int hashMorton(int3 ci, int3 cell_dim) { return encodeMorton(ci.x) | (encodeMorton(ci.y) << 1) | (encodeMorton(ci.z) << 2); } @@ -114,47 +114,19 @@ __device__ int3 getPeriodicCell(int3 cell, int3 cell_dim) { // Assign a hash to each atom based on its position and batch. // This hash is such that atoms in the same cell and batch have the same hash. template -__global__ void assignHash(const Accessor positions, uint64_t* hash_keys, - Accessor hash_values, const Accessor batch, +__global__ void assignHash(const Accessor positions, Accessor hash_keys, scalar3 box_size, scalar_t cutoff, int32_t num_atoms) { const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; if (i_atom >= num_atoms) return; - const uint32_t i_batch = batch[i_atom]; // Move to the unit cell - const auto pi = fetchPosition(positions, i_atom); + const auto pi = fetchPosition(positions, i_atom); const auto ci = getCell(pi, box_size, cutoff); // Calculate the hash - const uint32_t hash = hashMorton(ci); - // Create a hash combining the Morton hash and the batch index, so that atoms in the same cell - // are contiguous - const uint64_t hash_final = (static_cast(hash) << 32) | i_batch; - hash_keys[i_atom] = hash_final; - hash_values[i_atom] = i_atom; + const int32_t hash = hashMorton(ci, getCellDimensions(box_size, cutoff)); + hash_keys[i_atom] = hash; } -/* - * @brief A buffer that is allocated and deallocated using the CUDA caching allocator from torch - */ -template struct CachedBuffer { - explicit CachedBuffer(size_t size) : size_(size) { - ptr_ = static_cast(at::cuda::CUDACachingAllocator::raw_alloc(size * sizeof(T))); - } - ~CachedBuffer() { - at::cuda::CUDACachingAllocator::raw_delete(ptr_); - } - T* get() { - return ptr_; - } - size_t size() { - return size_; - } - -private: - T* ptr_; - size_t size_; -}; - /* * @brief Sort the positions by hash, first by the cell assigned to each position and the batch * index @@ -168,8 +140,7 @@ private: static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, const Tensor& box_size, const Scalar& cutoff) { const int num_atoms = positions.size(0); - auto hash_keys = CachedBuffer(num_atoms); - Tensor hash_values = empty({num_atoms}, positions.options().dtype(torch::kInt32)); + Tensor hash_keys = empty({num_atoms}, positions.options().dtype(torch::kInt32)); const int threads = 128; const int blocks = (num_atoms + threads - 1) / threads; auto stream = at::cuda::getCurrentCUDAStream(); @@ -178,28 +149,17 @@ static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, scalar3 box_size_ = {box_size[0][0].item(), box_size[1][1].item(), box_size[2][2].item()}; - assignHash<<>>( - get_accessor(positions), hash_keys.get(), - get_accessor(hash_values), get_accessor(batch), box_size_, - cutoff_, num_atoms); + assignHash<<>>(get_accessor(positions), + get_accessor(hash_keys), box_size_, + cutoff_, num_atoms); }); - // I have to use cub directly because thrust::sort_by_key is not compatible with graphs - // and torch::lexsort does not support uint64_t - size_t tmp_storage_bytes = 0; - auto d_keys_out = CachedBuffer(num_atoms); - auto d_values_out = CachedBuffer(num_atoms); - auto* hash_values_ptr = hash_values.data_ptr(); - cub::DeviceRadixSort::SortPairs(nullptr, tmp_storage_bytes, hash_keys.get(), d_keys_out.get(), - hash_values_ptr, d_values_out.get(), num_atoms, 0, 64, stream); - auto tmp_storage = CachedBuffer(tmp_storage_bytes); - cub::DeviceRadixSort::SortPairs(tmp_storage.get(), tmp_storage_bytes, hash_keys.get(), - d_keys_out.get(), hash_values_ptr, d_values_out.get(), - num_atoms, 0, 64, stream); - cudaMemcpyAsync(hash_values_ptr, d_values_out.get(), num_atoms * sizeof(int32_t), - cudaMemcpyDeviceToDevice, stream); - Tensor sorted_positions = positions.index_select(0, hash_values); - Tensor sorted_batch = batch.index_select(0, hash_values); - return std::make_tuple(sorted_positions, sorted_batch, hash_values); + // Sort the hash values by the hash keys + torch::Tensor sorted_hash_values; + torch::Tensor sorted_hash_keys; + std::tie(sorted_hash_keys, sorted_hash_values) = torch::sort(hash_keys); + Tensor sorted_positions = positions.index_select(0, sorted_hash_values); + Tensor sorted_batch = batch.index_select(0, sorted_hash_values); + return std::make_tuple(sorted_positions, sorted_batch, sorted_hash_values.to(torch::kInt32)); } template @@ -378,10 +338,6 @@ __device__ void addNeighborsForCell(const Particle& i_atom, int j_cell const auto last_particle = cl.cell_end[j_cell]; for (int cur_j = first_particle; cur_j < last_particle; cur_j++) { const auto j_batch = cl.sorted_batch[cur_j]; - // Particles are sorted by batch after cell, so we can break early here - if (j_batch > i_atom.batch) { - break; - } if ((j_batch == i_atom.batch) and ((cur_j < i_atom.index) or (list.loop and cur_j == i_atom.index))) { const auto position_j = fetchPosition(cl.sorted_positions, cur_j); @@ -461,7 +417,7 @@ public: box_size[2][2].item()}; PairListAccessor list_accessor(list); CellListAccessor cell_list_accessor(cell_list); - const int threads = 64; + const int threads = 128; const int blocks = (num_atoms + threads - 1) / threads; traverseCellList<<>>(cell_list_accessor, list_accessor, num_atoms, box_size_, From 95e6669cd020dc9440601b3f1f825abe7e3e8ade Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 24 May 2023 10:54:53 +0200 Subject: [PATCH 63/76] Update benchmark --- benchmarks/neighbors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py index 43df0cd41..9904ad39b 100644 --- a/benchmarks/neighbors.py +++ b/benchmarks/neighbors.py @@ -110,7 +110,7 @@ def benchmark_neighbors( if __name__ == "__main__": n_particles = 32767 mean_num_neighbors = min(n_particles, 64) - density = 0.5 + density = 0.8 print( "Benchmarking neighbor list generation for {} particles with {} neighbors on average".format( n_particles, mean_num_neighbors From 48e40e340f08449739f9440676740a977c9d8c4b Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 24 May 2023 10:55:54 +0200 Subject: [PATCH 64/76] Add ninja to build the C++ extension --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index f6413e3d5..83dbca760 100644 --- a/environment.yml +++ b/environment.yml @@ -18,3 +18,4 @@ dependencies: - flake8 - pytest - psutil + - ninja From 0d99bd75fd34b7f6f2f8786d23d91dc513d280e4 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 24 May 2023 14:15:44 +0200 Subject: [PATCH 65/76] Use cell index as hash --- torchmdnet/neighbors/neighbors_cuda_cell.cuh | 33 ++------------------ 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/neighbors/neighbors_cuda_cell.cuh index f3964a0c8..f05c439b0 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cuh +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cuh @@ -4,32 +4,7 @@ #ifndef NEIGHBOR_CUDA_CELL_H #define NEIGHBOR_CUDA_CELL_H #include "common.cuh" -#include #include -/* - * @brief Encodes an unsigned integer lower than 1024 as a 32 bit integer by filling every third - * bit. - * @param i The integer to encode - * @return The encoded integer - */ -inline __host__ __device__ uint encodeMorton(const uint& i) { - uint x = i; - x &= 0x3ff; - x = (x | x << 16) & 0x30000ff; - x = (x | x << 8) & 0x300f00f; - x = (x | x << 4) & 0x30c30c3; - x = (x | x << 2) & 0x9249249; - return x; -} - -/* - * @brief Interleave three 10 bit numbers in 32 bits, producing a Z order Morton hash - * @param ci The cell index - * @return The Morton hash - */ -inline __host__ __device__ int hashMorton(int3 ci, int3 cell_dim) { - return encodeMorton(ci.x) | (encodeMorton(ci.y) << 1) | (encodeMorton(ci.z) << 2); -} /* * @brief Calculates the cell dimensions for a given box size and cutoff @@ -111,19 +86,17 @@ __device__ int3 getPeriodicCell(int3 cell, int3 cell_dim) { return periodic_cell; } -// Assign a hash to each atom based on its position and batch. -// This hash is such that atoms in the same cell and batch have the same hash. +// Assign a hash to each atom based on its position. +// This hash is such that atoms in the same cell have the same hash. template __global__ void assignHash(const Accessor positions, Accessor hash_keys, scalar3 box_size, scalar_t cutoff, int32_t num_atoms) { const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; if (i_atom >= num_atoms) return; - // Move to the unit cell const auto pi = fetchPosition(positions, i_atom); const auto ci = getCell(pi, box_size, cutoff); - // Calculate the hash - const int32_t hash = hashMorton(ci, getCellDimensions(box_size, cutoff)); + const int32_t hash = getCellIndex(ci, getCellDimensions(box_size, cutoff)); hash_keys[i_atom] = hash; } From 81a9bf96c38fb3ba1665725bad06439d92eba35d Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 26 May 2023 09:32:55 +0200 Subject: [PATCH 66/76] Use min/max CUDA builtins instead of thrust --- torchmdnet/neighbors/neighbors_cuda_brute.cuh | 3 +-- torchmdnet/neighbors/neighbors_cuda_cell.cuh | 11 +++++------ torchmdnet/neighbors/neighbors_cuda_shared.cuh | 1 - 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_brute.cuh b/torchmdnet/neighbors/neighbors_cuda_brute.cuh index 3756062aa..963165d94 100644 --- a/torchmdnet/neighbors/neighbors_cuda_brute.cuh +++ b/torchmdnet/neighbors/neighbors_cuda_brute.cuh @@ -8,7 +8,6 @@ #define NEIGHBORS_BRUTE_CUH #include "common.cuh" #include -#include #include __device__ uint32_t get_row(uint32_t index) { @@ -50,7 +49,7 @@ __global__ void add_self_kernel(const int num_atoms, Accessor posit if (threadIdx.x == 0) { // Each block adds blockDim.x pairs to the list. // Handle the last block, so that only num_atoms are added in total i_pair = atomicAdd(&list.i_curr_pair[0], - thrust::min(blockDim.x, num_atoms - blockIdx.x * blockDim.x)); + min(blockDim.x, num_atoms - blockIdx.x * blockDim.x)); } __syncthreads(); scalar3 delta{}; diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/neighbors/neighbors_cuda_cell.cuh index f05c439b0..28afe99b3 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cuh +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cuh @@ -4,7 +4,6 @@ #ifndef NEIGHBOR_CUDA_CELL_H #define NEIGHBOR_CUDA_CELL_H #include "common.cuh" -#include /* * @brief Calculates the cell dimensions for a given box size and cutoff @@ -16,9 +15,9 @@ template __host__ __device__ int3 getCellDimensions(scalar3 box_size, scalar_t cutoff) { int3 cell_dim = make_int3(box_size.x / cutoff, box_size.y / cutoff, box_size.z / cutoff); // Minimum 3 cells in each dimension - cell_dim.x = thrust::max(cell_dim.x, 3); - cell_dim.y = thrust::max(cell_dim.y, 3); - cell_dim.z = thrust::max(cell_dim.z, 3); + cell_dim.x = max(cell_dim.x, 3); + cell_dim.y = max(cell_dim.y, 3); + cell_dim.z = max(cell_dim.z, 3); // In the host, throw if there are more than 1024 cells in any dimension #ifndef __CUDA_ARCH__ if (cell_dim.x > 1024 || cell_dim.y > 1024 || cell_dim.z > 1024) { @@ -286,8 +285,8 @@ template __device__ void addNeighborPair(PairListAccessor& list, const int i, const int j, scalar_t distance2, scalar3 delta) { const bool requires_transpose = list.include_transpose and (j != i); - const int ni = thrust::max(i, j); - const int nj = thrust::min(i, j); + const int ni = max(i, j); + const int nj = min(i, j); const scalar_t delta_sign = (ni == i) ? scalar_t(1.0) : scalar_t(-1.0); const scalar_t distance = sqrt_(distance2); delta = {delta_sign * delta.x, delta_sign * delta.y, delta_sign * delta.z}; diff --git a/torchmdnet/neighbors/neighbors_cuda_shared.cuh b/torchmdnet/neighbors/neighbors_cuda_shared.cuh index 01861e901..fd9693a82 100644 --- a/torchmdnet/neighbors/neighbors_cuda_shared.cuh +++ b/torchmdnet/neighbors/neighbors_cuda_shared.cuh @@ -8,7 +8,6 @@ #define NEIGHBORS_SHARED_CUH #include "common.cuh" #include -#include #include template From 5459db63b3c8f0334476498eac3a610c67e66a51 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 26 May 2023 12:18:35 +0200 Subject: [PATCH 67/76] Simplify getCell --- torchmdnet/neighbors/neighbors_cuda_cell.cuh | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/neighbors/neighbors_cuda_cell.cuh index 28afe99b3..87fde6afb 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cuh +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cuh @@ -36,20 +36,10 @@ __host__ __device__ int3 getCellDimensions(scalar3 box_size, scalar_t */ template __device__ int3 getCell(scalar3 p, scalar3 box_size, scalar_t cutoff) { - p = rect::apply_pbc(p, box_size); - // Take to the [0, box_size] range and divide by cutoff (which is the cell size) - int cx = floorf((p.x + scalar_t(0.5) * box_size.x) / cutoff); - int cy = floorf((p.y + scalar_t(0.5) * box_size.y) / cutoff); - int cz = floorf((p.z + scalar_t(0.5) * box_size.z) / cutoff); - int3 cell_dim = getCellDimensions(box_size, cutoff); - // Wrap around. If the position of a particle is exactly box_size, it will be in the last cell, - // which results in an illegal access down the line. - if (cx == cell_dim.x) - cx = 0; - if (cy == cell_dim.y) - cy = 0; - if (cz == cell_dim.z) - cz = 0; + const int3 cell_dim = getCellDimensions(box_size, cutoff); + const int cx = fmodf(floorf((p.x + scalar_t(0.5) * box_size.x) / cutoff), cell_dim.x); + const int cy = fmodf(floorf((p.y + scalar_t(0.5) * box_size.y) / cutoff), cell_dim.y); + const int cz = fmodf(floorf((p.z + scalar_t(0.5) * box_size.z) / cutoff), cell_dim.z); return make_int3(cx, cy, cz); } From 27adbbe876ba3d6b1ebceb4b5a527c5f0aae9a06 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 26 May 2023 12:18:44 +0200 Subject: [PATCH 68/76] Remove unused variable, add some consts --- torchmdnet/neighbors/neighbors_cuda_cell.cuh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/neighbors/neighbors_cuda_cell.cuh index 87fde6afb..0e6750c03 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cuh +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cuh @@ -140,14 +140,13 @@ __global__ void fillCellOffsetsD(const Accessor sorted_positions, const int icell = getCellIndex(getCell(pi, box_size, cutoff), cell_dim); int im1_cell; if (i_atom > 0) { - int im1 = i_atom - 1; + const int im1 = i_atom - 1; const auto pim1 = fetchPosition(sorted_positions, im1); im1_cell = getCellIndex(getCell(pim1, box_size, cutoff), cell_dim); } else { im1_cell = 0; } if (icell != im1_cell || i_atom == 0) { - int n_cells = cell_start.size(0); cell_start[icell] = i_atom; if (i_atom > 0) { cell_end[im1_cell] = i_atom; @@ -210,7 +209,7 @@ __device__ int getNeighborCellIndex(int3 cell_i, int i, int3 cell_dim) { cell_j.y += (i / 3) % 3 - 1; cell_j.z += i / 9 - 1; cell_j = getPeriodicCell(cell_j, cell_dim); - int icellj = getCellIndex(cell_j, cell_dim); + const int icellj = getCellIndex(cell_j, cell_dim); return icellj; } From 27a3e8789e1e0f5ee0f824084aeda8e79b5bda83 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 26 May 2023 13:35:58 +0200 Subject: [PATCH 69/76] Simplify code by assuming particles are always sorted by cell index --- torchmdnet/neighbors/neighbors_cuda_cell.cuh | 140 ++++++++----------- 1 file changed, 60 insertions(+), 80 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/neighbors/neighbors_cuda_cell.cuh index 0e6750c03..ae09741ab 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cuh +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cuh @@ -35,8 +35,8 @@ __host__ __device__ int3 getCellDimensions(scalar3 box_size, scalar_t * @return The cell index */ template -__device__ int3 getCell(scalar3 p, scalar3 box_size, scalar_t cutoff) { - const int3 cell_dim = getCellDimensions(box_size, cutoff); +__device__ int3 getCell(scalar3 p, scalar3 box_size, scalar_t cutoff, + int3 cell_dim) { const int cx = fmodf(floorf((p.x + scalar_t(0.5) * box_size.x) / cutoff), cell_dim.x); const int cy = fmodf(floorf((p.y + scalar_t(0.5) * box_size.y) / cutoff), cell_dim.y); const int cz = fmodf(floorf((p.z + scalar_t(0.5) * box_size.z) / cutoff), cell_dim.z); @@ -75,34 +75,31 @@ __device__ int3 getPeriodicCell(int3 cell, int3 cell_dim) { return periodic_cell; } -// Assign a hash to each atom based on its position. -// This hash is such that atoms in the same cell have the same hash. +// Computes and stores the cell index of each atom. template -__global__ void assignHash(const Accessor positions, Accessor hash_keys, - scalar3 box_size, scalar_t cutoff, int32_t num_atoms) { +__global__ void assignCellIndex(const Accessor positions, + Accessor cell_indices, scalar3 box_size, + scalar_t cutoff) { const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; - if (i_atom >= num_atoms) + if (i_atom >= positions.size(0)) return; const auto pi = fetchPosition(positions, i_atom); - const auto ci = getCell(pi, box_size, cutoff); - const int32_t hash = getCellIndex(ci, getCellDimensions(box_size, cutoff)); - hash_keys[i_atom] = hash; + const auto cell_dim = getCellDimensions(box_size, cutoff); + const auto ci = getCell(pi, box_size, cutoff, cell_dim); + cell_indices[i_atom] = getCellIndex(ci, cell_dim); } /* - * @brief Sort the positions by hash, first by the cell assigned to each position and the batch - * index + * @brief Sort the positions by cell index * @param positions The positions of the atoms - * @param batch The batch index of each atom * @param box_size The box vectors * @param cutoff The cutoff - * @return A tuple of the sorted positions, sorted batch indices and the original indices of each - * atom in the sorted list + * @return A tuple of the sorted indices and cell indices */ -static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, - const Tensor& box_size, const Scalar& cutoff) { +static auto sortAtomsByCellIndex(const Tensor& positions, const Tensor& box_size, + const Scalar& cutoff) { const int num_atoms = positions.size(0); - Tensor hash_keys = empty({num_atoms}, positions.options().dtype(torch::kInt32)); + Tensor cell_index = empty({num_atoms}, positions.options().dtype(torch::kInt32)); const int threads = 128; const int blocks = (num_atoms + threads - 1) / threads; auto stream = at::cuda::getCurrentCUDAStream(); @@ -111,38 +108,30 @@ static auto sortPositionsByHash(const Tensor& positions, const Tensor& batch, scalar3 box_size_ = {box_size[0][0].item(), box_size[1][1].item(), box_size[2][2].item()}; - assignHash<<>>(get_accessor(positions), - get_accessor(hash_keys), box_size_, - cutoff_, num_atoms); + assignCellIndex<<>>(get_accessor(positions), + get_accessor(cell_index), + box_size_, cutoff_); }); - // Sort the hash values by the hash keys - torch::Tensor sorted_hash_values; - torch::Tensor sorted_hash_keys; - std::tie(sorted_hash_keys, sorted_hash_values) = torch::sort(hash_keys); - Tensor sorted_positions = positions.index_select(0, sorted_hash_values); - Tensor sorted_batch = batch.index_select(0, sorted_hash_values); - return std::make_tuple(sorted_positions, sorted_batch, sorted_hash_values.to(torch::kInt32)); + // Sort the atom indices by cell index + Tensor sorted_atom_index; + Tensor sorted_cell_index; + std::tie(sorted_cell_index, sorted_atom_index) = torch::sort(cell_index); + return std::make_tuple(sorted_atom_index.to(torch::kInt32), sorted_cell_index); } -template -__global__ void fillCellOffsetsD(const Accessor sorted_positions, - const Accessor sorted_indices, - Accessor cell_start, Accessor cell_end, - scalar3 box_size, scalar_t cutoff) { +__global__ void fillCellOffsetsD(const Accessor sorted_cell_indices, + Accessor cell_start, Accessor cell_end) { // Since positions are sorted by cell, for a given atom, if the previous atom is in a different // cell, then the current atom is the first atom in its cell We use this fact to fill the // cell_start and cell_end arrays const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; - if (i_atom >= sorted_positions.size(0)) + if (i_atom >= sorted_cell_indices.size(0)) return; - const auto pi = fetchPosition(sorted_positions, i_atom); - const int3 cell_dim = getCellDimensions(box_size, cutoff); - const int icell = getCellIndex(getCell(pi, box_size, cutoff), cell_dim); + const int icell = sorted_cell_indices[i_atom]; int im1_cell; if (i_atom > 0) { const int im1 = i_atom - 1; - const auto pim1 = fetchPosition(sorted_positions, im1); - im1_cell = getCellIndex(getCell(pim1, box_size, cutoff), cell_dim); + im1_cell = sorted_cell_indices[im1]; } else { im1_cell = 0; } @@ -152,47 +141,28 @@ __global__ void fillCellOffsetsD(const Accessor sorted_positions, cell_end[im1_cell] = i_atom; } } - if (i_atom == sorted_positions.size(0) - 1) { + if (i_atom == sorted_cell_indices.size(0) - 1) { cell_end[icell] = i_atom + 1; } } /* @brief Fills the cell_start and cell_end arrays, identifying the first and last atom in each cell - @param sorted_positions The positions sorted by cell - @param sorted_indices The original indices of the sorted positions - @param batch The batch index of each position - @param box_size The box vectors - @param cutoff The cutoff distance + @param sorted_cell_indices The cell indices of each position + @param cell_dim The dimensions of the cell grid @return A tuple of cell_start and cell_end arrays */ -static auto fillCellOffsets(const Tensor& sorted_positions, const Tensor& sorted_indices, - const Tensor& box_size, const Scalar& cutoff) { - const TensorOptions options = sorted_positions.options(); - int3 cell_dim; - AT_DISPATCH_FLOATING_TYPES(sorted_positions.scalar_type(), "fillCellOffsets", [&] { - scalar_t cutoff_ = cutoff.to(); - scalar3 box_size_ = {box_size[0][0].item(), - box_size[1][1].item(), - box_size[2][2].item()}; - cell_dim = getCellDimensions(box_size_, cutoff_); - }); +static auto fillCellOffsets(const Tensor& sorted_cell_indices, int3 cell_dim) { + const TensorOptions options = sorted_cell_indices.options(); const int num_cells = cell_dim.x * cell_dim.y * cell_dim.z; const Tensor cell_start = full({num_cells}, -1, options.dtype(torch::kInt)); const Tensor cell_end = empty({num_cells}, options.dtype(torch::kInt)); const int threads = 128; - const int blocks = (sorted_positions.size(0) + threads - 1) / threads; - AT_DISPATCH_FLOATING_TYPES(sorted_positions.scalar_type(), "fillCellOffsets", [&] { - auto stream = at::cuda::getCurrentCUDAStream(); - scalar_t cutoff_ = cutoff.to(); - scalar3 box_size_ = {box_size[0][0].item(), - box_size[1][1].item(), - box_size[2][2].item()}; - fillCellOffsetsD<<>>( - get_accessor(sorted_positions), get_accessor(sorted_indices), - get_accessor(cell_start), get_accessor(cell_end), box_size_, - cutoff_); - }); + const int blocks = (sorted_cell_indices.size(0) + threads - 1) / threads; + auto stream = at::cuda::getCurrentCUDAStream(); + fillCellOffsetsD<<>>(get_accessor(sorted_cell_indices), + get_accessor(cell_start), + get_accessor(cell_end)); return std::make_tuple(cell_start, cell_end); } @@ -223,40 +193,50 @@ template struct Particle { struct CellList { Tensor cell_start, cell_end; - Tensor original_indices; + Tensor sorted_indices; Tensor sorted_positions, sorted_batch; }; CellList constructCellList(const Tensor& positions, const Tensor& batch, const Tensor& box_size, const Scalar& cutoff) { // The algorithm for the cell list construction can be summarized in three separate steps: - // 1. Hash (label) the particles according to the cell (bin) they lie in. - // 2. Sort the particles and hashes using the hashes as the ordering label + // 1. Label the particles according to the cell (bin) they lie in. + // 2. Sort the particles using the cell index as the ordering label // (technically this is known as sorting by key). So that particles with positions // lying in the same cell become contiguous in memory. // 3. Identify where each cell starts and ends in the sorted particle positions // array. const TensorOptions options = positions.options(); CellList cl; + Tensor sorted_cell_indices; // Steps 1 and 2 - std::tie(cl.sorted_positions, cl.sorted_batch, cl.original_indices) = - sortPositionsByHash(positions, batch, box_size, cutoff); + std::tie(cl.sorted_indices, sorted_cell_indices) = + sortAtomsByCellIndex(positions, box_size, cutoff); + cl.sorted_positions = positions.index_select(0, cl.sorted_indices); + cl.sorted_batch = batch.index_select(0, cl.sorted_indices); // Step 3 - std::tie(cl.cell_start, cl.cell_end) = - fillCellOffsets(cl.sorted_positions, cl.original_indices, box_size, cutoff); + int3 cell_dim; + AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "computeCellDim", [&] { + scalar_t cutoff_ = cutoff.to(); + scalar3 box_size_ = {box_size[0][0].item(), + box_size[1][1].item(), + box_size[2][2].item()}; + cell_dim = getCellDimensions(box_size_, cutoff_); + }); + std::tie(cl.cell_start, cl.cell_end) = fillCellOffsets(sorted_cell_indices, cell_dim); return cl; } template struct CellListAccessor { Accessor cell_start, cell_end; - Accessor original_indices; + Accessor sorted_indices; Accessor sorted_positions; Accessor sorted_batch; explicit CellListAccessor(const CellList& cl) : cell_start(get_accessor(cl.cell_start)), cell_end(get_accessor(cl.cell_end)), - original_indices(get_accessor(cl.original_indices)), + sorted_indices(get_accessor(cl.sorted_indices)), sorted_positions(get_accessor(cl.sorted_positions)), sorted_batch(get_accessor(cl.sorted_batch)) { } @@ -308,7 +288,7 @@ __device__ void addNeighborsForCell(const Particle& i_atom, int j_cell delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; if ((distance2 < i_atom.cutoff_upper2 and distance2 >= i_atom.cutoff_lower2) or (list.loop and cur_j == i_atom.index)) { - const int orj = cl.original_indices[cur_j]; + const int orj = cl.sorted_indices[cur_j]; addNeighborPair(list, i_atom.original_index, orj, distance2, delta); } // endif } // endif @@ -330,13 +310,13 @@ __global__ void traverseCellList(const CellListAccessor cell_list, if (i_atom.index >= num_atoms) { return; } - i_atom.original_index = cell_list.original_indices[i_atom.index]; + i_atom.original_index = cell_list.sorted_indices[i_atom.index]; i_atom.batch = cell_list.sorted_batch[i_atom.index]; i_atom.position = fetchPosition(cell_list.sorted_positions, i_atom.index); i_atom.cutoff_lower2 = cutoff_lower * cutoff_lower; i_atom.cutoff_upper2 = cutoff_upper * cutoff_upper; - const int3 cell_i = getCell(i_atom.position, box_size, cutoff_upper); const int3 cell_dim = getCellDimensions(box_size, cutoff_upper); + const int3 cell_i = getCell(i_atom.position, box_size, cutoff_upper, cell_dim); // Loop over the 27 cells around the current cell for (int i = 0; i < 27; i++) { const int neighbor_cell = getNeighborCellIndex(cell_i, i, cell_dim); From 5dad9a9a46318502d838a6c7b6a3b64e74798674 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 26 May 2023 13:39:22 +0200 Subject: [PATCH 70/76] Remove and/or keywords --- torchmdnet/neighbors/neighbors_cuda_cell.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/neighbors/neighbors_cuda_cell.cuh index ae09741ab..bf052b41f 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cuh +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cuh @@ -280,14 +280,14 @@ __device__ void addNeighborsForCell(const Particle& i_atom, int j_cell for (int cur_j = first_particle; cur_j < last_particle; cur_j++) { const auto j_batch = cl.sorted_batch[cur_j]; if ((j_batch == i_atom.batch) and - ((cur_j < i_atom.index) or (list.loop and cur_j == i_atom.index))) { + ((cur_j < i_atom.index) || (list.loop and cur_j == i_atom.index))) { const auto position_j = fetchPosition(cl.sorted_positions, cur_j); const auto delta = rect::compute_distance(i_atom.position, position_j, list.use_periodic, box_size); const scalar_t distance2 = delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; - if ((distance2 < i_atom.cutoff_upper2 and distance2 >= i_atom.cutoff_lower2) or - (list.loop and cur_j == i_atom.index)) { + if ((distance2 < i_atom.cutoff_upper2 && distance2 >= i_atom.cutoff_lower2) or + (list.loop && cur_j == i_atom.index)) { const int orj = cl.sorted_indices[cur_j]; addNeighborPair(list, i_atom.original_index, orj, distance2, delta); } // endif From c18d2c7d73322c1b4d5d65b221ead4960068d490 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 26 May 2023 13:44:55 +0200 Subject: [PATCH 71/76] Small changes to common --- torchmdnet/neighbors/common.cuh | 39 +++++++++++++++------------------ 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/torchmdnet/neighbors/common.cuh b/torchmdnet/neighbors/common.cuh index ec1a6661b..c96881595 100644 --- a/torchmdnet/neighbors/common.cuh +++ b/torchmdnet/neighbors/common.cuh @@ -93,27 +93,26 @@ template struct PairListAccessor { template __device__ void writeAtomPair(PairListAccessor& list, int i, int j, - scalar3 delta, scalar_t distance, int i_pair) { - list.neighbors[0][i_pair] = i; - list.neighbors[1][i_pair] = j; - list.deltas[i_pair][0] = delta.x; - list.deltas[i_pair][1] = delta.y; - list.deltas[i_pair][2] = delta.z; - list.distances[i_pair] = distance; + scalar3 delta, scalar_t distance, int i_pair) { + list.neighbors[0][i_pair] = i; + list.neighbors[1][i_pair] = j; + list.deltas[i_pair][0] = delta.x; + list.deltas[i_pair][1] = delta.y; + list.deltas[i_pair][2] = delta.z; + list.distances[i_pair] = distance; } template __device__ void addAtomPairToList(PairListAccessor& list, int i, int j, - scalar3 delta, scalar_t distance, - bool add_transpose) { - const int32_t i_pair = atomicAdd(&list.i_curr_pair[0], add_transpose ? 2 : 1); - // Neighbors after the max number of pairs are ignored, although the pair is counted - if (i_pair + add_transpose < list.neighbors.size(1)) { - writeAtomPair(list, i, j, delta, distance, i_pair); - if (add_transpose) { - writeAtomPair(list, j, i, {-delta.x, -delta.y, -delta.z}, distance, i_pair + 1); + scalar3 delta, scalar_t distance, bool add_transpose) { + const int32_t i_pair = atomicAdd(&list.i_curr_pair[0], add_transpose ? 2 : 1); + // Neighbors after the max number of pairs are ignored, although the pair is counted + if (i_pair + add_transpose < list.neighbors.size(1)) { + writeAtomPair(list, i, j, delta, distance, i_pair); + if (add_transpose) { + writeAtomPair(list, j, i, {-delta.x, -delta.y, -delta.z}, distance, i_pair + 1); + } } - } } static void checkInput(const Tensor& positions, const Tensor& batch) { @@ -164,11 +163,9 @@ namespace triclinic { template struct Box { scalar_t size[3][3]; Box(const Tensor& box_vectors) { - if (box_vectors.size(0) == 3 && box_vectors.size(1) == 3) { - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 3; j++) { - size[i][j] = box_vectors[i][j].item(); - } + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + size[i][j] = box_vectors[i][j].item(); } } } From c89b987d6e4748e50f1b3c45324b094c3cfad3ce Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 26 May 2023 13:59:01 +0200 Subject: [PATCH 72/76] Change a comment --- torchmdnet/neighbors/neighbors_cuda_cell.cuh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/neighbors/neighbors_cuda_cell.cuh index bf052b41f..bde20dce1 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cuh +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cuh @@ -28,11 +28,12 @@ __host__ __device__ int3 getCellDimensions(scalar3 box_size, scalar_t } /* - * @brief Get the cell index of a point + * @brief Get the cell coordinates of a point * @param p The point position * @param box_size The size of the box in each dimension * @param cutoff The cutoff - * @return The cell index + * @param cell_dim The number of cells in each dimension + * @return The cell coordinates */ template __device__ int3 getCell(scalar3 p, scalar3 box_size, scalar_t cutoff, From cd6b6f949ba5d6eec2f0bbc1ef4b5b1001938654 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 26 May 2023 14:06:23 +0200 Subject: [PATCH 73/76] Fix bug introduced in previous commit --- torchmdnet/neighbors/common.cuh | 10 ++++++---- torchmdnet/neighbors/neighbors_cuda_brute.cuh | 2 +- torchmdnet/neighbors/neighbors_cuda_shared.cuh | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/torchmdnet/neighbors/common.cuh b/torchmdnet/neighbors/common.cuh index c96881595..dce5e6d4d 100644 --- a/torchmdnet/neighbors/common.cuh +++ b/torchmdnet/neighbors/common.cuh @@ -162,10 +162,12 @@ __device__ auto compute_distance(scalar3 pos_i, scalar3 pos_ namespace triclinic { template struct Box { scalar_t size[3][3]; - Box(const Tensor& box_vectors) { - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 3; j++) { - size[i][j] = box_vectors[i][j].item(); + Box(const Tensor& box_vectors, bool use_periodic) { + if (use_periodic) { + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + size[i][j] = box_vectors[i][j].item(); + } } } } diff --git a/torchmdnet/neighbors/neighbors_cuda_brute.cuh b/torchmdnet/neighbors/neighbors_cuda_brute.cuh index 963165d94..098b6a88c 100644 --- a/torchmdnet/neighbors/neighbors_cuda_brute.cuh +++ b/torchmdnet/neighbors/neighbors_cuda_brute.cuh @@ -86,7 +86,7 @@ public: std::max((num_all_pairs + num_threads - 1UL) / num_threads, 1UL); AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { PairListAccessor list_accessor(list); - triclinic::Box box(box_vectors); + triclinic::Box box(box_vectors, use_periodic); const scalar_t cutoff_upper_ = cutoff_upper.to(); const scalar_t cutoff_lower_ = cutoff_lower.to(); TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); diff --git a/torchmdnet/neighbors/neighbors_cuda_shared.cuh b/torchmdnet/neighbors/neighbors_cuda_shared.cuh index fd9693a82..aebf91cd4 100644 --- a/torchmdnet/neighbors/neighbors_cuda_shared.cuh +++ b/torchmdnet/neighbors/neighbors_cuda_shared.cuh @@ -93,7 +93,7 @@ public: positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() { const scalar_t cutoff_upper_ = cutoff_upper.to(); const scalar_t cutoff_lower_ = cutoff_lower.to(); - triclinic::Box box(box_vectors); + triclinic::Box box(box_vectors, use_periodic); TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); constexpr int BLOCKSIZE = 64; const int num_blocks = std::max((num_atoms + BLOCKSIZE - 1) / BLOCKSIZE, 1); From 438274160cc19f8c253ab27e7ee308c4105a5ff0 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 26 May 2023 14:19:00 +0200 Subject: [PATCH 74/76] Revert "Simplify getCell" This reverts commit 5459db63b3c8f0334476498eac3a610c67e66a51. --- torchmdnet/neighbors/neighbors_cuda_cell.cuh | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/neighbors/neighbors_cuda_cell.cuh index bde20dce1..e0b49ed45 100644 --- a/torchmdnet/neighbors/neighbors_cuda_cell.cuh +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cuh @@ -38,9 +38,19 @@ __host__ __device__ int3 getCellDimensions(scalar3 box_size, scalar_t template __device__ int3 getCell(scalar3 p, scalar3 box_size, scalar_t cutoff, int3 cell_dim) { - const int cx = fmodf(floorf((p.x + scalar_t(0.5) * box_size.x) / cutoff), cell_dim.x); - const int cy = fmodf(floorf((p.y + scalar_t(0.5) * box_size.y) / cutoff), cell_dim.y); - const int cz = fmodf(floorf((p.z + scalar_t(0.5) * box_size.z) / cutoff), cell_dim.z); + p = rect::apply_pbc(p, box_size); + // Take to the [0, box_size] range and divide by cutoff (which is the cell size) + int cx = floorf((p.x + scalar_t(0.5) * box_size.x) / cutoff); + int cy = floorf((p.y + scalar_t(0.5) * box_size.y) / cutoff); + int cz = floorf((p.z + scalar_t(0.5) * box_size.z) / cutoff); + // Wrap around. If the position of a particle is exactly box_size, it will be in the last cell, + // which results in an illegal access down the line. + if (cx == cell_dim.x) + cx = 0; + if (cy == cell_dim.y) + cy = 0; + if (cz == cell_dim.z) + cz = 0; return make_int3(cx, cy, cz); } From fb554c935bf075734d9d7fcc2fe3baaab3dab867 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Fri, 26 May 2023 14:20:18 +0200 Subject: [PATCH 75/76] Move CoM of particles in neighbor tests outside of the main box to check PBC correctness --- tests/test_neighbors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index d017f0904..848489a46 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -89,7 +89,7 @@ def test_neighbors( ).to(device) cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) lbox = 10.0 - pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox + pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox - 10.0*lbox # Ensure there is at least one pair pos[0, :] = torch.zeros(3) pos[1, :] = torch.zeros(3) From 605a4b5baa2faa1c5b2ee0e736bacdf8ed31b8d3 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 29 May 2023 11:21:27 +0200 Subject: [PATCH 76/76] Remove commented-out code --- benchmarks/neighbors.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py index 9904ad39b..a9578283a 100644 --- a/benchmarks/neighbors.py +++ b/benchmarks/neighbors.py @@ -210,27 +210,3 @@ def benchmark_neighbors( results["distance", n_particles], ) ) - - # # Print a second table showing time per atom, show in ns - # print("\n") - # print("Time per atom") - # print( - # "{:<10} {:<10} {:<10} {:<10} {:<10}".format( - # "Batch size", "Shared(ns)", "Brute(ns)", "Cell(ns)", "Distance(ns)" - # ) - # ) - # print( - # "{:<10} {:<10} {:<10} {:<10} {:<10}".format( - # "----------", "---------", "---------", "---------", "---------" - # ) - # ) - # for n_batches in batch_sizes: - # print( - # "{:<10} {:<10.2f} {:<10.2f} {:<10.2f} {:<10.2f}".format( - # n_batches, - # results["shared", n_batches] / n_particles * 1e6, - # results["brute", n_batches] / n_particles * 1e6, - # results["cell", n_batches] / n_particles * 1e6, - # results["distance", n_batches] / n_particles * 1e6, - # ) - # )