From e5765555bebdcb81bce6df3612009c5b742343ae Mon Sep 17 00:00:00 2001 From: Sakib Matin <83463357+sakibmatin@users.noreply.github.com> Date: Tue, 9 Apr 2024 10:18:28 -0600 Subject: [PATCH] Updated Cusp Regularization (#59) * Larger Cusp Regularization The regularization term for the hipnnvec and hipnnquad are changed to `1e-6`. * add cusp_reg kwarg for hipnn class. The new default cusp regularization term is larger to avoid numerical issues with force calculations. * add cusp_reg for interact layers The csup regularization is only needed for InteractLayerQuad and InteractLayerVec. * backward compatibility for cusp regularization The backwards compatibility function mutates the reloaded model graph in place. This function can be updated in the future based on which older versions of hippynn to support. * Passing cusp_reg to Hipnn base class * passing cusp_reg args * Move backwards compatbility hooks to the affected layer * apply code formatter * remove unused imports --------- Co-authored-by: Nicholas Lubbers Co-authored-by: Nicholas Lubbers <56895592+lubbersnick@users.noreply.github.com> --- hippynn/experiment/serialization.py | 4 +- hippynn/layers/hiplayers.py | 74 ++++++++++++++++++++++++----- hippynn/networks/hipnn.py | 8 +++- 3 files changed, 73 insertions(+), 13 deletions(-) diff --git a/hippynn/experiment/serialization.py b/hippynn/experiment/serialization.py index 17248bf9..132a0c30 100644 --- a/hippynn/experiment/serialization.py +++ b/hippynn/experiment/serialization.py @@ -5,6 +5,7 @@ from typing import Tuple, Union import torch +import warnings from ..databases import Database from ..databases.restarter import Restartable @@ -75,7 +76,6 @@ def restore_checkpoint(structure: dict, state: dict, restore_db=True) -> dict: :return: experiment structure """ - structure["training_modules"][0].load_state_dict(state["model"]) structure["controller"].load_state_dict(state["controller"]) @@ -197,4 +197,6 @@ def load_model_from_cwd(map_location=None, model_device=None, **kwargs) -> Graph model.load_state_dict(state) if map_location == None and model_device != None and model_device != "cpu": model = model.to(model_device) + return model + diff --git a/hippynn/layers/hiplayers.py b/hippynn/layers/hiplayers.py index be70943e..2f62c07d 100644 --- a/hippynn/layers/hiplayers.py +++ b/hippynn/layers/hiplayers.py @@ -16,7 +16,7 @@ def warn_if_under(distance, threshold): if dmin < threshold: d_count = distance < threshold d_frac = d_count.to(distance.dtype).mean() - d_sum = (d_count.sum()/2).to(torch.int) + d_sum = (d_count.sum() / 2).to(torch.int) warnings.warn( "Provided distances are underneath sensitivity range!\n" f"Minimum distance in current batch: {dmin}\n" @@ -139,7 +139,7 @@ class InteractLayer(torch.nn.Module): Hipnn's interaction layer """ - def __init__(self, nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module): + def __init__(self, nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module, cusp_reg=None): """ Constructor @@ -150,9 +150,13 @@ def __init__(self, nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sen :param maxd_soft: maximum distance for initial sensitivities :param hard_cutoff: maximum distance for cutoff function :param sensitivity_module: class or callable that builds sensitivity functions, should return nn.Module + :param cusp_reg: ignored, only provided with compatibility for tensor sensitivity API """ super().__init__() + if type(self) is InteractLayer and cusp_reg is not None: + # Parameter is not used in this class. + warnings.warn(f"Parameter `cusp_reg`={cusp_reg} is ignored in this class, and is only provided for API compatibility.") self.n_dist = n_dist self.nf_in = nf_in self.nf_out = nf_out @@ -211,11 +215,60 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs): class InteractLayerVec(InteractLayer): - def __init__(self, nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module): - super().__init__(nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module) - + def __init__(self, nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module, cusp_reg): + super().__init__(nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module, cusp_reg) self.vecscales = torch.nn.Parameter(torch.Tensor(nf_out)) torch.nn.init.normal_(self.vecscales.data) + self.cusp_reg = cusp_reg + + def __setstate__(self, state): + output = super().__setstate__(state) + if not hasattr(self, "cusp_reg"): + # The layer was created before the cusp regularization was a parameter. + # Add a patch that if a state dict is loaded in with no cusp parameter, + # use the pre-introduction static value. + warnings.warn( + "Loading a module which does not contain the 'cusp_reg' parameter. " + "In the future, this behavior will cause an error. " + "To avoid this warning, re-save this model to disk. " + ) + self.handle = self.register_load_state_dict_post_hook(self.compatibility_hook) + return output + + @staticmethod + def compatibility_hook(self, incompatible_keys): + missing = incompatible_keys.missing_keys + if not missing: + # No need for compatibility! + return + + if len(missing) != 1: + warnings.warn("Backwards compatibility hook may have failed due to the presence of multiple missing keys!") + return + + for m in missing: + if m.endswith("_extra_state"): + break + else: + # Python reminder: The mysterious "else" clause of the for loop + # activates when python does not break out of the for loop. + return # No _extra_state type variable was missing: just return. + + DEPRECATED_CUSP_REG = 1e-30 + warnings.warn( + f"Loaded state does not contain 'cusp_reg' parameter. " + f"Using deprecated value of 1e-30. " + f"This compatibility behavior will be removed in the future. " + f"To avoid this warning, re-save this model." + ) + self.set_extra_state({"cusp_reg": DEPRECATED_CUSP_REG}) + missing.remove(m) + + def get_extra_state(self): + return {"cusp_reg": self.cusp_reg} + + def set_extra_state(self, state): + self.cusp_reg = state["cusp_reg"] def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs): @@ -235,7 +288,7 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs) env_features_vec = env_features_vec.reshape(n_atoms_real * 3, self.n_dist * self.nf_in) features_out_vec = torch.mm(env_features_vec, weights_rs) features_out_vec = features_out_vec.reshape(n_atoms_real, 3, self.nf_out) - features_out_vec = torch.square(features_out_vec).sum(dim=1) + 1e-30 + features_out_vec = torch.square(features_out_vec).sum(dim=1) + self.cusp_reg features_out_vec = torch.sqrt(features_out_vec) features_out_vec = features_out_vec * self.vecscales.unsqueeze(0) @@ -248,9 +301,8 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs) class InteractLayerQuad(InteractLayerVec): - def __init__(self, nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module): - super().__init__(nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module) - + def __init__(self, nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module, cusp_reg): + super().__init__(nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sensitivity_module, cusp_reg) self.quadscales = torch.nn.Parameter(torch.Tensor(nf_out)) torch.nn.init.normal_(self.quadscales.data) # upper indices of flattened 3x3 array minus the (3,3) component @@ -280,7 +332,7 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs) features_out_vec = torch.mm(env_features_vec, weights_rs) # Norm and scale features_out_vec = features_out_vec.reshape(n_atoms_real, 3, self.nf_out) - features_out_vec = torch.square(features_out_vec).sum(dim=1) + 1e-30 + features_out_vec = torch.square(features_out_vec).sum(dim=1) + self.cusp_reg features_out_vec = torch.sqrt(features_out_vec) features_out_vec = features_out_vec * self.vecscales.unsqueeze(0) @@ -303,7 +355,7 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs) quadfirst = torch.square(features_out_quad).sum(dim=1) quadsecond = features_out_quad[:, 0, :] * features_out_quad[:, 3, :] features_out_quad = 2 * (quadfirst + quadsecond) - features_out_quad = torch.sqrt(features_out_quad + 1e-30) + features_out_quad = torch.sqrt(features_out_quad + self.cusp_reg) # Scales features_out_quad = features_out_quad * self.quadscales.unsqueeze(0) diff --git a/hippynn/networks/hipnn.py b/hippynn/networks/hipnn.py index b9b629b6..5c5ece3d 100644 --- a/hippynn/networks/hipnn.py +++ b/hippynn/networks/hipnn.py @@ -80,6 +80,7 @@ def __init__( sensitivity_type="inverse", resnet=True, activation=torch.nn.Softplus, + cusp_reg=None, ): """ @@ -97,6 +98,7 @@ def __init__( 'inverse' is what is in hip-nn original paper. :param resnet: bool or int, if int, size of internal resnet width :param activation: activation function or subclass of nn.module. + :param cusp_reg: Used for API compatibility, but ignored in vanilla HIP-NN. Note: only one of possible_species or n_input_features is needed. If both are supplied, they must be consistent with each other. @@ -178,7 +180,7 @@ def __init__( # Add interaction layer lay = self._interaction_class( - in_size, middle_size, n_sensitivities, dist_soft_min, dist_soft_max, dist_hard_max, sensitivity_type + in_size, middle_size, n_sensitivities, dist_soft_min, dist_soft_max, dist_hard_max, sensitivity_type, cusp_reg ) if self.resnet: lay = ResNetWrapper(lay, in_size, middle_size, out_size, self.activation) @@ -246,6 +248,10 @@ class HipnnVec(Hipnn): _interaction_class = InteractLayerVec + def __init__(self, *args, cusp_reg=1e-6, **kwargs): + # cusp regularization for tensor sensitivity l>0. Defaults to 1e-6. + super().__init__(*args, cusp_reg=cusp_reg, **kwargs) + def forward(self, features, pair_first, pair_second, pair_dist, pair_coord): features = features.to(pair_dist.dtype) # Convert one-hot features to floating point features.