Skip to content

Commit

Permalink
add user warning and position wrapping to KDTreePairs+, attribute acc…
Browse files Browse the repository at this point in the history
…ess to Memory pair-finders
  • Loading branch information
shinkle-lanl committed Dec 13, 2023
1 parent c76e83b commit 7d06f91
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
24 changes: 20 additions & 4 deletions hippynn/graphs/nodes/pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,23 @@ def __init__(self, name, parents, dist_hard_max, module="auto", module_kwargs=No
parents = self.expand_parents(parents)
super().__init__(name, parents, module=module, **kwargs)

class PeriodicPairIndexerMemory(PeriodicPairIndexer):
class Memory:
@property
def skin(self):
return self.torch_module.skin

@skin.setter
def skin(self, skin):
self.torch_module.skin = skin

@property
def reuse_percentage(self):
return self.torch_module.reuse_percentage

def reset_reuse_percentage(self):
self.torch_module.reset_reuse_percentage()

class PeriodicPairIndexerMemory(PeriodicPairIndexer, Memory):
'''
Implementation of PeriodicPairIndexer with additional memory component.
Expand All @@ -86,9 +102,9 @@ class PeriodicPairIndexerMemory(PeriodicPairIndexer):
def __init__(self, name, parents, dist_hard_max, skin, module="auto", module_kwargs=None, **kwargs):
if module_kwargs is None:
module_kwargs = {}
module_kwargs = {"skin": skin, **module_kwargs}
self.expand0module_kwargs = {"skin": skin, **module_kwargs}

super().__init__(name, parents, dist_hard_max, module=module, module_kwargs=module_kwargs, **kwargs)
super().__init__(name, parents, dist_hard_max, module=module, module_kwargs=self.module_kwargs, **kwargs)


class ExternalNeighborIndexer(ExpandParents, PairIndexer, AutoKw, MultiNode):
Expand Down Expand Up @@ -379,7 +395,7 @@ class KDTreePairs(_DispatchNeighbors):
'''
_auto_module_class = pairs_modules.dispatch.KDTreeNeighbors

class KDTreePairsMemory(_DispatchNeighbors):
class KDTreePairsMemory(_DispatchNeighbors, Memory):
'''
Implementation of KDTreePairs with an added memory component.
Expand Down
17 changes: 17 additions & 0 deletions hippynn/layers/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Layers for encoding, decoding, index states, besides pairs
"""

import warnings
import torch


Expand Down Expand Up @@ -249,7 +250,23 @@ def __init__(self, length, vmin, vmax):
self.bins = torch.nn.Parameter(torch.linspace(vmin, vmax, length), requires_grad=False)
self.sigma = (vmax - vmin) / length

self.vmin = vmin
self.vmax = vmax

def forward(self, values):
# Warn user if provided values lie outside the range of the histogram bins
values_out_of_range = (values < self.vmin) + (values > self.vmax)

if values_out_of_range.sum() > 0:
perc_out_of_range = values_out_of_range.float().mean()
warnings.warn(
"Values out of range for FuzzyHistogrammer\n"
f"Number of values out of range: {values_out_of_range.sum()}\n"
f"Percentage of values out of range: {perc_out_of_range * 100:.2f}%\n"
f"Set range for FuzzyHistogrammer: {(self.vmin, self.vmax)}\n"
f"Range of values: ({values.min().item():.2f}, {values.max().item():.2f})"
)

if values.shape[-1] != 1:
values = values[...,None]
x = values - self.bins
Expand Down
2 changes: 2 additions & 0 deletions hippynn/layers/pairs/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def neighbor_list_kdtree(cutoff, coords, cell):
new_cell = cell.clone()
new_coords = coords.clone()

new_coords = new_coords % torch.diag(new_cell) # KD Tree will not work if positions are outside of periodic box

# Find pair indices
tree = KDTree(
data=new_coords.detach().cpu().numpy(),
Expand Down

0 comments on commit 7d06f91

Please sign in to comment.