Skip to content

Commit

Permalink
plumpy.ProcessListener made persistent
Browse files Browse the repository at this point in the history
solves aiidateam#273

We implement the persistence of ProcessListener by deriving the class
ProcessListener and EventHelper from persistence.Savable.
The class EventHelper is moved to a new file because of a circular
import with utils and persistence

Fixing the test

There was a circular reference issue in the test listener that was
storing a reference to the process inside it, making its serialization
impossible. To fix the tests an ugly hack was used: storing the
reference to the process outside the class in a global dict using id as
keys. Some more ugly hacks are needed to check correctly the equality of
two processes. We must ignore the fact that the instances if the
listener are different.

We call del on dict items of the ProcessListener's global implemented in the test suite
to clean the golbal variables

addressed issues in aiidateam#274
  • Loading branch information
rikigigi committed Nov 10, 2023
1 parent 44d27d1 commit a830eef
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 63 deletions.
2 changes: 2 additions & 0 deletions src/plumpy/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> None:
wrapped(self, *args, **kwargs)
self._called -= 1

#the following is to show the correct name later in the call_with_super_check error message
wrapper.__name__ = wrapped.__name__
return wrapper


Expand Down
52 changes: 52 additions & 0 deletions src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
import logging
from typing import TYPE_CHECKING, Any, Callable, Set, Type

from . import persistence

if TYPE_CHECKING:
from .process_listener import ProcessListener # pylint: disable=cyclic-import

_LOGGER = logging.getLogger(__name__)


@persistence.auto_persist('_listeners', '_listener_type')
class EventHelper(persistence.Savable):

def __init__(self, listener_type: 'Type[ProcessListener]'):
assert listener_type is not None, 'Must provide valid listener type'

self._listener_type = listener_type
self._listeners: 'Set[ProcessListener]' = set()

def add_listener(self, listener: 'ProcessListener') -> None:
assert isinstance(listener, self._listener_type), 'Listener is not of right type'
self._listeners.add(listener)

def remove_listener(self, listener: 'ProcessListener') -> None:
self._listeners.discard(listener)

def remove_all_listeners(self) -> None:
self._listeners.clear()

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners

def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""Call an event method on all listeners.
:param event_function: the method of the ProcessListener
:param args: arguments to pass to the method
:param kwargs: keyword arguments to pass to the method
"""
if event_function is None:
raise ValueError('Must provide valid event method')

# Make a copy of the list for iteration just in case it changes in a callback
for listener in list(self.listeners):
try:
getattr(listener, event_function.__name__)(*args, **kwargs)
except Exception as exception: # pylint: disable=broad-except
_LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception)
26 changes: 24 additions & 2 deletions src/plumpy/process_listener.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
# -*- coding: utf-8 -*-
import abc
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Dict, Optional

from . import persistence
from .utils import SAVED_STATE_TYPE, protected

__all__ = ['ProcessListener']

if TYPE_CHECKING:
from .processes import Process # pylint: disable=cyclic-import


class ProcessListener(metaclass=abc.ABCMeta):
@persistence.auto_persist('_params')
class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta):

# region Persistence methods

def __init__(self) -> None:
super().__init__()
self._params: Dict[str, Any] = {}

def init(self, **kwargs: Any) -> None:
self._params = kwargs

@protected
def load_instance_state(
self, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext]
) -> None:
super().load_instance_state(saved_state, load_context)
self.init(**saved_state['_params'])

# endregion

def on_process_created(self, process: 'Process') -> None:
"""
Expand Down
17 changes: 10 additions & 7 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .base import state_machine
from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event
from .base.utils import call_with_super_check, super_check
from .event_helper import EventHelper
from .process_listener import ProcessListener
from .process_spec import ProcessSpec
from .utils import PID_TYPE, SAVED_STATE_TYPE, protected
Expand Down Expand Up @@ -91,7 +92,9 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
return func_wrapper


@persistence.auto_persist('_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status')
@persistence.auto_persist(
'_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper'
)
class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta):
"""
The Process class is the base for any unit of work in plumpy.
Expand Down Expand Up @@ -289,7 +292,7 @@ def __init__(

# Runtime variables
self._future = persistence.SavableFuture(loop=self._loop)
self.__event_helper = utils.EventHelper(ProcessListener)
self._event_helper = EventHelper(ProcessListener)
self._logger = logger
self._communicator = communicator

Expand Down Expand Up @@ -612,7 +615,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi

# Runtime variables, set initial states
self._future = persistence.SavableFuture()
self.__event_helper = utils.EventHelper(ProcessListener)
self._event_helper = EventHelper(ProcessListener)
self._logger = None
self._communicator = None

Expand Down Expand Up @@ -661,11 +664,11 @@ def add_process_listener(self, listener: ProcessListener) -> None:
"""
assert (listener != self), 'Cannot listen to yourself!' # type: ignore
self.__event_helper.add_listener(listener)
self._event_helper.add_listener(listener)

def remove_process_listener(self, listener: ProcessListener) -> None:
"""Remove a process listener from the process."""
self.__event_helper.remove_listener(listener)
self._event_helper.remove_listener(listener)

@protected
def set_logger(self, logger: logging.Logger) -> None:
Expand Down Expand Up @@ -778,7 +781,7 @@ def on_output_emitting(self, output_port: str, value: Any) -> None:
"""Output is about to be emitted."""

def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None:
self.__event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic)
self._event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic)

@super_check
def on_wait(self, awaitables: Sequence[Awaitable]) -> None:
Expand Down Expand Up @@ -891,7 +894,7 @@ def on_close(self) -> None:
self._closed = True

def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
self.__event_helper.fire_event(evt, self, *args, **kwargs)
self._event_helper.fire_event(evt, self, *args, **kwargs)

# endregion

Expand Down
41 changes: 0 additions & 41 deletions src/plumpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,6 @@
PID_TYPE = Hashable # pylint: disable=invalid-name


class EventHelper:

def __init__(self, listener_type: 'Type[ProcessListener]'):
assert listener_type is not None, 'Must provide valid listener type'

self._listener_type = listener_type
self._listeners: 'Set[ProcessListener]' = set()

def add_listener(self, listener: 'ProcessListener') -> None:
assert isinstance(listener, self._listener_type), 'Listener is not of right type'
self._listeners.add(listener)

def remove_listener(self, listener: 'ProcessListener') -> None:
self._listeners.discard(listener)

def remove_all_listeners(self) -> None:
self._listeners.clear()

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners

def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""Call an event method on all listeners.
:param event_function: the method of the ProcessListener
:param args: arguments to pass to the method
:param kwargs: keyword arguments to pass to the method
"""
if event_function is None:
raise ValueError('Must provide valid event method')

# Make a copy of the list for iteration just in case it changes in a callback
for listener in list(self.listeners):
try:
getattr(listener, event_function.__name__)(*args, **kwargs)
except Exception as exception: # pylint: disable=broad-except
_LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception)


class Frozendict(Mapping):
"""
An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping`
Expand Down
7 changes: 5 additions & 2 deletions test/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,9 @@ def test_instance_state_with_outputs(self):
# Check that it is a copy
self.assertIsNot(outputs, bundle.get(BundleKeys.OUTPUTS, {}))
# Check the contents are the same
self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {}))
#we remove the ProcessSaver instance that is an object used only for testing
utils.compare_dictionaries(None, None, outputs, bundle.get(BundleKeys.OUTPUTS, {}), exclude={'_listeners'})
#self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {}))

self.assertIsNot(proc.outputs, saver.snapshots[-1].get(BundleKeys.OUTPUTS, {}))

Expand Down Expand Up @@ -875,7 +877,8 @@ def _check_round_trip(self, proc1):
bundle2 = plumpy.Bundle(proc2)

self.assertEqual(proc1.pid, proc2.pid)
self.assertDictEqual(bundle1, bundle2)
#self.assertDictEqual(bundle1, bundle2)
utils.compare_dictionaries(None, None, bundle1, bundle2, exclude={'_listeners'})


class TestProcessNamespace(unittest.TestCase):
Expand Down
84 changes: 73 additions & 11 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def run(self):
self.out('test', 5)
return process_states.Continue(self.middle_step)

def middle_step(self,):
def middle_step(self):
return process_states.Continue(self.last_step)

def last_step(self):
Expand Down Expand Up @@ -260,25 +260,72 @@ def _save(self, p):
self.outputs.append(p.outputs.copy())


class ProcessSaver(plumpy.ProcessListener, Saver):
_ProcessSaverProcReferences = {}
_ProcessSaver_Saver = {}


class ProcessSaver(plumpy.ProcessListener):
"""
Save the instance state of a process each time it is about to enter a new state
Save the instance state of a process each time it is about to enter a new state.
NB: this is not a general purpose saver, it is only intended to be used for testing
The listener instances inside a process are persisted, so if we store a process
reference in the ProcessSaver instance, we will have a circular reference that cannot be
persisted. So we store the Saver instance in a global dictionary with the key the id of the
ProcessSaver instance.
In the init_not_persistent method we initialize the instances that cannot be persisted,
like the saver instance. The __del__ method is used to clean up the global dictionaries
(note there is no guarantee that __del__ will be called)
"""

def __del__(self):
global _ProcessSaver_Saver
global _ProcessSaverProcReferences
if _ProcessSaverProcReferences is not None and id(self) in _ProcessSaverProcReferences:
del _ProcessSaverProcReferences[id(self)]
if _ProcessSaver_Saver is not None and id(self) in _ProcessSaver_Saver:
del _ProcessSaver_Saver[id(self)]

def get_process(self):
global _ProcessSaverProcReferences
return _ProcessSaverProcReferences[id(self)]

def _save(self, p):
global _ProcessSaver_Saver
_ProcessSaver_Saver[id(self)]._save(p)

def set_process(self, process):
global _ProcessSaverProcReferences
_ProcessSaverProcReferences[id(self)] = process

def __init__(self, proc):
plumpy.ProcessListener.__init__(self)
Saver.__init__(self)
self.process = proc
super().__init__()
proc.add_process_listener(self)
self.init_not_persistent(proc)

def init_not_persistent(self, proc):
global _ProcessSaver_Saver
_ProcessSaver_Saver[id(self)] = Saver()
self.set_process(proc)

def capture(self):
self._save(self.process)
if not self.process.has_terminated():
self._save(self.get_process())
if not self.get_process().has_terminated():
try:
self.process.execute()
self.get_process().execute()
except Exception:
pass

@property
def snapshots(self):
global _ProcessSaver_Saver
return _ProcessSaver_Saver[id(self)].snapshots

@property
def outputs(self):
global _ProcessSaver_Saver
return _ProcessSaver_Saver[id(self)].outputs

@utils.override
def on_process_running(self, process):
self._save(process)
Expand Down Expand Up @@ -335,7 +382,13 @@ def check_process_against_snapshots(loop, proc_class, snapshots):
"""
for i, bundle in zip(list(range(0, len(snapshots))), snapshots):
loaded = bundle.unbundle(plumpy.LoadSaveContext(loop=loop))
saver = ProcessSaver(loaded)
# the process listeners are persisted
saver = list(loaded._event_helper._listeners)[0]
assert isinstance(saver, ProcessSaver)
# the process reference inside this particular implementation of process listener
# cannot be persisted because of a circular reference. So we load it there
# also the saver is not persisted for the same reason. We load it manually
saver.init_not_persistent(loaded)
saver.capture()

# Now check going backwards until running that the saved states match
Expand All @@ -345,7 +398,11 @@ def check_process_against_snapshots(loop, proc_class, snapshots):
break

compare_dictionaries(
snapshots[-j], saver.snapshots[-j], snapshots[-j], saver.snapshots[-j], exclude={'exception'}
snapshots[-j],
saver.snapshots[-j],
snapshots[-j],
saver.snapshots[-j],
exclude={'exception', '_listeners'}
)
j += 1

Expand Down Expand Up @@ -376,6 +433,11 @@ def compare_value(bundle1, bundle2, v1, v2, exclude=None):
elif isinstance(v1, list) and isinstance(v2, list):
for vv1, vv2 in zip(v1, v2):
compare_value(bundle1, bundle2, vv1, vv2, exclude)
elif isinstance(v1, set) and isinstance(v2, set) and len(v1) == len(v2) and len(v1) <= 1:
# TODO: implement sets with more than one element
compare_value(bundle1, bundle2, list(v1), list(v2), exclude)
elif isinstance(v1, set) and isinstance(v2, set):
raise NotImplementedError('Comparison between sets not implemented')
else:
if v1 != v2:
raise ValueError(f'Dict values mismatch for :\n{v1} != {v2}')
Expand Down

0 comments on commit a830eef

Please sign in to comment.