From 0805ce257d196c88b75f37b59d627f15f4a3bad0 Mon Sep 17 00:00:00 2001 From: Emily Shinkle Date: Mon, 16 Sep 2024 13:19:23 -0600 Subject: [PATCH] remove fix to delete large files --- hippynn/molecular_dynamics/md.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/hippynn/molecular_dynamics/md.py b/hippynn/molecular_dynamics/md.py index b16cc552..4509a428 100644 --- a/hippynn/molecular_dynamics/md.py +++ b/hippynn/molecular_dynamics/md.py @@ -1,5 +1,6 @@ from __future__ import annotations from functools import singledispatchmethod +from copy import copy import numpy as np import torch @@ -9,7 +10,6 @@ from ..graphs import Predictor from ..layers.pairs.periodic import wrap_systems_torch - class Variable: """ Tracks the state of a quantity (eg. position, cell, species, @@ -307,10 +307,12 @@ def pre_step(self, dt): self.variable.data["position"] = self.variable.data["position"] + self.variable.data["velocity"] * dt - try: - _, self.variable.data["position"], *_ = wrap_systems_torch(coords=self.variable.data["position"], cell=self.variable.data["cell"], cutoff=0) # cutoff only used for discarded outputs; can be set arbitrarily - except KeyError: - pass + if "cell" in self.variable.data.keys(): + _, self.variable.data["position"], *_ = wrap_systems_torch(coords=self.variable.data["position"], cell=self.variable.data["cell"], cutoff=0) # cutoff only impacts unused outputs; can be set arbitrarily + try: + self.variable.data["unwrapped_position"] = self.variable.data["unwrapped_position"] + self.variable.data["velocity"] * dt + except KeyError: + self.variable.data["unwrapped_position"] = copy(self.variable.data["position"]) def post_step(self, dt, model_outputs): """Updates to variables performed during each step of MD simulation after HIPNN model evaluation @@ -406,6 +408,7 @@ def model(self, model): + f" Entries in the 'model_input_map' should have the form 'hipnn-db_name: variable-data-key' where 'hipnn-db_name'" + f" refers to the db_name of an input for the hippynn Predictor model," + f" and 'variable-data-key' corresponds to a key in the 'data' dictionary of one of the Variables." + + f" Currently assigned db_names are: {variable_data_db_names}." ) self._model = model @@ -482,6 +485,9 @@ def _update_data(self, model_outputs: dict): except KeyError: self._data[f"output_{key}"] = [value.cpu().detach()[0]] + for key, value in variable._data.items(): + print(variable.name, key, value.shape) + def run(self, dt: float, n_steps: int, record_every: int = None): """Run `n_steps` of MD algorithm.