diff --git a/csrc/cpu/fps_cpu.cpp b/csrc/cpu/fps_cpu.cpp index 263f8912..a60dd7d0 100644 --- a/csrc/cpu/fps_cpu.cpp +++ b/csrc/cpu/fps_cpu.cpp @@ -35,7 +35,7 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio, int64_t start_idx = 0; if (random_start) { - start_idx = rand() % src.size(0); + start_idx = rand() % y.size(0); } out_data[out_start] = src_start + start_idx; diff --git a/csrc/cuda/fps_cuda.cu b/csrc/cuda/fps_cuda.cu index eb3bc4c5..59f8a80e 100644 --- a/csrc/cuda/fps_cuda.cu +++ b/csrc/cuda/fps_cuda.cu @@ -4,7 +4,7 @@ #include "utils.cuh" -#define THREADS 1024 +#define THREADS 256 template __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr, @@ -31,15 +31,15 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr, int64_t best_idx = 0; for (int64_t n = start_idx + thread_idx; n < end_idx; n += THREADS) { - scalar_t tmp; - scalar_t dd = (scalar_t)0.; + scalar_t tmp, dd = (scalar_t)0.; for (int64_t d = 0; d < dim; d++) { tmp = src[dim * old + d] - src[dim * n + d]; dd += tmp * tmp; } - dist[n] = min(dist[n], dd); - if (dist[n] > best) { - best = dist[n]; + dd = min(dist[n], dd); + dist[n] = dd; + if (dd > best) { + best = dd; best_idx = n; } } diff --git a/setup.py b/setup.py index 5666ba05..1af96e87 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ def get_extensions(): setup( name='torch_cluster', - version='1.5.2', + version='1.5.3', author='Matthias Fey', author_email='matthias.fey@tu-dortmund.de', url='https://github.com/rusty1s/pytorch_cluster', diff --git a/test/test_fps.py b/test/test_fps.py index 5416bd50..9eb44291 100644 --- a/test/test_fps.py +++ b/test/test_fps.py @@ -26,3 +26,15 @@ def test_fps(dtype, device): out = fps(x, ratio=0.5, random_start=False) assert out.sort()[0].tolist() == [0, 5, 6, 7] + + +@pytest.mark.parametrize('device', devices) +def test_random_fps(device): + N = 1024 + for _ in range(5): + pos = torch.randn((2 * N, 3), device=device) + batch_1 = torch.zeros(N, dtype=torch.long, device=device) + batch_2 = torch.ones(N, dtype=torch.long, device=device) + batch = torch.cat([batch_1, batch_2]) + idx = fps(pos, batch, ratio=0.5) + assert idx.min() >= 0 and idx.max() < 2 * N diff --git a/torch_cluster/__init__.py b/torch_cluster/__init__.py index f5166e67..9a870cb6 100644 --- a/torch_cluster/__init__.py +++ b/torch_cluster/__init__.py @@ -3,7 +3,7 @@ import torch -__version__ = '1.5.2' +__version__ = '1.5.3' expected_torch_version = (1, 4) try: