Skip to content

Commit

Permalink
Allows for custom writer initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
craabreu committed Apr 10, 2024
1 parent 1ed650e commit 166f6f1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
15 changes: 13 additions & 2 deletions cvpack/reporting/custom_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,22 @@ class CustomWriter(t.Protocol):
An abstract class for StateDataReporter writers
"""

def initialize(self, context: mm.Context) -> None:
"""
Initializes the writer. This method is called before the first report and
can be used to perform any necessary setup.
Parameters
----------
context
The context object.
"""

def getHeaders(self) -> t.List[str]:
"""
Gets a list of strigs containing the headers to be added to the report.
"""
raise NotImplementedError("Method getHeaders not implemented")
raise NotImplementedError("Method 'getHeaders' not implemented")

def getValues(self, context: mm.Context) -> t.List[float]:
"""
Expand All @@ -35,4 +46,4 @@ def getValues(self, context: mm.Context) -> t.List[float]:
state
The state object.
"""
raise NotImplementedError("Method getValues not implemented")
raise NotImplementedError("Method 'getValues' not implemented")
26 changes: 20 additions & 6 deletions cvpack/reporting/state_data_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,32 @@ class StateDataReporter(mmapp.StateDataReporter):
.. _openmm.app.StateDataReporter: http://docs.openmm.org/latest/api-python/
generated/openmm.app.statedatareporter.StateDataReporter.html
A custom writer is an object that includes the methods two particular methods.
The first one is ``getHeaders``, which returns a list of strings containing the
headers to be added to the report. It has the following signature:
A custom writer is an object that includes the following methods:
1. **getHeaders**: returns a list of strings containing the headers to be added
to the report. It must have the following signature:
.. code-block::
def getHeaders(self) -> List[str]:
pass
The second method is ``getValues``, which accepts an :OpenMM:`Context` as
argument and returns a list of floats containing the values to be added to the
report. It has the following signature:
2. **getValues**: returns a list of floats containing the values to be added to
the report at a given time step. It must have the following signature:
.. code-block::
def getValues(self, context: openmm.Context) -> List[float]:
pass
3. **initialize** (optional): performs any necessary setup before the first report.
If present, it must have the following signature:
.. code-block::
def initialize(self, context: openmm.Context) -> None:
pass
Parameters
----------
file
Expand Down Expand Up @@ -133,6 +141,12 @@ def _expand(self, sequence: list, addition: t.Iterable) -> list:
pos = len(sequence) - self._back_steps
return sum(addition, sequence[:pos]) + sequence[pos:]

def _initializeConstants(self, simulation: mmapp.Simulation) -> None:
super()._initializeConstants(simulation)
for writer in self._writers:
if hasattr(writer, "initialize"):
writer.initialize(simulation.context)

def _constructHeaders(self) -> t.List[str]:
return self._expand(
super()._constructHeaders(),
Expand Down

0 comments on commit 166f6f1

Please sign in to comment.