Skip to content

Commit

Permalink
minor changes to MD code & its documentation (#81)
Browse files Browse the repository at this point in the history
* Make VariableUpdater in molecular_dynamics/md.py public and fix documentation formatting in same file
  • Loading branch information
shinkle-lanl authored Jun 24, 2024
1 parent 7e6614e commit 00e0094
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 25 deletions.
2 changes: 1 addition & 1 deletion hippynn/molecular_dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Molecular dynamics driver with great flexibility and customizability regarding which quantities which are evolved
and what algorithms are used to evolve them. Calls a hippynn `Predictor` on current state during each MD step.
"""
from .md import Variable, NullUpdater, VelocityVerlet, LangevinDynamics, MolecularDynamics
from .md import *
39 changes: 15 additions & 24 deletions hippynn/molecular_dynamics/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
name: str,
data: dict[str, torch.Tensor],
model_input_map: dict[str, str] = dict(),
updater: _VariableUpdater = None,
updater: VariableUpdater = None,
device: torch.device = None,
dtype: torch.dtype = None,
) -> None:
Expand All @@ -38,7 +38,7 @@ def __init__(
:type model_input_map: dict[str, str], optional
:param updater: object which will update the data of the Variable
over the course of the MD simulation, defaults to None
:type updater: _VariableUpdater, optional
:type updater: VariableUpdater, optional
:param device: device on which to keep data, defaults to None
:type device: torch.device, optional
:param dtype: dtype for float type data, defaults to None
Expand Down Expand Up @@ -136,14 +136,11 @@ def _(self, arg: torch.dtype):
self.dtype = arg


class _VariableUpdater:
class VariableUpdater:
"""
Parent class for algorithms that make updates to the data of a Variable during
each step of an MD simulation.
Parent class for algorithms that make updates to the data of a Variable during each step of an MD simulation.
Subclasses should redefine __init__, pre_step, post_step, and
required_variable_data as needed. The inputs to pre_step and post_step
should not be changed.
Subclasses should redefine __init__, pre_step, post_step, and required_variable_data as needed. The inputs to pre_step and post_step should not be changed.
"""

# A list of keys which must appear in Variable.data for any Variable that will be updated by objects of this class.
Expand All @@ -167,17 +164,15 @@ def variable(self, variable):
self._variable = variable

def pre_step(self, dt):
"""Updates to variables performed during each step of MD simulation
before HIPNN model evaluation
"""Updates to variables performed during each step of MD simulation before HIPNN model evaluation
:param dt: timestep
:type dt: float
"""
pass

def post_step(self, dt, model_outputs):
"""Updates to variables performed during each step of MD simulation
after HIPNN model evaluation
"""Updates to variables performed during each step of MD simulation after HIPNN model evaluation
:param dt: timestep
:type dt: float
Expand All @@ -187,7 +182,7 @@ def post_step(self, dt, model_outputs):
pass


class NullUpdater(_VariableUpdater):
class NullUpdater(VariableUpdater):
"""
Makes no change to the variable data at each step of MD.
"""
Expand All @@ -198,7 +193,7 @@ def pre_step(self, dt):
def post_step(self, dt, model_outputs):
pass

class VelocityVerlet(_VariableUpdater):
class VelocityVerlet(VariableUpdater):
"""
Implements the Velocity Verlet algorithm
"""
Expand Down Expand Up @@ -228,8 +223,7 @@ def __init__(
self.force_factor = units_force / units_acc

def pre_step(self, dt):
"""Updates to variables performed during each step of MD simulation
before HIPNN model evaluation
"""Updates to variables performed during each step of MD simulation before HIPNN model evaluation
:param dt: timestep
:type dt: float
Expand All @@ -242,8 +236,7 @@ def pre_step(self, dt):
pass

def post_step(self, dt, model_outputs):
"""Updates to variables performed during each step of MD simulation
after HIPNN model evaluation
"""Updates to variables performed during each step of MD simulation after HIPNN model evaluation
:param dt: timestep
:type dt: float
Expand All @@ -260,7 +253,7 @@ def post_step(self, dt, model_outputs):
self.variable.data["velocity"] = self.variable.data["velocity"] + 0.5 * dt * self.variable.data["acceleration"]


class LangevinDynamics(_VariableUpdater):
class LangevinDynamics(VariableUpdater):
"""
Implements the Langevin algorithm
"""
Expand Down Expand Up @@ -306,8 +299,7 @@ def __init__(
torch.manual_seed(seed)

def pre_step(self, dt):
"""Updates to variables performed during each step of MD simulation
before HIPNN model evaluation
"""Updates to variables performed during each step of MD simulation before HIPNN model evaluation
:param dt: timestep
:type dt: float
Expand All @@ -321,8 +313,7 @@ def pre_step(self, dt):
pass

def post_step(self, dt, model_outputs):
"""Updates to variables performed during each step of MD simulation
after HIPNN model evaluation
"""Updates to variables performed during each step of MD simulation after HIPNN model evaluation
:param dt: timestep
:type dt: float
Expand Down Expand Up @@ -386,7 +377,7 @@ def variables(self, variables):
variables = [variables]
for variable in variables:
if variable.updater is None:
raise ValueError(f"Variable with name {variable.name} does not have a _VariableUpdater set.")
raise ValueError(f"Variable with name {variable.name} does not have a VariableUpdater set.")

variable_names = [variable.name for variable in variables]
if len(variable_names) != len(set(variable_names)):
Expand Down

0 comments on commit 00e0094

Please sign in to comment.