Skip to content

Commit

Permalink
fix coordinate wrapping issue in KDTree paifinders (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
shinkle-lanl authored Jun 4, 2024
1 parent 8d5fd8a commit 0796fe4
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions hippynn/layers/pairs/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def neighbor_list_torch(cutoff: float, coords, cell):
def neighbor_list_kdtree(cutoff, coords, cell):
'''
Use KD Tree implementation from scipy.spatial to find pairs under periodic boundary conditions
with an orthonormal cell.
with an orthorhombic cell.
'''

# Verify that cell is orthorhombic
Expand Down Expand Up @@ -180,10 +180,16 @@ def neighbor_list_kdtree(cutoff, coords, cell):
pairs = torch.as_tensor(pairs, device=coords.device)
pair_first, pair_second = torch.unbind(pairs, dim=1)

# Wrap coordinates into cell and keep track of how they were translated
inv_cell = torch.linalg.inv(cell)
coords, wrapped_offset = wrap_points_torch(coords, cell, inv_cell)

# Find difference vector between pairs without considering the MIC
pair_diff = torch.sub(coords[pair_first], coords[pair_second])

# Possible adjacent offset directions for images of the difference vector
# More is not needed because of the restriction that the cutoff is less than the length of
# each side of the cell
offset_range = torch.tensor(list(product([-1, 0, 1], repeat=3)), device=coords.device)

# All adjacent offsets
Expand All @@ -198,8 +204,9 @@ def neighbor_list_kdtree(cutoff, coords, cell):
# Index of shortest offset image
pair_diff = torch.argmin(pair_diff, dim=1)

# Offset direction corresponding to shortest offset image
# Offset direction corresponding to shortest offset image plus accounting for the wrapping done earlier
pair_image = offset_range[pair_diff]
pair_image -= (wrapped_offset[pair_first] - wrapped_offset[pair_second])

# KDTree only returns each pair once (eg. (1,2) but not (2,1))
doubled_pair_first = torch.concat((pair_first, pair_second))
Expand Down

0 comments on commit 0796fe4

Please sign in to comment.