Skip to content

Commit

Permalink
track unwrapped and wrapped positions in MD code when cell is present…
Browse files Browse the repository at this point in the history
…, fix typo
  • Loading branch information
shinkle-lanl committed Sep 3, 2024
1 parent 85a91a6 commit 72ce274
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
2 changes: 1 addition & 1 deletion hippynn/experiment/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def load_checkpoint(
:param structure_fname: name of the structure file
:param state_fname: name of the state file
:param restart_db: restore database or not, defaults to True
:param restart_db: restore database or not, defaults to False
:param map_location: device mapping argument for ``torch.load``, defaults to None
:param model_device: automatically handle device mapping. Defaults to None, defaults to None
:return: experiment structure
Expand Down
14 changes: 8 additions & 6 deletions hippynn/molecular_dynamics/md.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from functools import singledispatchmethod
from copy import copy

import numpy as np
import torch
Expand All @@ -9,7 +10,6 @@
from ..graphs import Predictor
from ..layers.pairs.periodic import wrap_systems_torch


class Variable:
"""
Tracks the state of a quantity (eg. position, cell, species,
Expand Down Expand Up @@ -307,11 +307,12 @@ def pre_step(self, dt):

self.variable.data["position"] = self.variable.data["position"] + self.variable.data["velocity"] * dt

try:
_, self.variable.data["position"], *_ = wrap_systems_torch(coords=self.variable.data["position"], cell=self.variable.data["cell"], cutoff=0) # cutoff only used for discarded outputs; can be set arbitrarily
except KeyError:
pass

if "cell" in self.variable.data.keys():
_, self.variable.data["position"], *_ = wrap_systems_torch(coords=self.variable.data["position"], cell=self.variable.data["cell"], cutoff=0) # cutoff only impacts unused outputs; can be set arbitrarily
try:
self.variable.data["unwrapped_position"] = self.variable.data["unwrapped_position"] + self.variable.data["velocity"] * dt
except KeyError:
self.variable.data["unwrapped_position"] = copy(self.variable.data["position"])
def post_step(self, dt, model_outputs):
"""Updates to variables performed during each step of MD simulation after HIPNN model evaluation
Expand Down Expand Up @@ -406,6 +407,7 @@ def model(self, model):
+ f" Entries in the 'model_input_map' should have the form 'hipnn-db_name: variable-data-key' where 'hipnn-db_name'"
+ f" refers to the db_name of an input for the hippynn Predictor model,"
+ f" and 'variable-data-key' corresponds to a key in the 'data' dictionary of one of the Variables."
+ f" Currently assigned db_names are: {variable_data_db_names}."
)
self._model = model

Expand Down

0 comments on commit 72ce274

Please sign in to comment.