Skip to content

Commit

Permalink
fix fps implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 24, 2020
1 parent aff91e0 commit 9c33077
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 9 deletions.
2 changes: 1 addition & 1 deletion csrc/cpu/fps_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio,

int64_t start_idx = 0;
if (random_start) {
start_idx = rand() % src.size(0);
start_idx = rand() % y.size(0);
}

out_data[out_start] = src_start + start_idx;
Expand Down
12 changes: 6 additions & 6 deletions csrc/cuda/fps_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include "utils.cuh"

#define THREADS 1024
#define THREADS 256

template <typename scalar_t>
__global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
Expand All @@ -31,15 +31,15 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
int64_t best_idx = 0;

for (int64_t n = start_idx + thread_idx; n < end_idx; n += THREADS) {
scalar_t tmp;
scalar_t dd = (scalar_t)0.;
scalar_t tmp, dd = (scalar_t)0.;
for (int64_t d = 0; d < dim; d++) {
tmp = src[dim * old + d] - src[dim * n + d];
dd += tmp * tmp;
}
dist[n] = min(dist[n], dd);
if (dist[n] > best) {
best = dist[n];
dd = min(dist[n], dd);
dist[n] = dd;
if (dd > best) {
best = dd;
best_idx = n;
}
}
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_extensions():

setup(
name='torch_cluster',
version='1.5.2',
version='1.5.3',
author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de',
url='https://github.com/rusty1s/pytorch_cluster',
Expand Down
12 changes: 12 additions & 0 deletions test/test_fps.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,15 @@ def test_fps(dtype, device):

out = fps(x, ratio=0.5, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]


@pytest.mark.parametrize('device', devices)
def test_random_fps(device):
N = 1024
for _ in range(5):
pos = torch.randn((2 * N, 3), device=device)
batch_1 = torch.zeros(N, dtype=torch.long, device=device)
batch_2 = torch.ones(N, dtype=torch.long, device=device)
batch = torch.cat([batch_1, batch_2])
idx = fps(pos, batch, ratio=0.5)
assert idx.min() >= 0 and idx.max() < 2 * N
2 changes: 1 addition & 1 deletion torch_cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

__version__ = '1.5.2'
__version__ = '1.5.3'
expected_torch_version = (1, 4)

try:
Expand Down

0 comments on commit 9c33077

Please sign in to comment.