diff --git a/hippynn/graphs/nodes/pairs.py b/hippynn/graphs/nodes/pairs.py index ae2d6fb7..e209449f 100644 --- a/hippynn/graphs/nodes/pairs.py +++ b/hippynn/graphs/nodes/pairs.py @@ -68,7 +68,23 @@ def __init__(self, name, parents, dist_hard_max, module="auto", module_kwargs=No parents = self.expand_parents(parents) super().__init__(name, parents, module=module, **kwargs) -class PeriodicPairIndexerMemory(PeriodicPairIndexer): +class Memory: + @property + def skin(self): + return self.torch_module.skin + + @skin.setter + def skin(self, skin): + self.torch_module.skin = skin + + @property + def reuse_percentage(self): + return self.torch_module.reuse_percentage + + def reset_reuse_percentage(self): + self.torch_module.reset_reuse_percentage() + +class PeriodicPairIndexerMemory(PeriodicPairIndexer, Memory): ''' Implementation of PeriodicPairIndexer with additional memory component. @@ -86,9 +102,9 @@ class PeriodicPairIndexerMemory(PeriodicPairIndexer): def __init__(self, name, parents, dist_hard_max, skin, module="auto", module_kwargs=None, **kwargs): if module_kwargs is None: module_kwargs = {} - module_kwargs = {"skin": skin, **module_kwargs} + self.expand0module_kwargs = {"skin": skin, **module_kwargs} - super().__init__(name, parents, dist_hard_max, module=module, module_kwargs=module_kwargs, **kwargs) + super().__init__(name, parents, dist_hard_max, module=module, module_kwargs=self.module_kwargs, **kwargs) class ExternalNeighborIndexer(ExpandParents, PairIndexer, AutoKw, MultiNode): @@ -379,7 +395,7 @@ class KDTreePairs(_DispatchNeighbors): ''' _auto_module_class = pairs_modules.dispatch.KDTreeNeighbors -class KDTreePairsMemory(_DispatchNeighbors): +class KDTreePairsMemory(_DispatchNeighbors, Memory): ''' Implementation of KDTreePairs with an added memory component. diff --git a/hippynn/layers/indexers.py b/hippynn/layers/indexers.py index c3d6d1f9..b19cfc45 100644 --- a/hippynn/layers/indexers.py +++ b/hippynn/layers/indexers.py @@ -2,6 +2,7 @@ Layers for encoding, decoding, index states, besides pairs """ +import warnings import torch @@ -249,7 +250,23 @@ def __init__(self, length, vmin, vmax): self.bins = torch.nn.Parameter(torch.linspace(vmin, vmax, length), requires_grad=False) self.sigma = (vmax - vmin) / length + self.vmin = vmin + self.vmax = vmax + def forward(self, values): + # Warn user if provided values lie outside the range of the histogram bins + values_out_of_range = (values < self.vmin) + (values > self.vmax) + + if values_out_of_range.sum() > 0: + perc_out_of_range = values_out_of_range.float().mean() + warnings.warn( + "Values out of range for FuzzyHistogrammer\n" + f"Number of values out of range: {values_out_of_range.sum()}\n" + f"Percentage of values out of range: {perc_out_of_range * 100:.2f}%\n" + f"Set range for FuzzyHistogrammer: {(self.vmin, self.vmax)}\n" + f"Range of values: ({values.min().item():.2f}, {values.max().item():.2f})" + ) + if values.shape[-1] != 1: values = values[...,None] x = values - self.bins diff --git a/hippynn/layers/pairs/dispatch.py b/hippynn/layers/pairs/dispatch.py index f034a516..4f43aba7 100644 --- a/hippynn/layers/pairs/dispatch.py +++ b/hippynn/layers/pairs/dispatch.py @@ -159,6 +159,8 @@ def neighbor_list_kdtree(cutoff, coords, cell): new_cell = cell.clone() new_coords = coords.clone() + new_coords = new_coords % torch.diag(new_cell) # KD Tree will not work if positions are outside of periodic box + # Find pair indices tree = KDTree( data=new_coords.detach().cpu().numpy(),