diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py new file mode 100644 index 000000000..a9578283a --- /dev/null +++ b/benchmarks/neighbors.py @@ -0,0 +1,212 @@ +import os +import torch +import numpy as np +from torchmdnet.models.utils import Distance, OptimizedDistance + + +def benchmark_neighbors( + device, strategy, n_batches, total_num_particles, mean_num_neighbors, density +): + """Benchmark the neighbor list generation. + + Parameters + ---------- + device : str + Device to use for the benchmark. + strategy : str + Strategy to use for the neighbor list generation (cell, brute). + n_batches : int + Number of batches to generate. + total_num_particles : int + Total number of particles. + mean_num_neighbors : int + Mean number of neighbors per particle. + density : float + Density of the system. + Returns + ------- + float + Average time per batch in seconds. + """ + torch.random.manual_seed(12344) + np.random.seed(43211) + num_particles = total_num_particles // n_batches + expected_num_neighbors = mean_num_neighbors + cutoff = np.cbrt(3 * expected_num_neighbors / (4 * np.pi * density)) + n_atoms_per_batch = torch.randint( + int(num_particles / 2), int(num_particles * 2), size=(n_batches,), device="cpu" + ) + # Fix so that the total number of particles is correct. Special care if the difference is negative + difference = total_num_particles - n_atoms_per_batch.sum() + if difference > 0: + while difference > 0: + i = np.random.randint(0, n_batches) + n_atoms_per_batch[i] += 1 + difference -= 1 + else: + while difference < 0: + i = np.random.randint(0, n_batches) + if n_atoms_per_batch[i] > num_particles: + n_atoms_per_batch[i] -= 1 + difference += 1 + lbox = np.cbrt(num_particles / density) + batch = torch.repeat_interleave( + torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch + ).to(device) + cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) + pos = torch.rand(cumsum[-1], 3, device="cpu").to(device) * lbox + if strategy != "distance": + max_num_pairs = (expected_num_neighbors * n_atoms_per_batch.sum()).item() * 2 + box = torch.eye(3, device=device) * lbox + nl = OptimizedDistance( + cutoff_upper=cutoff, + max_num_pairs=max_num_pairs, + strategy=strategy, + box=box, + loop=False, + include_transpose=True, + check_errors=False, + resize_to_fit=False, + ) + else: + max_num_neighbors = int(expected_num_neighbors * 5) + nl = Distance( + loop=False, + cutoff_lower=0.0, + cutoff_upper=cutoff, + max_num_neighbors=max_num_neighbors, + ) + # Warmup + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for i in range(10): + neighbors, distances, distance_vecs = nl(pos, batch) + torch.cuda.synchronize() + nruns = 50 + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + graph = torch.cuda.CUDAGraph() + # record in a cuda graph + if strategy != "distance": + with torch.cuda.graph(graph): + neighbors, distances, distance_vecs = nl(pos, batch) + start.record() + for i in range(nruns): + graph.replay() + end.record() + else: + start.record() + for i in range(nruns): + neighbors, distances, distance_vecs = nl(pos, batch) + end.record() + torch.cuda.synchronize() + # Final time + return start.elapsed_time(end) / nruns + + +if __name__ == "__main__": + n_particles = 32767 + mean_num_neighbors = min(n_particles, 64) + density = 0.8 + print( + "Benchmarking neighbor list generation for {} particles with {} neighbors on average".format( + n_particles, mean_num_neighbors + ) + ) + results = {} + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + for strategy in ["shared", "brute", "cell", "distance"]: + for n_batches in batch_sizes: + time = benchmark_neighbors( + device="cuda", + strategy=strategy, + n_batches=n_batches, + total_num_particles=n_particles, + mean_num_neighbors=mean_num_neighbors, + density=density, + ) + # Store results in a dictionary + results[strategy, n_batches] = time + print("Summary") + print("-------") + print( + "{:<10} {:<21} {:<21} {:<18} {:<10}".format( + "Batch size", "Shared(ms)", "Brute(ms)", "Cell(ms)", "Distance(ms)" + ) + ) + print( + "{:<10} {:<21} {:<21} {:<18} {:<10}".format( + "----------", "---------", "---------", "---------", "---------" + ) + ) + # Print a column per strategy, show speedup over Distance in parenthesis + for n_batches in batch_sizes: + base = results["distance", n_batches] + print( + "{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format( + n_batches, + results["shared", n_batches], + base / results["shared", n_batches], + results["brute", n_batches], + base / results["brute", n_batches], + results["cell", n_batches], + base / results["cell", n_batches], + results["distance", n_batches], + ) + ) + n_particles_list = np.power(2, np.arange(8, 18)) + + for n_batches in [1, 2, 32, 64]: + print( + "Benchmarking neighbor list generation for {} batches with {} neighbors on average".format( + n_batches, mean_num_neighbors + ) + ) + results = {} + for strategy in ["shared", "brute", "cell", "distance"]: + for n_particles in n_particles_list: + mean_num_neighbors = min(n_particles, 64) + time = benchmark_neighbors( + device="cuda", + strategy=strategy, + n_batches=n_batches, + total_num_particles=n_particles, + mean_num_neighbors=mean_num_neighbors, + density=density, + ) + # Store results in a dictionary + results[strategy, n_particles] = time + print("Summary") + print("-------") + print( + "{:<10} {:<21} {:<21} {:<18} {:<10}".format( + "N Particles", "Shared(ms)", "Brute(ms)", "Cell(ms)", "Distance(ms)" + ) + ) + print( + "{:<10} {:<21} {:<21} {:<18} {:<10}".format( + "----------", "---------", "---------", "---------", "---------" + ) + ) + # Print a column per strategy, show speedup over Distance in parenthesis + for n_particles in n_particles_list: + base = results["distance", n_particles] + brute_speedup = base / results["brute", n_particles] + if n_particles > 32000: + results["brute", n_particles] = 0 + brute_speedup = 0 + print( + "{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format( + n_particles, + results["shared", n_particles], + base / results["shared", n_particles], + results["brute", n_particles], + brute_speedup, + results["cell", n_particles], + base / results["cell", n_particles], + results["distance", n_particles], + ) + ) diff --git a/environment.yml b/environment.yml index 610d7c044..aeac149e3 100644 --- a/environment.yml +++ b/environment.yml @@ -18,3 +18,4 @@ dependencies: - flake8 - pytest - psutil + - ninja diff --git a/setup.py b/setup.py index 3ace81d3f..2fd821ed0 100644 --- a/setup.py +++ b/setup.py @@ -15,5 +15,7 @@ name="torchmd-net", version=version, packages=find_packages(), + package_data={"torchmdnet": ["neighbors/neighbors*", "neighbors/*.cu*"]}, + include_package_data=True, entry_points={"console_scripts": ["torchmd-train = torchmdnet.scripts.train:main"]}, ) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py new file mode 100644 index 000000000..848489a46 --- /dev/null +++ b/tests/test_neighbors.py @@ -0,0 +1,532 @@ +import os +import pytest +import torch +import torch.jit +import numpy as np +from torchmdnet.models.utils import Distance, OptimizedDistance + + +def sort_neighbors(neighbors, deltas, distances): + i_sorted = np.lexsort(neighbors) + return neighbors[:, i_sorted], deltas[i_sorted], distances[i_sorted] + + +def apply_pbc(deltas, box_vectors): + if box_vectors is None: + return deltas + else: + ref_vectors = box_vectors.cpu().detach().numpy() + deltas -= np.outer(np.round(deltas[:, 2] / ref_vectors[2, 2]), ref_vectors[2]) + deltas -= np.outer(np.round(deltas[:, 1] / ref_vectors[1, 1]), ref_vectors[1]) + deltas -= np.outer(np.round(deltas[:, 0] / ref_vectors[0, 0]), ref_vectors[0]) + return deltas + + +def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box_vectors): + batch = batch.cpu() + n_atoms_per_batch = torch.bincount(batch) + n_batches = n_atoms_per_batch.shape[0] + cumsum = ( + torch.cumsum(torch.cat([torch.tensor([0]), n_atoms_per_batch]), dim=0) + .cpu() + .detach() + .numpy() + ) + ref_neighbors = np.concatenate( + [ + np.tril_indices(int(n_atoms_per_batch[i]), -1) + cumsum[i] + for i in range(n_batches) + ], + axis=1, + ) + # add the upper triangle + if include_transpose: + ref_neighbors = np.concatenate( + [ref_neighbors, np.flip(ref_neighbors, axis=0)], axis=1 + ) + if loop: # Add self interactions + ilist = np.arange(cumsum[-1]) + ref_neighbors = np.concatenate( + [ref_neighbors, np.stack([ilist, ilist], axis=0)], axis=1 + ) + pos_np = pos.cpu().detach().numpy() + ref_distance_vecs = apply_pbc( + pos_np[ref_neighbors[0]] - pos_np[ref_neighbors[1]], box_vectors + ) + ref_distances = np.linalg.norm(ref_distance_vecs, axis=-1) + + # remove pairs with distance > cutoff + mask = ref_distances < cutoff + ref_neighbors = ref_neighbors[:, mask] + ref_distance_vecs = ref_distance_vecs[mask] + ref_distances = ref_distances[mask] + ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors( + ref_neighbors, ref_distance_vecs, ref_distances + ) + return ref_neighbors, ref_distance_vecs, ref_distances + + +@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")]) +@pytest.mark.parametrize("n_batches", [1, 2, 3, 4, 128]) +@pytest.mark.parametrize("cutoff", [0.1, 1.0, 3.0, 4.9]) +@pytest.mark.parametrize("loop", [True, False]) +@pytest.mark.parametrize("include_transpose", [True, False]) +@pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_neighbors( + device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype +): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if box_type == "triclinic" and strategy == "cell": + pytest.skip("Triclinic not supported for cell") + if device == "cpu" and strategy != "brute": + pytest.skip("Only brute force supported on CPU") + torch.manual_seed(4321) + n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) + batch = torch.repeat_interleave( + torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch + ).to(device) + cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) + lbox = 10.0 + pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox - 10.0*lbox + # Ensure there is at least one pair + pos[0, :] = torch.zeros(3) + pos[1, :] = torch.zeros(3) + pos.requires_grad = True + if box_type is None: + box = None + elif box_type == "rectangular": + box = ( + torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]) + .to(pos.dtype) + .to(device) + ) + elif box_type == "triclinic": + box = ( + torch.tensor([[lbox, 0.0, 0.0], [0.1, lbox, 0.0], [0.3, 0.2, lbox]]) + .to(pos.dtype) + .to(device) + ) + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors( + pos, batch, loop, include_transpose, cutoff, box + ) + max_num_pairs = ref_neighbors.shape[1] + nl = OptimizedDistance( + cutoff_lower=0.0, + loop=loop, + cutoff_upper=cutoff, + max_num_pairs=max_num_pairs, + strategy=strategy, + box=box, + return_vecs=True, + include_transpose=include_transpose, + ) + batch.to(device) + neighbors, distances, distance_vecs = nl(pos, batch) + neighbors = neighbors.cpu().detach().numpy() + distance_vecs = distance_vecs.cpu().detach().numpy() + distances = distances.cpu().detach().numpy() + neighbors, distance_vecs, distances = sort_neighbors( + neighbors, distance_vecs, distances + ) + assert neighbors.shape == (2, max_num_pairs) + assert distances.shape == (max_num_pairs,) + assert distance_vecs.shape == (max_num_pairs, 3) + + assert np.allclose(neighbors, ref_neighbors) + assert np.allclose(distances, ref_distances) + assert np.allclose(distance_vecs, ref_distance_vecs) + + +@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")]) +@pytest.mark.parametrize("n_batches", [1, 2, 3, 4]) +@pytest.mark.parametrize("cutoff", [0.1, 1.0, 1000.0]) +@pytest.mark.parametrize("loop", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_compatible_with_distance(device, strategy, n_batches, cutoff, loop, dtype): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if device == "cpu" and strategy != "brute": + pytest.skip("Only brute force supported on CPU") + + torch.manual_seed(4321) + n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) + batch = torch.repeat_interleave( + torch.arange(n_batches, dtype=torch.long), n_atoms_per_batch + ).to(device) + cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) + lbox = 10.0 + pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox + # Ensure there is at least one pair + pos[0, :] = torch.zeros(3) + pos[1, :] = torch.zeros(3) + pos.requires_grad = True + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors( + pos, batch, loop, True, cutoff, None + ) + # Find the particle appearing in the most pairs + max_num_neighbors = torch.max(torch.bincount(torch.tensor(ref_neighbors[0, :]))) + d = Distance( + cutoff_lower=0.0, + cutoff_upper=cutoff, + loop=loop, + max_num_neighbors=max_num_neighbors, + return_vecs=True, + ) + ref_neighbors, ref_distances, ref_distance_vecs = d(pos, batch) + ref_neighbors = ref_neighbors.cpu().detach().numpy() + ref_distance_vecs = ref_distance_vecs.cpu().detach().numpy() + ref_distances = ref_distances.cpu().detach().numpy() + ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors( + ref_neighbors, ref_distance_vecs, ref_distances + ) + + max_num_pairs = ref_neighbors.shape[1] + box = None + nl = OptimizedDistance( + cutoff_lower=0.0, + loop=loop, + cutoff_upper=cutoff, + max_num_pairs=max_num_pairs, + strategy=strategy, + box=box, + return_vecs=True, + include_transpose=True, + ) + neighbors, distances, distance_vecs = nl(pos, batch) + neighbors = neighbors.cpu().detach().numpy() + distance_vecs = distance_vecs.cpu().detach().numpy() + distances = distances.cpu().detach().numpy() + neighbors, distance_vecs, distances = sort_neighbors( + neighbors, distance_vecs, distances + ) + assert np.allclose(neighbors, ref_neighbors) + assert np.allclose(distances, ref_distances) + assert np.allclose(distance_vecs, ref_distance_vecs) + + +@pytest.mark.parametrize("strategy", ["brute", "cell", "shared"]) +@pytest.mark.parametrize("n_batches", [1, 2, 3, 4]) +def test_large_size(strategy, n_batches): + device = "cuda" + cutoff = 1.76 + loop = False + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + torch.manual_seed(4321) + num_atoms = int(32000 / n_batches) + n_atoms_per_batch = torch.ones(n_batches, dtype=torch.int64) * num_atoms + batch = torch.repeat_interleave( + torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch + ).to(device) + cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) + lbox = 45.0 + pos = torch.rand(cumsum[-1], 3, device=device) * lbox + # Ensure there is at least one pair + pos[0, :] = torch.zeros(3) + pos[1, :] = torch.zeros(3) + pos.requires_grad = True + # Find the particle appearing in the most pairs + max_num_neighbors = 64 + d = Distance( + cutoff_lower=0.0, + cutoff_upper=cutoff, + loop=loop, + max_num_neighbors=max_num_neighbors, + return_vecs=True, + ) + ref_neighbors, ref_distances, ref_distance_vecs = d(pos, batch) + ref_neighbors = ref_neighbors.cpu().detach().numpy() + ref_distance_vecs = ref_distance_vecs.cpu().detach().numpy() + ref_distances = ref_distances.cpu().detach().numpy() + ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors( + ref_neighbors, ref_distance_vecs, ref_distances + ) + + max_num_pairs = ref_neighbors.shape[1] + + # Must check without PBC since Distance does not support it + box = None + nl = OptimizedDistance( + cutoff_lower=0.0, + loop=loop, + cutoff_upper=cutoff, + max_num_pairs=max_num_pairs, + strategy=strategy, + box=box, + return_vecs=True, + include_transpose=True, + resize_to_fit=True, + ) + neighbors, distances, distance_vecs = nl(pos, batch) + neighbors = neighbors.cpu().detach().numpy() + distance_vecs = distance_vecs.cpu().detach().numpy() + distances = distances.cpu().detach().numpy() + neighbors, distance_vecs, distances = sort_neighbors( + neighbors, distance_vecs, distances + ) + assert np.allclose(neighbors, ref_neighbors) + assert np.allclose(distances, ref_distances) + assert np.allclose(distance_vecs, ref_distance_vecs) + + +@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")]) +@pytest.mark.parametrize("loop", [True, False]) +@pytest.mark.parametrize("include_transpose", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +@pytest.mark.parametrize("num_atoms", [1, 2, 3, 5, 100, 1000]) +@pytest.mark.parametrize("grad", ["deltas", "distances", "combined"]) +@pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"]) +def test_neighbor_grads( + device, strategy, loop, include_transpose, dtype, num_atoms, grad, box_type +): + if not torch.cuda.is_available() and device == "cuda": + pytest.skip("No GPU") + if device == "cpu" and strategy != "brute": + pytest.skip("Only brute force supported on CPU") + if box_type == "triclinic" and strategy == "cell": + pytest.skip("Triclinic only supported for brute force") + cutoff = 4.999999 + lbox = 10.0 + torch.random.manual_seed(1234) + np.random.seed(123456) + # Generate random positions + positions = 0.25 * lbox * torch.rand(num_atoms, 3, device=device, dtype=dtype) + if box_type is None: + box = None + else: + box = ( + torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]) + .to(dtype) + .to(device) + ) + # Compute reference values using pure pytorch + ref_neighbors = torch.vstack( + (torch.tril_indices(num_atoms, num_atoms, -1, device=device),) + ) + if include_transpose: + ref_neighbors = torch.hstack( + (ref_neighbors, torch.stack((ref_neighbors[1], ref_neighbors[0]))) + ) + if loop: + index = torch.arange(num_atoms, device=device) + ref_neighbors = torch.hstack((ref_neighbors, torch.stack((index, index)))) + ref_positions = positions.clone() + ref_positions.requires_grad_(True) + # Every pair is included, so there is no need to filter out pairs even after PBC + ref_deltas = ref_positions[ref_neighbors[0]] - ref_positions[ref_neighbors[1]] + if box is not None: + ref_box = box.clone() + ref_deltas -= torch.outer( + torch.round(ref_deltas[:, 2] / ref_box[2, 2]), ref_box[2] + ) + ref_deltas -= torch.outer( + torch.round(ref_deltas[:, 1] / ref_box[1, 1]), ref_box[1] + ) + ref_deltas -= torch.outer( + torch.round(ref_deltas[:, 0] / ref_box[0, 0]), ref_box[0] + ) + + if loop: + ref_distances = torch.zeros((ref_deltas.size(0),), device=device, dtype=dtype) + mask = ref_neighbors[0] != ref_neighbors[1] + ref_distances[mask] = torch.linalg.norm(ref_deltas[mask], dim=-1) + else: + ref_distances = torch.linalg.norm(ref_deltas, dim=-1) + max_num_pairs = max(ref_neighbors.shape[1], 1) + positions.requires_grad_(True) + nl = OptimizedDistance( + cutoff_upper=cutoff, + max_num_pairs=max_num_pairs, + strategy=strategy, + loop=loop, + include_transpose=include_transpose, + return_vecs=True, + resize_to_fit=True, + box=box, + ) + neighbors, distances, deltas = nl(positions) + # Check neighbor pairs are correct + ref_neighbors_sort, _, _ = sort_neighbors( + ref_neighbors.clone().cpu().detach().numpy(), + ref_deltas.clone().cpu().detach().numpy(), + ref_distances.clone().cpu().detach().numpy(), + ) + neighbors_sort, _, _ = sort_neighbors( + neighbors.clone().cpu().detach().numpy(), + deltas.clone().cpu().detach().numpy(), + distances.clone().cpu().detach().numpy(), + ) + assert np.allclose(ref_neighbors_sort, neighbors_sort) + + # Compute gradients + if grad == "deltas": + ref_deltas.sum().backward() + deltas.sum().backward() + elif grad == "distances": + ref_distances.sum().backward() + distances.sum().backward() + elif grad == "combined": + (ref_deltas.sum() + ref_distances.sum()).backward() + (deltas.sum() + distances.sum()).backward() + else: + raise ValueError("grad") + ref_pos_grad_sorted = ref_positions.grad.cpu().detach().numpy() + pos_grad_sorted = positions.grad.cpu().detach().numpy() + if dtype == torch.float32: + assert np.allclose(ref_pos_grad_sorted, pos_grad_sorted, atol=1e-2, rtol=1e-2) + else: + assert np.allclose(ref_pos_grad_sorted, pos_grad_sorted, atol=1e-8, rtol=1e-5) + + +@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")]) +@pytest.mark.parametrize("n_batches", [1, 128]) +@pytest.mark.parametrize("cutoff", [1.0]) +@pytest.mark.parametrize("loop", [True, False]) +@pytest.mark.parametrize("include_transpose", [True, False]) +@pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"]) +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_jit_script_compatible( + device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype +): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if box_type == "triclinic" and strategy == "cell": + pytest.skip("Triclinic only supported for brute force") + if device == "cpu" and strategy != "brute": + pytest.skip("Only brute force supported on CPU") + torch.manual_seed(4321) + n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) + batch = torch.repeat_interleave( + torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch + ).to(device) + cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) + lbox = 10.0 + pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox + # Ensure there is at least one pair + pos[0, :] = torch.zeros(3) + pos[1, :] = torch.zeros(3) + pos.requires_grad = True + if box_type is None: + box = None + else: + box = ( + torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]) + .to(pos.dtype) + .to(device) + ) + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors( + pos, batch, loop, include_transpose, cutoff, box + ) + max_num_pairs = ref_neighbors.shape[1] + nl = OptimizedDistance( + cutoff_lower=0.0, + loop=loop, + cutoff_upper=cutoff, + max_num_pairs=max_num_pairs, + strategy=strategy, + box=box, + return_vecs=True, + include_transpose=include_transpose, + ) + batch.to(device) + + nl = torch.jit.script(nl) + neighbors, distances, distance_vecs = nl(pos, batch) + neighbors = neighbors.cpu().detach().numpy() + distance_vecs = distance_vecs.cpu().detach().numpy() + distances = distances.cpu().detach().numpy() + neighbors, distance_vecs, distances = sort_neighbors( + neighbors, distance_vecs, distances + ) + assert neighbors.shape == (2, max_num_pairs) + assert distances.shape == (max_num_pairs,) + assert distance_vecs.shape == (max_num_pairs, 3) + + assert np.allclose(neighbors, ref_neighbors) + assert np.allclose(distances, ref_distances) + assert np.allclose(distance_vecs, ref_distance_vecs) + + +@pytest.mark.parametrize("device", ["cuda"]) +@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"]) +@pytest.mark.parametrize("n_batches", [1, 128]) +@pytest.mark.parametrize("cutoff", [1.0]) +@pytest.mark.parametrize("loop", [True, False]) +@pytest.mark.parametrize("include_transpose", [True, False]) +@pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"]) +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_cuda_graph_compatible( + device, strategy, n_batches, cutoff, loop, include_transpose, box_type, dtype +): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if box_type == "triclinic" and strategy == "cell": + pytest.skip("Triclinic only supported for brute force") + torch.manual_seed(4321) + n_atoms_per_batch = torch.randint(3, 100, size=(n_batches,)) + batch = torch.repeat_interleave( + torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch + ).to(device) + cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) + lbox = 10.0 + pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox + # Ensure there is at least one pair + pos[0, :] = torch.zeros(3) + pos[1, :] = torch.zeros(3) + pos.requires_grad = True + if box_type is None: + box = None + else: + box = ( + torch.tensor([[lbox, 0.0, 0.0], [0.0, lbox, 0.0], [0.0, 0.0, lbox]]) + .to(pos.dtype) + .to(device) + ) + ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors( + pos, batch, loop, include_transpose, cutoff, box + ) + max_num_pairs = ref_neighbors.shape[1] + nl = OptimizedDistance( + cutoff_lower=0.0, + loop=loop, + cutoff_upper=cutoff, + max_num_pairs=max_num_pairs, + strategy=strategy, + box=box, + return_vecs=True, + include_transpose=include_transpose, + check_errors=False, + resize_to_fit=False, + ) + batch.to(device) + + graph = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + # Warm up + with torch.cuda.stream(s): + for _ in range(10): + neighbors, distances, distance_vecs = nl(pos, batch) + torch.cuda.synchronize() + # Capture + with torch.cuda.graph(graph): + neighbors, distances, distance_vecs = nl(pos, batch) + neighbors.fill_(0) + graph.replay() + torch.cuda.synchronize() + + neighbors = neighbors.cpu().detach().numpy() + distance_vecs = distance_vecs.cpu().detach().numpy() + distances = distances.cpu().detach().numpy() + neighbors, distance_vecs, distances = sort_neighbors( + neighbors, distance_vecs, distances + ) + assert neighbors.shape == (2, max_num_pairs) + assert distances.shape == (max_num_pairs,) + assert distance_vecs.shape == (max_num_pairs, 3) + + assert np.allclose(neighbors, ref_neighbors) + assert np.allclose(distances, ref_distances) + assert np.allclose(distance_vecs, ref_distance_vecs) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index f094a537b..2ad1fd382 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -1,10 +1,12 @@ import math -from typing import Optional +from typing import Optional, Tuple import torch +from torch import Tensor from torch import nn import torch.nn.functional as F from torch_geometric.nn import MessagePassing from torch_cluster import radius_graph +import torchmdnet.neighbors as neighbors import warnings @@ -77,6 +79,164 @@ def message(self, x_j, W): return x_j * W +class OptimizedDistance(torch.nn.Module): + + def __init__( + self, + cutoff_lower=0.0, + cutoff_upper=5.0, + max_num_pairs=-32, + return_vecs=False, + loop=False, + strategy="brute", + include_transpose=True, + resize_to_fit=True, + check_errors=True, + box=None, + ): + super(OptimizedDistance, self).__init__() + """ Compute the neighbor list for a given cutoff. + This operation can be placed inside a CUDA graph in some cases. + In particular, resize_to_fit and check_errors must be False. + Note that this module returns neighbors such that distance(i,j) >= cutoff_lower and distance(i,j) < cutoff_upper. + This function optionally supports periodic boundary conditions with + arbitrary triclinic boxes. The box vectors `a`, `b`, and `c` must satisfy + certain requirements: + + `a[1] = a[2] = b[2] = 0` + `a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff` + `a[0] >= 2*b[0]` + `a[0] >= 2*c[0]` + `b[1] >= 2*c[1]` + + These requirements correspond to a particular rotation of the system and + reduced form of the vectors, as well as the requirement that the cutoff be + no larger than half the box width. + + Parameters + ---------- + cutoff_lower : float + Lower cutoff for the neighbor list. + cutoff_upper : float + Upper cutoff for the neighbor list. + max_num_pairs : int + Maximum number of pairs to store, if the number of pairs found is less than this, the list is padded with (-1,-1) pairs up to max_num_pairs unless resize_to_fit is True, in which case the list is resized to the actual number of pairs found. + If the number of pairs found is larger than this, the pairs are randomly sampled. When check_errors is True, an exception is raised in this case. + If negative, it is interpreted as (minus) the maximum number of neighbors per atom. + strategy : str + Strategy to use for computing the neighbor list. Can be one of + ["shared", "brute", "cell"]. + Shared: An O(N^2) algorithm that leverages CUDA shared memory, best for large number of particles. + Brute: A brute force O(N^2) algorithm, best for small number of particles. + Cell: A cell list algorithm, best for large number of particles, low cutoffs and low batch size. + box : torch.Tensor, optional + The vectors defining the periodic box. This must have shape `(3, 3)`, + where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`. + If this is omitted, periodic boundary conditions are not applied. + loop : bool, optional + Whether to include self-interactions. + Default: False + include_transpose : bool, optional + Whether to include the transpose of the neighbor list. + Default: True + resize_to_fit : bool, optional + Whether to resize the neighbor list to the actual number of pairs found. When False, the list is padded with (-1,-1) pairs up to max_num_pairs + Default: True + If this is True the operation is not CUDA graph compatible. + check_errors : bool, optional + Whether to check for too many pairs. If this is True the operation is not CUDA graph compatible. + Default: True + return_vecs : bool, optional + Whether to return the distance vectors. + Default: False + """ + self.cutoff_upper = cutoff_upper + self.cutoff_lower = cutoff_lower + self.max_num_pairs = max_num_pairs + self.strategy = strategy + self.box: Optional[Tensor] = box + self.loop = loop + self.return_vecs = return_vecs + self.include_transpose = include_transpose + self.resize_to_fit = resize_to_fit + self.use_periodic = True + if self.box is None: + self.use_periodic = False + self.box = torch.empty((0, 0)) + if self.strategy == "cell": + # Default the box to 3 times the cutoff, really inefficient for the cell list + lbox = cutoff_upper * 3.0 + self.box = torch.tensor([[lbox, 0, 0], [0, lbox, 0], [0, 0, lbox]]) + self.box = self.box.cpu() # All strategies expect the box to be in CPU memory + self._backends = neighbors.get_backends() + self.kernel = self._backends[self.strategy] + if self.kernel is None: + raise ValueError("Unknown strategy: {}".format(self.strategy)) + self.check_errors = check_errors + + def forward( + self, pos: Tensor, batch: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + """Compute the neighbor list for a given cutoff. + Parameters + ---------- + pos : torch.Tensor + shape (N, 3) + batch : torch.Tensor or None + shape (N,) + Returns + ------- + neighbors : torch.Tensor + List of neighbors for each atom in the batch. + shape (2, num_found_pairs or max_num_pairs) + distances : torch.Tensor + List of distances for each atom in the batch. + shape (num_found_pairs or max_num_pairs,) + distance_vecs : torch.Tensor + List of distance vectors for each atom in the batch. + shape (num_found_pairs or max_num_pairs, 3) + + If resize_to_fit is True, the tensors will be trimmed to the actual number of pairs found. + otherwise the tensors will have size max_num_pairs, with neighbor pairs (-1, -1) at the end. + + """ + self.box = self.box.to(pos.dtype) + max_pairs = self.max_num_pairs + if self.max_num_pairs < 0: + max_pairs = -self.max_num_pairs * pos.shape[0] + if batch is None: + batch = torch.zeros(pos.shape[0], dtype=torch.long, device=pos.device) + neighbors, distance_vecs, distances, num_pairs = self.kernel( + pos, + cutoff_lower=self.cutoff_lower, + cutoff_upper=self.cutoff_upper, + loop=self.loop, + batch=batch, + max_num_pairs=max_pairs, + include_transpose=self.include_transpose, + box_vectors=self.box, + use_periodic=self.use_periodic, + ) + if self.check_errors: + if num_pairs[0] > max_pairs: + raise RuntimeError( + "Found num_pairs({}) > max_num_pairs({})".format( + num_pairs[0], max_pairs + ) + ) + # Remove (-1,-1) pairs + if self.resize_to_fit: + mask = neighbors[0] != -1 + neighbors = neighbors[:, mask] + distances = distances[mask] + distance_vecs = distance_vecs[mask, :] + neighbors = neighbors.to(torch.long) + if self.return_vecs: + return neighbors, distances, distance_vecs + else: + return neighbors, distances, None + + class GaussianSmearing(nn.Module): def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=True): super(GaussianSmearing, self).__init__() diff --git a/torchmdnet/neighbors/__init__.py b/torchmdnet/neighbors/__init__.py new file mode 100644 index 000000000..b19c22606 --- /dev/null +++ b/torchmdnet/neighbors/__init__.py @@ -0,0 +1,26 @@ +import os +import torch as pt +from torch.utils import cpp_extension + + +def compile_extension(): + src_dir = os.path.dirname(__file__) + sources = ["neighbors.cpp", "neighbors_cpu.cpp"] + ( + ["neighbors_cuda.cu", "backwards.cu"] + if pt.cuda.is_available() + else [] + ) + sources = [os.path.join(src_dir, name) for name in sources] + cpp_extension.load(name="torchmdnet_neighbors", sources=sources, is_python_module=False) + + +def get_backends(): + compile_extension() + get_neighbor_pairs_brute = pt.ops.torchmdnet_neighbors.get_neighbor_pairs_brute + get_neighbor_pairs_shared = pt.ops.torchmdnet_neighbors.get_neighbor_pairs_shared + get_neighbor_pairs_cell = pt.ops.torchmdnet_neighbors.get_neighbor_pairs_cell + return { + "brute": get_neighbor_pairs_brute, + "cell": get_neighbor_pairs_cell, + "shared": get_neighbor_pairs_shared, + } diff --git a/torchmdnet/neighbors/backwards.cu b/torchmdnet/neighbors/backwards.cu new file mode 100644 index 000000000..4875d7e84 --- /dev/null +++ b/torchmdnet/neighbors/backwards.cu @@ -0,0 +1,59 @@ +/* Raul P. Pelaez 2023. Backwards pass for the CUDA neighbor list operation. + Computes the gradient of the positions with respect to the distances and deltas. + */ +#include "common.cuh" + +template +__global__ void +backward_kernel(const Accessor neighbors, const Accessor deltas, + const Accessor grad_deltas, const Accessor distances, + const Accessor grad_distances, Accessor grad_positions) { + const int32_t i_pair = blockIdx.x * blockDim.x + threadIdx.x; + const int32_t num_pairs = neighbors.size(1); + if (i_pair >= num_pairs){ + return; + } + const int32_t i_dir = blockIdx.y; + const int32_t i_atom = neighbors[i_dir][i_pair]; + const int32_t i_comp = blockIdx.z; + if (i_atom < 0) { + return; + } + const scalar_t grad_deltas_ = grad_deltas[i_pair][i_comp]; + const scalar_t dist = distances[i_pair]; + const scalar_t grad_distances_ = deltas[i_pair][i_comp] / dist * grad_distances[i_pair]; + // Handle self interaction + const scalar_t grad = + (i_dir ? -1 : 1) * + (i_atom == neighbors[1 - i_dir][i_pair] ? scalar_t(0.0) : (grad_deltas_ + grad_distances_)); + atomicAdd(&grad_positions[i_atom][i_comp], grad); +} + + +tensor_list common_backward(AutogradContext* ctx, const tensor_list &grad_inputs) { + const Tensor& grad_deltas = grad_inputs[1]; + const Tensor& grad_distances = grad_inputs[2]; + const int num_atoms = ctx->saved_data["num_atoms"].toInt(); + const int num_pairs = grad_distances.size(0); + const int num_threads = 128; + const int num_blocks_x = std::max((num_pairs + num_threads - 1) / num_threads, 1); + const dim3 blocks(num_blocks_x, 2, 3); + const auto stream = getCurrentCUDAStream(grad_distances.get_device()); + + const tensor_list data = ctx->get_saved_variables(); + const Tensor& neighbors = data[0]; + const Tensor& deltas = data[1]; + const Tensor& distances = data[2]; + const Tensor grad_positions = zeros({num_atoms, 3}, grad_distances.options()); + + AT_DISPATCH_FLOATING_TYPES(grad_distances.scalar_type(), "getNeighborPairs::backward", [&]() { + const CUDAStreamGuard guard(stream); + backward_kernel<<>>( + get_accessor(neighbors), get_accessor(deltas), + get_accessor(grad_deltas), get_accessor(distances), + get_accessor(grad_distances), get_accessor(grad_positions)); + }); + + return {grad_positions, Tensor(), Tensor(), Tensor(), Tensor(), + Tensor(), Tensor(), Tensor(), Tensor(), Tensor()}; +} diff --git a/torchmdnet/neighbors/common.cuh b/torchmdnet/neighbors/common.cuh new file mode 100644 index 000000000..dce5e6d4d --- /dev/null +++ b/torchmdnet/neighbors/common.cuh @@ -0,0 +1,213 @@ +/* Raul P. Pelaez 2023. Common utilities for the CUDA neighbor operation. + */ +#ifndef NEIGHBORS_COMMON_CUH +#define NEIGHBORS_COMMON_CUH +#include +#include +#include +#include + +using c10::cuda::CUDAStreamGuard; +using c10::cuda::getCurrentCUDAStream; +using torch::empty; +using torch::full; +using torch::kInt32; +using torch::Scalar; +using torch::Tensor; +using torch::TensorOptions; +using torch::zeros; +using torch::autograd::AutogradContext; +using torch::autograd::Function; +using torch::autograd::tensor_list; + +template +using Accessor = torch::PackedTensorAccessor32; + +template +inline Accessor get_accessor(const Tensor& tensor) { + return tensor.packed_accessor32(); +}; + +template __device__ __forceinline__ scalar_t sqrt_(scalar_t x){}; +template <> __device__ __forceinline__ float sqrt_(float x) { + return ::sqrtf(x); +}; +template <> __device__ __forceinline__ double sqrt_(double x) { + return ::sqrt(x); +}; + +template struct vec3 { + using type = void; +}; + +template <> struct vec3 { + using type = float3; +}; + +template <> struct vec3 { + using type = double3; +}; + +template using scalar3 = typename vec3::type; + +/* + * @brief Get the position of the i'th particle + * @param positions The positions tensor + * @param i The index of the particle + * @return The position of the i'th particle + */ +template +__device__ scalar3 fetchPosition(const Accessor positions, const int i) { + return {positions[i][0], positions[i][1], positions[i][2]}; +} + +struct PairList { + Tensor i_curr_pair; + Tensor neighbors; + Tensor deltas; + Tensor distances; + const bool loop, include_transpose, use_periodic; + PairList(int max_num_pairs, TensorOptions options, bool loop, bool include_transpose, + bool use_periodic) + : i_curr_pair(zeros({1}, options.dtype(torch::kInt))), + neighbors(full({2, max_num_pairs}, -1, options.dtype(torch::kInt))), + deltas(empty({max_num_pairs, 3}, options)), distances(full({max_num_pairs}, 0, options)), + loop(loop), include_transpose(include_transpose), use_periodic(use_periodic) { + } +}; + +template struct PairListAccessor { + Accessor i_curr_pair; + Accessor neighbors; + Accessor deltas; + Accessor distances; + bool loop, include_transpose, use_periodic; + explicit PairListAccessor(const PairList& pl) + : i_curr_pair(get_accessor(pl.i_curr_pair)), + neighbors(get_accessor(pl.neighbors)), + deltas(get_accessor(pl.deltas)), + distances(get_accessor(pl.distances)), loop(pl.loop), + include_transpose(pl.include_transpose), use_periodic(pl.use_periodic) { + } +}; + +template +__device__ void writeAtomPair(PairListAccessor& list, int i, int j, + scalar3 delta, scalar_t distance, int i_pair) { + list.neighbors[0][i_pair] = i; + list.neighbors[1][i_pair] = j; + list.deltas[i_pair][0] = delta.x; + list.deltas[i_pair][1] = delta.y; + list.deltas[i_pair][2] = delta.z; + list.distances[i_pair] = distance; +} + +template +__device__ void addAtomPairToList(PairListAccessor& list, int i, int j, + scalar3 delta, scalar_t distance, bool add_transpose) { + const int32_t i_pair = atomicAdd(&list.i_curr_pair[0], add_transpose ? 2 : 1); + // Neighbors after the max number of pairs are ignored, although the pair is counted + if (i_pair + add_transpose < list.neighbors.size(1)) { + writeAtomPair(list, i, j, delta, distance, i_pair); + if (add_transpose) { + writeAtomPair(list, j, i, {-delta.x, -delta.y, -delta.z}, distance, i_pair + 1); + } + } +} + +static void checkInput(const Tensor& positions, const Tensor& batch) { + TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); + TORCH_CHECK(positions.size(0) > 0, + "Expected the 1nd dimension size of \"positions\" to be more than 0"); + TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); + TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); + + TORCH_CHECK(batch.dim() == 1, "Expected \"batch\" to have one dimension"); + TORCH_CHECK(batch.size(0) == positions.size(0), + "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension " + "size of \"positions\""); + TORCH_CHECK(batch.is_contiguous(), "Expected \"batch\" to be contiguous"); + TORCH_CHECK(batch.dtype() == torch::kInt64, "Expected \"batch\" to be of type torch::kLong"); +} + +namespace rect { + +/* + * @brief Takes a point to the unit cell in the range [-0.5, 0.5]*box_size using Minimum Image + * Convention + * @param p The point position + * @param box_size The box size + * @return The point in the unit cell + */ +template +__device__ auto apply_pbc(scalar3 p, scalar3 box_size) { + p.x = p.x - floorf(p.x / box_size.x + scalar_t(0.5)) * box_size.x; + p.y = p.y - floorf(p.y / box_size.y + scalar_t(0.5)) * box_size.y; + p.z = p.z - floorf(p.z / box_size.z + scalar_t(0.5)) * box_size.z; + return p; +} + +template +__device__ auto compute_distance(scalar3 pos_i, scalar3 pos_j, + bool use_periodic, scalar3 box_size) { + scalar3 delta = {pos_i.x - pos_j.x, pos_i.y - pos_j.y, pos_i.z - pos_j.z}; + if (use_periodic) { + delta = apply_pbc(delta, box_size); + } + return delta; +} + +} // namespace rect + +namespace triclinic { +template struct Box { + scalar_t size[3][3]; + Box(const Tensor& box_vectors, bool use_periodic) { + if (use_periodic) { + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { + size[i][j] = box_vectors[i][j].item(); + } + } + } + } +}; +/* + * @brief Takes a point to the unit cell using Minimum Image + * Convention + * @param p The point position + * @param box_vectors The box vectors (3x3 matrix) + * @return The point in the unit cell + */ +template +__device__ auto apply_pbc(scalar3 delta, const Box& box) { + scalar_t scale3 = round(delta.z / box.size[2][2]); + delta.x -= scale3 * box.size[2][0]; + delta.y -= scale3 * box.size[2][1]; + delta.z -= scale3 * box.size[2][2]; + scalar_t scale2 = round(delta.y / box.size[1][1]); + delta.x -= scale2 * box.size[1][0]; + delta.y -= scale2 * box.size[1][1]; + scalar_t scale1 = round(delta.x / box.size[0][0]); + delta.x -= scale1 * box.size[0][0]; + return delta; +} + +template +__device__ auto compute_distance(scalar3 pos_i, scalar3 pos_j, + bool use_periodic, const Box& box) { + scalar3 delta = {pos_i.x - pos_j.x, pos_i.y - pos_j.y, pos_i.z - pos_j.z}; + if (use_periodic) { + delta = apply_pbc(delta, box); + } + return delta; +} + +} // namespace triclinic + +/* + * Backward pass for the CUDA neighbor list operation. + * Computes the gradient of the positions with respect to the distances and deltas. + */ +tensor_list common_backward(AutogradContext* ctx, const tensor_list& grad_inputs); +#endif diff --git a/torchmdnet/neighbors/neighbors.cpp b/torchmdnet/neighbors/neighbors.cpp new file mode 100644 index 000000000..5c407ef3c --- /dev/null +++ b/torchmdnet/neighbors/neighbors.cpp @@ -0,0 +1,7 @@ +#include + +TORCH_LIBRARY(torchmdnet_neighbors, m) { + m.def("get_neighbor_pairs_brute(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)"); + m.def("get_neighbor_pairs_shared(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)"); + m.def("get_neighbor_pairs_cell(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)"); +} diff --git a/torchmdnet/neighbors/neighbors_cpu.cpp b/torchmdnet/neighbors/neighbors_cpu.cpp new file mode 100644 index 000000000..dac6cfc69 --- /dev/null +++ b/torchmdnet/neighbors/neighbors_cpu.cpp @@ -0,0 +1,96 @@ +#include +#include + +using std::tuple; +using torch::arange; +using torch::div; +using torch::frobenius_norm; +using torch::full; +using torch::hstack; +using torch::index_select; +using torch::kInt32; +using torch::outer; +using torch::round; +using torch::Scalar; +using torch::Tensor; +using torch::vstack; +using torch::indexing::Slice; + +static tuple +forward(const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, bool use_periodic, + const Scalar& cutoff_lower, const Scalar& cutoff_upper, const Scalar& max_num_pairs, + bool loop, bool include_transpose) { + TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); + TORCH_CHECK(positions.size(0) > 0, + "Expected the 1nd dimension size of \"positions\" to be more than 0"); + TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); + TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); + + TORCH_CHECK(cutoff_upper.to() > 0, "Expected \"cutoff\" to be positive"); + if (use_periodic) { + TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions"); + TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, + "Expected \"box_vectors\" to have shape (3, 3)"); + double v[3][3]; + for (int i = 0; i < 3; i++) + for (int j = 0; j < 3; j++) + v[i][j] = box_vectors[i][j].item(); + double c = cutoff_upper.to(); + TORCH_CHECK(v[0][1] == 0, "Invalid box vectors: box_vectors[0][1] != 0"); + TORCH_CHECK(v[0][2] == 0, "Invalid box vectors: box_vectors[0][2] != 0"); + TORCH_CHECK(v[1][2] == 0, "Invalid box vectors: box_vectors[1][2] != 0"); + TORCH_CHECK(v[0][0] >= 2 * c, "Invalid box vectors: box_vectors[0][0] < 2*cutoff"); + TORCH_CHECK(v[1][1] >= 2 * c, "Invalid box vectors: box_vectors[1][1] < 2*cutoff"); + TORCH_CHECK(v[2][2] >= 2 * c, "Invalid box vectors: box_vectors[2][2] < 2*cutoff"); + TORCH_CHECK(v[0][0] >= 2 * v[1][0], + "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]"); + TORCH_CHECK(v[0][0] >= 2 * v[2][0], + "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]"); + TORCH_CHECK(v[1][1] >= 2 * v[2][1], + "Invalid box vectors: box_vectors[1][1] < 2*box_vectors[2][1]"); + } + TORCH_CHECK(max_num_pairs.toLong() > 0, "Expected \"max_num_neighbors\" to be positive"); + const int n_atoms = positions.size(0); + Tensor neighbors = torch::empty({0}, positions.options().dtype(kInt32)); + Tensor distances = torch::empty({0}, positions.options()); + Tensor deltas = torch::empty({0}, positions.options()); + neighbors = torch::vstack((torch::tril_indices(n_atoms, n_atoms, -1, neighbors.options()))); + auto mask = index_select(batch, 0, neighbors.index({0, Slice()})) == + index_select(batch, 0, neighbors.index({1, Slice()})); + neighbors = neighbors.index({Slice(), mask}).to(kInt32); + deltas = index_select(positions, 0, neighbors.index({0, Slice()})) - + index_select(positions, 0, neighbors.index({1, Slice()})); + if (use_periodic) { + deltas -= outer(round(deltas.index({Slice(), 2}) / box_vectors.index({2, 2})), + box_vectors.index({2})); + deltas -= outer(round(deltas.index({Slice(), 1}) / box_vectors.index({1, 1})), + box_vectors.index({1})); + deltas -= outer(round(deltas.index({Slice(), 0}) / box_vectors.index({0, 0})), + box_vectors.index({0})); + } + distances = frobenius_norm(deltas, 1); + mask = (distances < cutoff_upper) * (distances >= cutoff_lower); + neighbors = neighbors.index({Slice(), mask}); + deltas = deltas.index({mask, Slice()}); + distances = distances.index({mask}); + if (include_transpose) { + neighbors = torch::hstack({neighbors, torch::stack({neighbors[1], neighbors[0]})}); + distances = torch::hstack({distances, distances}); + deltas = torch::vstack({deltas, -deltas}); + } + if (loop) { + const Tensor range = torch::arange(0, n_atoms, torch::kInt32); + neighbors = torch::hstack({neighbors, torch::stack({range, range})}); + distances = torch::hstack({distances, torch::zeros_like(range)}); + deltas = torch::vstack({deltas, torch::zeros({n_atoms, 3}, deltas.options())}); + } + Tensor num_pairs_found = torch::empty(1, distances.options().dtype(kInt32)); + num_pairs_found[0] = distances.size(0); + return {neighbors, deltas, distances, num_pairs_found}; +} + +TORCH_LIBRARY_IMPL(torchmdnet_neighbors, CPU, m) { + m.impl("get_neighbor_pairs_brute", &forward); + m.impl("get_neighbor_pairs_shared", &forward); + m.impl("get_neighbor_pairs_cell", &forward); +} diff --git a/torchmdnet/neighbors/neighbors_cuda.cu b/torchmdnet/neighbors/neighbors_cuda.cu new file mode 100644 index 000000000..83f43d985 --- /dev/null +++ b/torchmdnet/neighbors/neighbors_cuda.cu @@ -0,0 +1,46 @@ +/* Raul P. Pelaez 2023 + Connection between the neighbor CUDA implementations and the torch extension. + See neighbors.cpp for the definition of the torch extension functions. + */ +#include +#include "neighbors_cuda_brute.cuh" +#include "neighbors_cuda_cell.cuh" +#include "neighbors_cuda_shared.cuh" + +TORCH_LIBRARY_IMPL(torchmdnet_neighbors, AutogradCUDA, m) { + m.impl("get_neighbor_pairs_brute", + [](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, + bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Scalar& max_num_pairs, bool loop, bool include_transpose) { + tensor_list results; + if (positions.size(0) >= 32768) { + // Revert to shared if there are too many particles, which brute can't handle + results = AutogradSharedCUDA::apply(positions, batch, cutoff_lower, cutoff_upper, + box_vectors, use_periodic, max_num_pairs, + loop, include_transpose); + } else { + results = AutogradBruteCUDA::apply(positions, batch, cutoff_lower, cutoff_upper, + box_vectors, use_periodic, max_num_pairs, + loop, include_transpose); + } + return std::make_tuple(results[0], results[1], results[2], results[3]); + }); + m.impl("get_neighbor_pairs_shared", + [](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, + bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Scalar& max_num_pairs, bool loop, bool include_transpose) { + const tensor_list results = AutogradSharedCUDA::apply( + positions, batch, cutoff_lower, cutoff_upper, box_vectors, use_periodic, + max_num_pairs, loop, include_transpose); + return std::make_tuple(results[0], results[1], results[2], results[3]); + }); + m.impl("get_neighbor_pairs_cell", + [](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors, + bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Scalar& max_num_pairs, bool loop, bool include_transpose) { + const tensor_list results = AutogradCellCUDA::apply( + positions, batch, box_vectors, use_periodic, cutoff_lower, cutoff_upper, + max_num_pairs, loop, include_transpose); + return std::make_tuple(results[0], results[1], results[2], results[3]); + }); +} diff --git a/torchmdnet/neighbors/neighbors_cuda_brute.cuh b/torchmdnet/neighbors/neighbors_cuda_brute.cuh new file mode 100644 index 000000000..098b6a88c --- /dev/null +++ b/torchmdnet/neighbors/neighbors_cuda_brute.cuh @@ -0,0 +1,114 @@ +/* Raul P. Pelaez 2023. Brute force neighbor list construction in CUDA. + + A brute force approach that assigns a thread per each possible pair of particles in the system. + Based on an implementation by Raimondas Galvelis. + Works fantastically for small (less than 10K atoms) systems, but cannot handle more than 32K atoms. + */ +#ifndef NEIGHBORS_BRUTE_CUH +#define NEIGHBORS_BRUTE_CUH +#include "common.cuh" +#include +#include + +__device__ uint32_t get_row(uint32_t index) { + uint32_t row = floor((sqrtf(8 * index + 1) + 1) / 2); + if (row * (row - 1) > 2 * index) + row--; + return row; +} + +template +__global__ void forward_kernel_brute(uint32_t num_all_pairs, const Accessor positions, + const Accessor batch, scalar_t cutoff_lower2, + scalar_t cutoff_upper2, PairListAccessor list, + triclinic::Box box) { + const uint32_t index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= num_all_pairs) + return; + const uint32_t row = get_row(index); + const uint32_t column = (index - row * (row - 1) / 2); + if (batch[row] == batch[column]) { + const auto pos_i = fetchPosition(positions, row); + const auto pos_j = fetchPosition(positions, column); + const auto delta = triclinic::compute_distance(pos_i, pos_j, list.use_periodic, box); + const scalar_t distance2 = delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; + if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) { + const scalar_t r2 = sqrt_(distance2); + addAtomPairToList(list, row, column, delta, r2, list.include_transpose); + } + } +} + +template +__global__ void add_self_kernel(const int num_atoms, Accessor positions, + PairListAccessor list) { + const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; + if (i_atom >= num_atoms) + return; + __shared__ int i_pair; + if (threadIdx.x == 0) { // Each block adds blockDim.x pairs to the list. + // Handle the last block, so that only num_atoms are added in total + i_pair = atomicAdd(&list.i_curr_pair[0], + min(blockDim.x, num_atoms - blockIdx.x * blockDim.x)); + } + __syncthreads(); + scalar3 delta{}; + scalar_t distance = 0; + writeAtomPair(list, i_atom, i_atom, delta, distance, i_pair + threadIdx.x); +} + +class AutogradBruteCUDA : public Function { +public: + static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, + const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Tensor& box_vectors, bool use_periodic, + const Scalar& max_num_pairs, bool loop, bool include_transpose) { + checkInput(positions, batch); + const auto max_num_pairs_ = max_num_pairs.toLong(); + TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); + if (use_periodic) { + TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions"); + TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, + "Expected \"box_vectors\" to have shape (3, 3)"); + } + TORCH_CHECK(box_vectors.device() == torch::kCPU, "Expected \"box_vectors\" to be on CPU"); + const int num_atoms = positions.size(0); + TORCH_CHECK(num_atoms < 32768, + "The brute strategy fails with \"num_atoms\" larger than 32768"); + const int num_pairs = max_num_pairs_; + const TensorOptions options = positions.options(); + const auto stream = getCurrentCUDAStream(positions.get_device()); + PairList list(num_pairs, positions.options(), loop, include_transpose, use_periodic); + const CUDAStreamGuard guard(stream); + const uint64_t num_all_pairs = num_atoms * (num_atoms - 1UL) / 2UL; + const uint64_t num_threads = 128; + const uint64_t num_blocks = + std::max((num_all_pairs + num_threads - 1UL) / num_threads, 1UL); + AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { + PairListAccessor list_accessor(list); + triclinic::Box box(box_vectors, use_periodic); + const scalar_t cutoff_upper_ = cutoff_upper.to(); + const scalar_t cutoff_lower_ = cutoff_lower.to(); + TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); + forward_kernel_brute<<>>( + num_all_pairs, get_accessor(positions), + get_accessor(batch), cutoff_lower_ * cutoff_lower_, + cutoff_upper_ * cutoff_upper_, list_accessor, box); + if (loop) { + const uint32_t num_threads = 256; + const uint32_t num_blocks = + std::max((num_atoms + num_threads - 1U) / num_threads, 1U); + add_self_kernel<<>>( + num_atoms, get_accessor(positions), list_accessor); + } + }); + ctx->save_for_backward({list.neighbors, list.deltas, list.distances}); + ctx->saved_data["num_atoms"] = num_atoms; + return {list.neighbors, list.deltas, list.distances, list.i_curr_pair}; + } + + static tensor_list backward(AutogradContext* ctx, const tensor_list& grad_inputs) { + return common_backward(ctx, grad_inputs); + } +}; +#endif diff --git a/torchmdnet/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/neighbors/neighbors_cuda_cell.cuh new file mode 100644 index 000000000..e0b49ed45 --- /dev/null +++ b/torchmdnet/neighbors/neighbors_cuda_cell.cuh @@ -0,0 +1,388 @@ +/* Raul P. Pelaez 2023. Batched cell list neighbor list implementation for CUDA. + + */ +#ifndef NEIGHBOR_CUDA_CELL_H +#define NEIGHBOR_CUDA_CELL_H +#include "common.cuh" + +/* + * @brief Calculates the cell dimensions for a given box size and cutoff + * @param box_size The box size + * @param cutoff The cutoff + * @return The cell dimensions + */ +template +__host__ __device__ int3 getCellDimensions(scalar3 box_size, scalar_t cutoff) { + int3 cell_dim = make_int3(box_size.x / cutoff, box_size.y / cutoff, box_size.z / cutoff); + // Minimum 3 cells in each dimension + cell_dim.x = max(cell_dim.x, 3); + cell_dim.y = max(cell_dim.y, 3); + cell_dim.z = max(cell_dim.z, 3); +// In the host, throw if there are more than 1024 cells in any dimension +#ifndef __CUDA_ARCH__ + if (cell_dim.x > 1024 || cell_dim.y > 1024 || cell_dim.z > 1024) { + throw std::runtime_error("Too many cells in one dimension. Maximum is 1024"); + } +#endif + return cell_dim; +} + +/* + * @brief Get the cell coordinates of a point + * @param p The point position + * @param box_size The size of the box in each dimension + * @param cutoff The cutoff + * @param cell_dim The number of cells in each dimension + * @return The cell coordinates + */ +template +__device__ int3 getCell(scalar3 p, scalar3 box_size, scalar_t cutoff, + int3 cell_dim) { + p = rect::apply_pbc(p, box_size); + // Take to the [0, box_size] range and divide by cutoff (which is the cell size) + int cx = floorf((p.x + scalar_t(0.5) * box_size.x) / cutoff); + int cy = floorf((p.y + scalar_t(0.5) * box_size.y) / cutoff); + int cz = floorf((p.z + scalar_t(0.5) * box_size.z) / cutoff); + // Wrap around. If the position of a particle is exactly box_size, it will be in the last cell, + // which results in an illegal access down the line. + if (cx == cell_dim.x) + cx = 0; + if (cy == cell_dim.y) + cy = 0; + if (cz == cell_dim.z) + cz = 0; + return make_int3(cx, cy, cz); +} + +/* + * @brief Get the index of a cell in a 1D array of cells. + * @param cell The cell coordinates, assumed to be in the range [0, cell_dim]. + * @param cell_dim The number of cells in each dimension + */ +__device__ int getCellIndex(int3 cell, int3 cell_dim) { + return cell.x + cell_dim.x * (cell.y + cell_dim.y * cell.z); +} + +/* + @brief Fold a cell coordinate to the range [0, cell_dim) + @param cell The cell coordinate + @param cell_dim The dimensions of the grid + @return The folded cell coordinate +*/ +__device__ int3 getPeriodicCell(int3 cell, int3 cell_dim) { + int3 periodic_cell = cell; + if (cell.x < 0) + periodic_cell.x += cell_dim.x; + if (cell.x >= cell_dim.x) + periodic_cell.x -= cell_dim.x; + if (cell.y < 0) + periodic_cell.y += cell_dim.y; + if (cell.y >= cell_dim.y) + periodic_cell.y -= cell_dim.y; + if (cell.z < 0) + periodic_cell.z += cell_dim.z; + if (cell.z >= cell_dim.z) + periodic_cell.z -= cell_dim.z; + return periodic_cell; +} + +// Computes and stores the cell index of each atom. +template +__global__ void assignCellIndex(const Accessor positions, + Accessor cell_indices, scalar3 box_size, + scalar_t cutoff) { + const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; + if (i_atom >= positions.size(0)) + return; + const auto pi = fetchPosition(positions, i_atom); + const auto cell_dim = getCellDimensions(box_size, cutoff); + const auto ci = getCell(pi, box_size, cutoff, cell_dim); + cell_indices[i_atom] = getCellIndex(ci, cell_dim); +} + +/* + * @brief Sort the positions by cell index + * @param positions The positions of the atoms + * @param box_size The box vectors + * @param cutoff The cutoff + * @return A tuple of the sorted indices and cell indices + */ +static auto sortAtomsByCellIndex(const Tensor& positions, const Tensor& box_size, + const Scalar& cutoff) { + const int num_atoms = positions.size(0); + Tensor cell_index = empty({num_atoms}, positions.options().dtype(torch::kInt32)); + const int threads = 128; + const int blocks = (num_atoms + threads - 1) / threads; + auto stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "assignHash", [&] { + scalar_t cutoff_ = cutoff.to(); + scalar3 box_size_ = {box_size[0][0].item(), + box_size[1][1].item(), + box_size[2][2].item()}; + assignCellIndex<<>>(get_accessor(positions), + get_accessor(cell_index), + box_size_, cutoff_); + }); + // Sort the atom indices by cell index + Tensor sorted_atom_index; + Tensor sorted_cell_index; + std::tie(sorted_cell_index, sorted_atom_index) = torch::sort(cell_index); + return std::make_tuple(sorted_atom_index.to(torch::kInt32), sorted_cell_index); +} + +__global__ void fillCellOffsetsD(const Accessor sorted_cell_indices, + Accessor cell_start, Accessor cell_end) { + // Since positions are sorted by cell, for a given atom, if the previous atom is in a different + // cell, then the current atom is the first atom in its cell We use this fact to fill the + // cell_start and cell_end arrays + const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; + if (i_atom >= sorted_cell_indices.size(0)) + return; + const int icell = sorted_cell_indices[i_atom]; + int im1_cell; + if (i_atom > 0) { + const int im1 = i_atom - 1; + im1_cell = sorted_cell_indices[im1]; + } else { + im1_cell = 0; + } + if (icell != im1_cell || i_atom == 0) { + cell_start[icell] = i_atom; + if (i_atom > 0) { + cell_end[im1_cell] = i_atom; + } + } + if (i_atom == sorted_cell_indices.size(0) - 1) { + cell_end[icell] = i_atom + 1; + } +} + +/* + @brief Fills the cell_start and cell_end arrays, identifying the first and last atom in each cell + @param sorted_cell_indices The cell indices of each position + @param cell_dim The dimensions of the cell grid + @return A tuple of cell_start and cell_end arrays +*/ +static auto fillCellOffsets(const Tensor& sorted_cell_indices, int3 cell_dim) { + const TensorOptions options = sorted_cell_indices.options(); + const int num_cells = cell_dim.x * cell_dim.y * cell_dim.z; + const Tensor cell_start = full({num_cells}, -1, options.dtype(torch::kInt)); + const Tensor cell_end = empty({num_cells}, options.dtype(torch::kInt)); + const int threads = 128; + const int blocks = (sorted_cell_indices.size(0) + threads - 1) / threads; + auto stream = at::cuda::getCurrentCUDAStream(); + fillCellOffsetsD<<>>(get_accessor(sorted_cell_indices), + get_accessor(cell_start), + get_accessor(cell_end)); + return std::make_tuple(cell_start, cell_end); +} + +/* + @brief Get the cell index of the i'th neighboring cell for a given cell + @param cell_i The cell coordinates + @param i The index of the neighboring cell, from 0 to 26 + @param cell_dim The dimensions of the cell grid + @return The cell index of the i'th neighboring cell +*/ +__device__ int getNeighborCellIndex(int3 cell_i, int i, int3 cell_dim) { + auto cell_j = cell_i; + cell_j.x += i % 3 - 1; + cell_j.y += (i / 3) % 3 - 1; + cell_j.z += i / 9 - 1; + cell_j = getPeriodicCell(cell_j, cell_dim); + const int icellj = getCellIndex(cell_j, cell_dim); + return icellj; +} + +template struct Particle { + int index; // Index in the sorted arrays + int original_index; // Index in the original arrays + int batch; + scalar3 position; + scalar_t cutoff_upper2, cutoff_lower2; +}; + +struct CellList { + Tensor cell_start, cell_end; + Tensor sorted_indices; + Tensor sorted_positions, sorted_batch; +}; + +CellList constructCellList(const Tensor& positions, const Tensor& batch, const Tensor& box_size, + const Scalar& cutoff) { + // The algorithm for the cell list construction can be summarized in three separate steps: + // 1. Label the particles according to the cell (bin) they lie in. + // 2. Sort the particles using the cell index as the ordering label + // (technically this is known as sorting by key). So that particles with positions + // lying in the same cell become contiguous in memory. + // 3. Identify where each cell starts and ends in the sorted particle positions + // array. + const TensorOptions options = positions.options(); + CellList cl; + Tensor sorted_cell_indices; + // Steps 1 and 2 + std::tie(cl.sorted_indices, sorted_cell_indices) = + sortAtomsByCellIndex(positions, box_size, cutoff); + cl.sorted_positions = positions.index_select(0, cl.sorted_indices); + cl.sorted_batch = batch.index_select(0, cl.sorted_indices); + // Step 3 + int3 cell_dim; + AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "computeCellDim", [&] { + scalar_t cutoff_ = cutoff.to(); + scalar3 box_size_ = {box_size[0][0].item(), + box_size[1][1].item(), + box_size[2][2].item()}; + cell_dim = getCellDimensions(box_size_, cutoff_); + }); + std::tie(cl.cell_start, cl.cell_end) = fillCellOffsets(sorted_cell_indices, cell_dim); + return cl; +} + +template struct CellListAccessor { + Accessor cell_start, cell_end; + Accessor sorted_indices; + Accessor sorted_positions; + Accessor sorted_batch; + + explicit CellListAccessor(const CellList& cl) + : cell_start(get_accessor(cl.cell_start)), + cell_end(get_accessor(cl.cell_end)), + sorted_indices(get_accessor(cl.sorted_indices)), + sorted_positions(get_accessor(cl.sorted_positions)), + sorted_batch(get_accessor(cl.sorted_batch)) { + } +}; + +/* + * @brief Add a pair of particles to the pair list. If necessary, also add the transpose pair. + * @param list The pair list + * @param i The index of the first particle + * @param j The index of the second particle + * @param distance2 The squared distance between the particles + * @param delta The vector between the particles + */ +template +__device__ void addNeighborPair(PairListAccessor& list, const int i, const int j, + scalar_t distance2, scalar3 delta) { + const bool requires_transpose = list.include_transpose and (j != i); + const int ni = max(i, j); + const int nj = min(i, j); + const scalar_t delta_sign = (ni == i) ? scalar_t(1.0) : scalar_t(-1.0); + const scalar_t distance = sqrt_(distance2); + delta = {delta_sign * delta.x, delta_sign * delta.y, delta_sign * delta.z}; + addAtomPairToList(list, ni, nj, delta, distance, requires_transpose); +} + +/* + * @brief Add to the pair list all neighbors of particle i_atom in cell j_cell + * @param i_atom The Information of the particle for which we are adding neighbors + * @param j_cell The index of the cell in which we are looking for neighbors + * @param cl The cell list + * @param box_size The box size + * @param list The pair list + */ +template +__device__ void addNeighborsForCell(const Particle& i_atom, int j_cell, + const CellListAccessor& cl, + scalar3 box_size, PairListAccessor& list) { + const auto first_particle = cl.cell_start[j_cell]; + if (first_particle != -1) { // Continue only if there are particles in this cell + const auto last_particle = cl.cell_end[j_cell]; + for (int cur_j = first_particle; cur_j < last_particle; cur_j++) { + const auto j_batch = cl.sorted_batch[cur_j]; + if ((j_batch == i_atom.batch) and + ((cur_j < i_atom.index) || (list.loop and cur_j == i_atom.index))) { + const auto position_j = fetchPosition(cl.sorted_positions, cur_j); + const auto delta = rect::compute_distance(i_atom.position, position_j, + list.use_periodic, box_size); + const scalar_t distance2 = + delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; + if ((distance2 < i_atom.cutoff_upper2 && distance2 >= i_atom.cutoff_lower2) or + (list.loop && cur_j == i_atom.index)) { + const int orj = cl.sorted_indices[cur_j]; + addNeighborPair(list, i_atom.original_index, orj, distance2, delta); + } // endif + } // endif + } // endfor + } // endif +} + +// Traverse the cell list for each atom and find the neighbors +template +__global__ void traverseCellList(const CellListAccessor cell_list, + PairListAccessor list, int num_atoms, + scalar3 box_size, scalar_t cutoff_lower, + scalar_t cutoff_upper) { + // Each atom traverses the cells around it and finds the neighbors + // Atoms for all batches are placed in the same cell list, but other batches are ignored while + // traversing + Particle i_atom; + i_atom.index = blockIdx.x * blockDim.x + threadIdx.x; + if (i_atom.index >= num_atoms) { + return; + } + i_atom.original_index = cell_list.sorted_indices[i_atom.index]; + i_atom.batch = cell_list.sorted_batch[i_atom.index]; + i_atom.position = fetchPosition(cell_list.sorted_positions, i_atom.index); + i_atom.cutoff_lower2 = cutoff_lower * cutoff_lower; + i_atom.cutoff_upper2 = cutoff_upper * cutoff_upper; + const int3 cell_dim = getCellDimensions(box_size, cutoff_upper); + const int3 cell_i = getCell(i_atom.position, box_size, cutoff_upper, cell_dim); + // Loop over the 27 cells around the current cell + for (int i = 0; i < 27; i++) { + const int neighbor_cell = getNeighborCellIndex(cell_i, i, cell_dim); + addNeighborsForCell(i_atom, neighbor_cell, cell_list, box_size, list); + } +} + +class AutogradCellCUDA : public Function { +public: + static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, + const Tensor& box_size, bool use_periodic, + const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Scalar& max_num_pairs, bool loop, bool include_transpose) { + // This module computes the pair list for a given set of particles, which may be in multiple + // batches. The strategy is to first compute a cell list for all particles, and then + // traverse the cell list for each particle to construct a pair list. + checkInput(positions, batch); + TORCH_CHECK(box_size.dim() == 2, "Expected \"box_size\" to have two dimensions"); + TORCH_CHECK(box_size.size(0) == 3 && box_size.size(1) == 3, + "Expected \"box_size\" to have shape (3, 3)"); + TORCH_CHECK(box_size[0][1].item() == 0 && box_size[0][2].item() == 0 && + box_size[1][0].item() == 0 && box_size[1][2].item() == 0 && + box_size[2][0].item() == 0 && box_size[2][1].item() == 0, + "Expected \"box_size\" to be diagonal"); + const auto max_num_pairs_ = max_num_pairs.toInt(); + TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); + const int num_atoms = positions.size(0); + const auto cell_list = constructCellList(positions, batch, box_size, cutoff_upper); + PairList list(max_num_pairs_, positions.options(), loop, include_transpose, use_periodic); + const auto stream = getCurrentCUDAStream(positions.get_device()); + { // Traverse the cell list to find the neighbors + const CUDAStreamGuard guard(stream); + AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "forward", [&] { + const scalar_t cutoff_upper_ = cutoff_upper.to(); + TORCH_CHECK(cutoff_upper_ > 0, "Expected cutoff_upper to be positive"); + const scalar_t cutoff_lower_ = cutoff_lower.to(); + const scalar3 box_size_ = {box_size[0][0].item(), + box_size[1][1].item(), + box_size[2][2].item()}; + PairListAccessor list_accessor(list); + CellListAccessor cell_list_accessor(cell_list); + const int threads = 128; + const int blocks = (num_atoms + threads - 1) / threads; + traverseCellList<<>>(cell_list_accessor, list_accessor, + num_atoms, box_size_, + cutoff_lower_, cutoff_upper_); + }); + } + ctx->save_for_backward({list.neighbors, list.deltas, list.distances}); + ctx->saved_data["num_atoms"] = num_atoms; + return {list.neighbors, list.deltas, list.distances, list.i_curr_pair}; + } + + static tensor_list backward(AutogradContext* ctx, const tensor_list& grad_inputs) { + return common_backward(ctx, grad_inputs); + } +}; +#endif diff --git a/torchmdnet/neighbors/neighbors_cuda_shared.cuh b/torchmdnet/neighbors/neighbors_cuda_shared.cuh new file mode 100644 index 000000000..aebf91cd4 --- /dev/null +++ b/torchmdnet/neighbors/neighbors_cuda_shared.cuh @@ -0,0 +1,118 @@ +/* Raul P. Pelaez 2023. Shared memory neighbor list construction for CUDA. + This brute force approach checks all pairs of atoms by collaborativelly loading and processing + tiles of atoms into shared memory. + This approach is tipically slower than the brute force approach, but can handle an arbitrarily + large number of atoms. + */ +#ifndef NEIGHBORS_SHARED_CUH +#define NEIGHBORS_SHARED_CUH +#include "common.cuh" +#include +#include + +template +__global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor positions, + const Accessor batch, scalar_t cutoff_lower2, + scalar_t cutoff_upper2, PairListAccessor list, + int32_t num_tiles, triclinic::Box box) { + // A thread per atom + const int id = blockIdx.x * blockDim.x + threadIdx.x; + // All threads must pass through __syncthreads, + // but when N is not a multiple of 32 some threads are assigned a particle i>N. + // This threads cant return, so they are masked to not do any work + const bool active = id < num_atoms; + __shared__ scalar3 sh_pos[BLOCKSIZE]; + __shared__ int64_t sh_batch[BLOCKSIZE]; + scalar3 pos_i; + int64_t batch_i; + if (active) { + pos_i = fetchPosition(positions, id); + batch_i = batch[id]; + } + // Distribute the N particles in a group of tiles. Storing in each tile blockDim.x values in + // shared memory. This way all threads are accesing the same memory addresses at the same time + for (int tile = 0; tile < num_tiles; tile++) { + // Load this tiles particles values to shared memory + const int i_load = tile * blockDim.x + threadIdx.x; + if (i_load < num_atoms) { // Even if im not active, my thread may load a value each tile to + // shared memory. + sh_pos[threadIdx.x] = fetchPosition(positions, i_load); + sh_batch[threadIdx.x] = batch[i_load]; + } + // Wait for all threads to arrive + __syncthreads(); + // Go through all the particles in the current tile +#pragma unroll 8 + for (int counter = 0; counter < blockDim.x; counter++) { + if (!active) + break; // An out of bounds thread must be masked + const int cur_j = tile * blockDim.x + counter; + const bool testPair = cur_j < num_atoms and (cur_j < id or (list.loop and cur_j == id)); + if (testPair) { + const auto batch_j = sh_batch[counter]; + if (batch_i == batch_j) { + const auto pos_j = sh_pos[counter]; + const auto delta = + triclinic::compute_distance(pos_i, pos_j, list.use_periodic, box); + const scalar_t distance2 = + delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; + if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) { + const bool requires_transpose = list.include_transpose && !(cur_j == id); + const auto distance = sqrt_(distance2); + addAtomPairToList(list, id, cur_j, delta, distance, requires_transpose); + } + } + } + } + __syncthreads(); + } +} + +class AutogradSharedCUDA : public Function { +public: + static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Tensor& batch, + const Scalar& cutoff_lower, const Scalar& cutoff_upper, + const Tensor& box_vectors, bool use_periodic, + const Scalar& max_num_pairs, bool loop, bool include_transpose) { + checkInput(positions, batch); + const auto max_num_pairs_ = max_num_pairs.toLong(); + TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); + if (use_periodic) { + TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions"); + TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, + "Expected \"box_vectors\" to have shape (3, 3)"); + } + TORCH_CHECK(box_vectors.device() == torch::kCPU, "Expected \"box_vectors\" to be on CPU"); + const int num_atoms = positions.size(0); + const int num_pairs = max_num_pairs_; + const TensorOptions options = positions.options(); + const auto stream = getCurrentCUDAStream(positions.get_device()); + PairList list(num_pairs, positions.options(), loop, include_transpose, use_periodic); + const CUDAStreamGuard guard(stream); + AT_DISPATCH_FLOATING_TYPES( + positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() { + const scalar_t cutoff_upper_ = cutoff_upper.to(); + const scalar_t cutoff_lower_ = cutoff_lower.to(); + triclinic::Box box(box_vectors, use_periodic); + TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); + constexpr int BLOCKSIZE = 64; + const int num_blocks = std::max((num_atoms + BLOCKSIZE - 1) / BLOCKSIZE, 1); + const int num_threads = BLOCKSIZE; + const int num_tiles = num_blocks; + PairListAccessor list_accessor(list); + forward_kernel_shared<<>>( + num_atoms, get_accessor(positions), + get_accessor(batch), cutoff_lower_ * cutoff_lower_, + cutoff_upper_ * cutoff_upper_, list_accessor, num_tiles, box); + }); + ctx->save_for_backward({list.neighbors, list.deltas, list.distances}); + ctx->saved_data["num_atoms"] = num_atoms; + return {list.neighbors, list.deltas, list.distances, list.i_curr_pair}; + } + + static tensor_list backward(AutogradContext* ctx, const tensor_list& grad_inputs) { + return common_backward(ctx, grad_inputs); + } +}; + +#endif