Skip to content

Commit

Permalink
update progress bar in molecular_dynamics/md.py for consistency, fix …
Browse files Browse the repository at this point in the history
…spacing, fix bug in examples/molecular_dynamics.py
  • Loading branch information
shinkle-lanl committed Jun 24, 2024
1 parent a020912 commit b0dc259
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 9 deletions.
4 changes: 2 additions & 2 deletions examples/molecular_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def update(self, diff_steps, data):
def print(self, diff_steps=None, data=None):
time_per_atom_step = self.update(diff_steps, data)
"""Function to print the potential, kinetic and total energy"""
atoms.set_positions(data["position_position"][-1])
atoms.set_velocities(data["position_velocity"][-1])
atoms.set_positions(np.array(data["position_position"][-1]))
atoms.set_velocities(np.array(data["position_velocity"][-1]))
print(
"Performance:",
round(1e6 * time_per_atom_step, 1),
Expand Down
3 changes: 1 addition & 2 deletions hippynn/layers/physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ def forward(self, pair_dist, radius):

class LocalDampingCosine(AlphaScreening):
""" Local damping using complement of the hipnn cutoff function. ('glue-on' method)
g = 1 if pair_dist > R_cutoff
1 - [cos(pi/2 * dist * R_cutoff)]^2 otherwise
g = 1 if pair_dist > R_cutoff, 1 - [cos(pi/2 * dist * R_cutoff)]^2 otherwise
"""
def __init__(self, alpha):
"""
Expand Down
8 changes: 3 additions & 5 deletions hippynn/molecular_dynamics/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

import numpy as np
import torch

from tqdm.autonotebook import trange
import ase

from ..tools import progress_bar
from ..graphs import Predictor
from ..layers.pairs.periodic import wrap_systems_torch

Expand Down Expand Up @@ -359,8 +358,7 @@ def __init__(
device: torch.device = None,
dtype: torch.dtype = None,
):
"""_summary_
"""
:param variables: list of Variable objects which will be tracked during simulation
:type variables: list[Variable]
:param model: HIPNN Predictor
Expand Down Expand Up @@ -506,7 +504,7 @@ def run(self, dt: float, n_steps: int, record_every: int = None):
:type record_every: int, optional
"""

for i in trange(n_steps):
for i in progress_bar(range(n_steps)):
model_outputs = self._step(dt)
if record_every is not None and (i + 1) % record_every == 0:
self._update_data(model_outputs)
Expand Down

0 comments on commit b0dc259

Please sign in to comment.