Skip to content

Commit

Permalink
Correct outputs for describeNextReport and add testing.
Browse files Browse the repository at this point in the history
  • Loading branch information
orionarcher committed Apr 4, 2024
1 parent 6d7887d commit acae341
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
26 changes: 23 additions & 3 deletions mdareporter/mdareporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@
USE OR OTHER DEALINGS IN THE SOFTWARE.
"""



import MDAnalysis as mda
from MDAnalysis.lib.util import get_ext
from MDAnalysis.lib.mdamath import triclinic_box
from openmm import unit
import numpy as np
Expand Down Expand Up @@ -102,7 +101,28 @@ def describeNextReport(self, simulation):
positions should be wrapped to lie in a single periodic box.
"""
steps = self._reportInterval - simulation.currentStep%self._reportInterval
return (steps, True, False, False, False, self._enforcePeriodicBox)

try:
# this will be called again inside mda.Writer but we need the ext here
root, ext = get_ext(self._filename)
except (TypeError, AttributeError):
errmsg = f'File format could not be guessed from "{self._filename}"'
raise ValueError(errmsg) from None

if ext in ["trr"]:
positions, velocities, forces = True, True, True
elif ext in ["ncdf", "nc"]:
positions = True
velocities = self._writer_kwargs.get("positions", False)
forces = self._writer_kwargs.get("forces", False)
elif ext in ["h5md"]:
positions = self._writer_kwargs.get("positions", True)
velocities = self._writer_kwargs.get("velocities", True)
forces = self._writer_kwargs.get("forces", True)
else:
positions, velocities, forces = True, False, False

return steps, positions, velocities, forces, False, self._enforcePeriodicBox

def report(self, simulation, state):
"""Generate a report.
Expand Down
20 changes: 19 additions & 1 deletion mdareporter/tests/test_mdareporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,22 @@ def test_mdareporter_boxvectors(file_ext, simulation):

print("max error in positions is ",np.max(max_errors)," Angstrom")



@pytest.mark.parametrize("file_ext, expected_returns", [
("dcd", (1, True, False, False)),
("ncdf", (1, True, False, False)),
("trr", (1, True, True, True)),
("nc", (1, True, False, False)),
("h5md", (1, True, True, True)),
])
def test_describenextreport(file_ext, expected_returns, simulation):
report_interval = 1
reporter = MDAReporter(f"output.{file_ext}", report_interval, enforcePeriodicBox=True)

simulation.currentStep = 0
steps, positions, velocities, forces, _, _ = reporter.describeNextReport(simulation)

assert steps == report_interval
assert positions == expected_returns[1]
assert velocities == expected_returns[2]
assert forces == expected_returns[3]

0 comments on commit acae341

Please sign in to comment.