Skip to content

Commit

Permalink
Updated Cusp Regularization (#59)
Browse files Browse the repository at this point in the history
* Larger Cusp Regularization

The regularization term for the hipnnvec and hipnnquad are changed to `1e-6`.

* add cusp_reg kwarg for hipnn class.

The new default cusp regularization term is larger to avoid numerical issues with force calculations.

* add cusp_reg for interact layers

The csup regularization is only needed for InteractLayerQuad and InteractLayerVec.

* backward compatibility for cusp regularization

The backwards compatibility function mutates the reloaded model graph in place. This function can be updated in the future based on which older versions of hippynn to support.

* Passing cusp_reg to Hipnn base class

* passing cusp_reg args

* Move backwards compatbility hooks to the affected layer

* apply code formatter

* remove unused imports

---------

Co-authored-by: Nicholas Lubbers <hippynn@lanl.gov>
Co-authored-by: Nicholas Lubbers <56895592+lubbersnick@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 9, 2024
1 parent e5ad4e2 commit e576555
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 13 deletions.
4 changes: 3 additions & 1 deletion hippynn/experiment/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Tuple, Union

import torch
import warnings

from ..databases import Database
from ..databases.restarter import Restartable
Expand Down Expand Up @@ -75,7 +76,6 @@ def restore_checkpoint(structure: dict, state: dict, restore_db=True) -> dict:
:return: experiment structure
"""

structure["training_modules"][0].load_state_dict(state["model"])
structure["controller"].load_state_dict(state["controller"])

Expand Down Expand Up @@ -197,4 +197,6 @@ def load_model_from_cwd(map_location=None, model_device=None, **kwargs) -> Graph
model.load_state_dict(state)
if map_location == None and model_device != None and model_device != "cpu":
model = model.to(model_device)

return model

74 changes: 63 additions & 11 deletions hippynn/layers/hiplayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def warn_if_under(distance, threshold):
if dmin < threshold:
d_count = distance < threshold
d_frac = d_count.to(distance.dtype).mean()
d_sum = (d_count.sum()/2).to(torch.int)
d_sum = (d_count.sum() / 2).to(torch.int)
warnings.warn(
"Provided distances are underneath sensitivity range!\n"
f"Minimum distance in current batch: {dmin}\n"
Expand Down Expand Up @@ -139,7 +139,7 @@ class InteractLayer(torch.nn.Module):
Hipnn's interaction layer
"""

def __init__(self, nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module):
def __init__(self, nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module, cusp_reg=None):
"""
Constructor
Expand All @@ -150,9 +150,13 @@ def __init__(self, nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sen
:param maxd_soft: maximum distance for initial sensitivities
:param hard_cutoff: maximum distance for cutoff function
:param sensitivity_module: class or callable that builds sensitivity functions, should return nn.Module
:param cusp_reg: ignored, only provided with compatibility for tensor sensitivity API
"""
super().__init__()

if type(self) is InteractLayer and cusp_reg is not None:
# Parameter is not used in this class.
warnings.warn(f"Parameter `cusp_reg`={cusp_reg} is ignored in this class, and is only provided for API compatibility.")
self.n_dist = n_dist
self.nf_in = nf_in
self.nf_out = nf_out
Expand Down Expand Up @@ -211,11 +215,60 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs):


class InteractLayerVec(InteractLayer):
def __init__(self, nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module):
super().__init__(nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module)

def __init__(self, nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module, cusp_reg):
super().__init__(nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module, cusp_reg)
self.vecscales = torch.nn.Parameter(torch.Tensor(nf_out))
torch.nn.init.normal_(self.vecscales.data)
self.cusp_reg = cusp_reg

def __setstate__(self, state):
output = super().__setstate__(state)
if not hasattr(self, "cusp_reg"):
# The layer was created before the cusp regularization was a parameter.
# Add a patch that if a state dict is loaded in with no cusp parameter,
# use the pre-introduction static value.
warnings.warn(
"Loading a module which does not contain the 'cusp_reg' parameter. "
"In the future, this behavior will cause an error. "
"To avoid this warning, re-save this model to disk. "
)
self.handle = self.register_load_state_dict_post_hook(self.compatibility_hook)
return output

@staticmethod
def compatibility_hook(self, incompatible_keys):
missing = incompatible_keys.missing_keys
if not missing:
# No need for compatibility!
return

if len(missing) != 1:
warnings.warn("Backwards compatibility hook may have failed due to the presence of multiple missing keys!")
return

for m in missing:
if m.endswith("_extra_state"):
break
else:
# Python reminder: The mysterious "else" clause of the for loop
# activates when python does not break out of the for loop.
return # No _extra_state type variable was missing: just return.

DEPRECATED_CUSP_REG = 1e-30
warnings.warn(
f"Loaded state does not contain 'cusp_reg' parameter. "
f"Using deprecated value of 1e-30. "
f"This compatibility behavior will be removed in the future. "
f"To avoid this warning, re-save this model."
)
self.set_extra_state({"cusp_reg": DEPRECATED_CUSP_REG})
missing.remove(m)

def get_extra_state(self):
return {"cusp_reg": self.cusp_reg}

def set_extra_state(self, state):
self.cusp_reg = state["cusp_reg"]

def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs):

Expand All @@ -235,7 +288,7 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs)
env_features_vec = env_features_vec.reshape(n_atoms_real * 3, self.n_dist * self.nf_in)
features_out_vec = torch.mm(env_features_vec, weights_rs)
features_out_vec = features_out_vec.reshape(n_atoms_real, 3, self.nf_out)
features_out_vec = torch.square(features_out_vec).sum(dim=1) + 1e-30
features_out_vec = torch.square(features_out_vec).sum(dim=1) + self.cusp_reg
features_out_vec = torch.sqrt(features_out_vec)
features_out_vec = features_out_vec * self.vecscales.unsqueeze(0)

Expand All @@ -248,9 +301,8 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs)


class InteractLayerQuad(InteractLayerVec):
def __init__(self, nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module):
super().__init__(nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module)

def __init__(self, nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module, cusp_reg):
super().__init__(nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module, cusp_reg)
self.quadscales = torch.nn.Parameter(torch.Tensor(nf_out))
torch.nn.init.normal_(self.quadscales.data)
# upper indices of flattened 3x3 array minus the (3,3) component
Expand Down Expand Up @@ -280,7 +332,7 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs)
features_out_vec = torch.mm(env_features_vec, weights_rs)
# Norm and scale
features_out_vec = features_out_vec.reshape(n_atoms_real, 3, self.nf_out)
features_out_vec = torch.square(features_out_vec).sum(dim=1) + 1e-30
features_out_vec = torch.square(features_out_vec).sum(dim=1) + self.cusp_reg
features_out_vec = torch.sqrt(features_out_vec)
features_out_vec = features_out_vec * self.vecscales.unsqueeze(0)

Expand All @@ -303,7 +355,7 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs)
quadfirst = torch.square(features_out_quad).sum(dim=1)
quadsecond = features_out_quad[:, 0, :] * features_out_quad[:, 3, :]
features_out_quad = 2 * (quadfirst + quadsecond)
features_out_quad = torch.sqrt(features_out_quad + 1e-30)
features_out_quad = torch.sqrt(features_out_quad + self.cusp_reg)
# Scales
features_out_quad = features_out_quad * self.quadscales.unsqueeze(0)

Expand Down
8 changes: 7 additions & 1 deletion hippynn/networks/hipnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
sensitivity_type="inverse",
resnet=True,
activation=torch.nn.Softplus,
cusp_reg=None,
):
"""
Expand All @@ -97,6 +98,7 @@ def __init__(
'inverse' is what is in hip-nn original paper.
:param resnet: bool or int, if int, size of internal resnet width
:param activation: activation function or subclass of nn.module.
:param cusp_reg: Used for API compatibility, but ignored in vanilla HIP-NN.
Note: only one of possible_species or n_input_features is needed. If both are supplied,
they must be consistent with each other.
Expand Down Expand Up @@ -178,7 +180,7 @@ def __init__(

# Add interaction layer
lay = self._interaction_class(
in_size, middle_size, n_sensitivities, dist_soft_min, dist_soft_max, dist_hard_max, sensitivity_type
in_size, middle_size, n_sensitivities, dist_soft_min, dist_soft_max, dist_hard_max, sensitivity_type, cusp_reg
)
if self.resnet:
lay = ResNetWrapper(lay, in_size, middle_size, out_size, self.activation)
Expand Down Expand Up @@ -246,6 +248,10 @@ class HipnnVec(Hipnn):

_interaction_class = InteractLayerVec

def __init__(self, *args, cusp_reg=1e-6, **kwargs):
# cusp regularization for tensor sensitivity l>0. Defaults to 1e-6.
super().__init__(*args, cusp_reg=cusp_reg, **kwargs)

def forward(self, features, pair_first, pair_second, pair_dist, pair_coord):
features = features.to(pair_dist.dtype) # Convert one-hot features to floating point features.

Expand Down

0 comments on commit e576555

Please sign in to comment.