From 72ce27427c7ae4ac4ed852796dc4fc7f70992b5b Mon Sep 17 00:00:00 2001 From: Emily Shinkle Date: Tue, 3 Sep 2024 16:01:46 -0600 Subject: [PATCH] track unwrapped and wrapped positions in MD code when cell is present, fix typo --- hippynn/experiment/serialization.py | 2 +- hippynn/molecular_dynamics/md.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/hippynn/experiment/serialization.py b/hippynn/experiment/serialization.py index c4d73c1a..190fee39 100644 --- a/hippynn/experiment/serialization.py +++ b/hippynn/experiment/serialization.py @@ -138,7 +138,7 @@ def load_checkpoint( :param structure_fname: name of the structure file :param state_fname: name of the state file - :param restart_db: restore database or not, defaults to True + :param restart_db: restore database or not, defaults to False :param map_location: device mapping argument for ``torch.load``, defaults to None :param model_device: automatically handle device mapping. Defaults to None, defaults to None :return: experiment structure diff --git a/hippynn/molecular_dynamics/md.py b/hippynn/molecular_dynamics/md.py index b16cc552..8752990c 100644 --- a/hippynn/molecular_dynamics/md.py +++ b/hippynn/molecular_dynamics/md.py @@ -1,5 +1,6 @@ from __future__ import annotations from functools import singledispatchmethod +from copy import copy import numpy as np import torch @@ -9,7 +10,6 @@ from ..graphs import Predictor from ..layers.pairs.periodic import wrap_systems_torch - class Variable: """ Tracks the state of a quantity (eg. position, cell, species, @@ -307,11 +307,12 @@ def pre_step(self, dt): self.variable.data["position"] = self.variable.data["position"] + self.variable.data["velocity"] * dt - try: - _, self.variable.data["position"], *_ = wrap_systems_torch(coords=self.variable.data["position"], cell=self.variable.data["cell"], cutoff=0) # cutoff only used for discarded outputs; can be set arbitrarily - except KeyError: - pass - + if "cell" in self.variable.data.keys(): + _, self.variable.data["position"], *_ = wrap_systems_torch(coords=self.variable.data["position"], cell=self.variable.data["cell"], cutoff=0) # cutoff only impacts unused outputs; can be set arbitrarily + try: + self.variable.data["unwrapped_position"] = self.variable.data["unwrapped_position"] + self.variable.data["velocity"] * dt + except KeyError: + self.variable.data["unwrapped_position"] = copy(self.variable.data["position"]) def post_step(self, dt, model_outputs): """Updates to variables performed during each step of MD simulation after HIPNN model evaluation @@ -406,6 +407,7 @@ def model(self, model): + f" Entries in the 'model_input_map' should have the form 'hipnn-db_name: variable-data-key' where 'hipnn-db_name'" + f" refers to the db_name of an input for the hippynn Predictor model," + f" and 'variable-data-key' corresponds to a key in the 'data' dictionary of one of the Variables." + + f" Currently assigned db_names are: {variable_data_db_names}." ) self._model = model