diff --git a/hippynn/layers/pairs/dispatch.py b/hippynn/layers/pairs/dispatch.py index ea5a5f43..d4cce4a5 100644 --- a/hippynn/layers/pairs/dispatch.py +++ b/hippynn/layers/pairs/dispatch.py @@ -138,7 +138,7 @@ def neighbor_list_torch(cutoff: float, coords, cell): def neighbor_list_kdtree(cutoff, coords, cell): ''' Use KD Tree implementation from scipy.spatial to find pairs under periodic boundary conditions - with an orthonormal cell. + with an orthorhombic cell. ''' # Verify that cell is orthorhombic @@ -180,10 +180,16 @@ def neighbor_list_kdtree(cutoff, coords, cell): pairs = torch.as_tensor(pairs, device=coords.device) pair_first, pair_second = torch.unbind(pairs, dim=1) + # Wrap coordinates into cell and keep track of how they were translated + inv_cell = torch.linalg.inv(cell) + coords, wrapped_offset = wrap_points_torch(coords, cell, inv_cell) + # Find difference vector between pairs without considering the MIC pair_diff = torch.sub(coords[pair_first], coords[pair_second]) # Possible adjacent offset directions for images of the difference vector + # More is not needed because of the restriction that the cutoff is less than the length of + # each side of the cell offset_range = torch.tensor(list(product([-1, 0, 1], repeat=3)), device=coords.device) # All adjacent offsets @@ -198,8 +204,9 @@ def neighbor_list_kdtree(cutoff, coords, cell): # Index of shortest offset image pair_diff = torch.argmin(pair_diff, dim=1) - # Offset direction corresponding to shortest offset image + # Offset direction corresponding to shortest offset image plus accounting for the wrapping done earlier pair_image = offset_range[pair_diff] + pair_image -= (wrapped_offset[pair_first] - wrapped_offset[pair_second]) # KDTree only returns each pair once (eg. (1,2) but not (2,1)) doubled_pair_first = torch.concat((pair_first, pair_second))