From ebc222160dd209ce17de17b38b900fbc3e3f36a6 Mon Sep 17 00:00:00 2001 From: Emily Shinkle Date: Thu, 17 Oct 2024 10:38:09 -0600 Subject: [PATCH] fix bug with batched MD saving --- hippynn/molecular_dynamics/__init__.py | 4 ++-- hippynn/molecular_dynamics/md.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/hippynn/molecular_dynamics/__init__.py b/hippynn/molecular_dynamics/__init__.py index 0807e466..a7c6e4d0 100644 --- a/hippynn/molecular_dynamics/__init__.py +++ b/hippynn/molecular_dynamics/__init__.py @@ -5,7 +5,7 @@ """ -from .md import MolecularDynamics, Variable, NullUpdater, VelocityVerlet, LangevinDynamics +from .md import MolecularDynamics, Variable, NullUpdater, VelocityVerlet, LangevinDynamics, VariableUpdater -__all__ = ["MolecularDynamics", "Variable", "NullUpdater", "VelocityVerlet", "LangevinDynamics"] +__all__ = ["MolecularDynamics", "Variable", "VariableUpdater", "NullUpdater", "VelocityVerlet", "LangevinDynamics"] diff --git a/hippynn/molecular_dynamics/md.py b/hippynn/molecular_dynamics/md.py index 58bca979..ebb18a77 100644 --- a/hippynn/molecular_dynamics/md.py +++ b/hippynn/molecular_dynamics/md.py @@ -450,14 +450,14 @@ def _update_data(self, model_outputs: dict): for variable in self.variables: for key, value in variable.data.items(): try: - self._data[f"{variable.name}_{key}"].append(value.cpu().detach()[0]) + self._data[f"{variable.name}_{key}"].append(value.cpu().detach()) except KeyError: - self._data[f"{variable.name}_{key}"] = [value.cpu().detach()[0]] + self._data[f"{variable.name}_{key}"] = [value.cpu().detach()] for key, value in model_outputs.items(): try: - self._data[f"output_{key}"].append(value.cpu().detach()[0]) + self._data[f"output_{key}"].append(value.cpu().detach()) except KeyError: - self._data[f"output_{key}"] = [value.cpu().detach()[0]] + self._data[f"output_{key}"] = [value.cpu().detach()] def run(self, dt: float, n_steps: int, record_every: Optional[int] = None): """Run `n_steps` of MD algorithm.