Skip to content

Commit

Permalink
localpdistidns: better performance and gpu support through batching
Browse files Browse the repository at this point in the history
  • Loading branch information
axsk committed Aug 26, 2024
1 parent af3e998 commit 5dd1142
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions src/pairdists.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,9 @@ Given `coords` of shape ( 3n x frames ) return the pairs of indices whose minima
"""
function localpdistinds(coords::AbstractMatrix, radius)
traj = reshape(coords, 3, :, size(coords, 2))
elmin(x, y) = min.(x, y)
d = mapreduce(elmin, eachslice(traj, dims=3)) do coords
UpperTriangular(pairwise(Euclidean(), coords, dims=2))
end
pairs = Tuple.(findall(0 .< d .<= radius))
return pairs
ds = sqpairdist(traj)
mds = dropdims(minimum(ds, dims=3), dims=3)
pairs = Tuple.(findall(0 .< UpperTriangular(mds) .<= radius^2))
end

localpdistinds(coords::Vector, radius) = localpdistinds(hcat(coords), radius)
Expand Down

0 comments on commit 5dd1142

Please sign in to comment.