diff --git a/cvpack/reporting/custom_writer.py b/cvpack/reporting/custom_writer.py index 48859e34..caea4c8d 100644 --- a/cvpack/reporting/custom_writer.py +++ b/cvpack/reporting/custom_writer.py @@ -18,11 +18,22 @@ class CustomWriter(t.Protocol): An abstract class for StateDataReporter writers """ + def initialize(self, context: mm.Context) -> None: + """ + Initializes the writer. This method is called before the first report and + can be used to perform any necessary setup. + + Parameters + ---------- + context + The context object. + """ + def getHeaders(self) -> t.List[str]: """ Gets a list of strigs containing the headers to be added to the report. """ - raise NotImplementedError("Method getHeaders not implemented") + raise NotImplementedError("Method 'getHeaders' not implemented") def getValues(self, context: mm.Context) -> t.List[float]: """ @@ -35,4 +46,4 @@ def getValues(self, context: mm.Context) -> t.List[float]: state The state object. """ - raise NotImplementedError("Method getValues not implemented") + raise NotImplementedError("Method 'getValues' not implemented") diff --git a/cvpack/reporting/state_data_reporter.py b/cvpack/reporting/state_data_reporter.py index 1923638d..e3cdb5b4 100644 --- a/cvpack/reporting/state_data_reporter.py +++ b/cvpack/reporting/state_data_reporter.py @@ -24,24 +24,32 @@ class StateDataReporter(mmapp.StateDataReporter): .. _openmm.app.StateDataReporter: http://docs.openmm.org/latest/api-python/ generated/openmm.app.statedatareporter.StateDataReporter.html - A custom writer is an object that includes the methods two particular methods. - The first one is ``getHeaders``, which returns a list of strings containing the - headers to be added to the report. It has the following signature: + A custom writer is an object that includes the following methods: + + 1. **getHeaders**: returns a list of strings containing the headers to be added + to the report. It must have the following signature: .. code-block:: def getHeaders(self) -> List[str]: pass - The second method is ``getValues``, which accepts an :OpenMM:`Context` as - argument and returns a list of floats containing the values to be added to the - report. It has the following signature: + 2. **getValues**: returns a list of floats containing the values to be added to + the report at a given time step. It must have the following signature: .. code-block:: def getValues(self, context: openmm.Context) -> List[float]: pass + 3. **initialize** (optional): performs any necessary setup before the first report. + If present, it must have the following signature: + + .. code-block:: + + def initialize(self, context: openmm.Context) -> None: + pass + Parameters ---------- file @@ -133,6 +141,12 @@ def _expand(self, sequence: list, addition: t.Iterable) -> list: pos = len(sequence) - self._back_steps return sum(addition, sequence[:pos]) + sequence[pos:] + def _initializeConstants(self, simulation: mmapp.Simulation) -> None: + super()._initializeConstants(simulation) + for writer in self._writers: + if hasattr(writer, "initialize"): + writer.initialize(simulation.context) + def _constructHeaders(self) -> t.List[str]: return self._expand( super()._constructHeaders(),