From 41e7833033a4c4e29c5a6b2bbef95c4f5c14d193 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 17 Dec 2024 12:17:10 +0100 Subject: [PATCH] Interface change from communicator -> coordinator --- src/plumpy/coordinator.py | 13 +++++++++++++ src/plumpy/processes.py | 14 +++++++------- src/plumpy/workchains.py | 6 +++--- tests/rmq/test_process_comms.py | 24 ++++++++++++------------ tests/test_processes.py | 2 +- 5 files changed, 36 insertions(+), 23 deletions(-) diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py index 1daaf1f8..cd66a883 100644 --- a/src/plumpy/coordinator.py +++ b/src/plumpy/coordinator.py @@ -17,3 +17,16 @@ def remove_rpc_subscriber(self, identifier): ... def remove_broadcast_subscriber(self, identifier): ... def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ... + +class Coordinator(Protocol): + def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any: ... + + def add_broadcast_subscriber( + self, subscriber: BroadcastSubscriber, subject_filter: str | Pattern[str] | None = None, identifier=None + ) -> Any: ... + + def remove_rpc_subscriber(self, identifier): ... + + def remove_broadcast_subscriber(self, identifier): ... + + def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ... diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 9cf3302b..47506b3f 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -32,7 +32,7 @@ cast, ) -from plumpy.coordinator import Communicator +from plumpy.coordinator import Coordinator try: from aiocontextvars import ContextVar @@ -266,7 +266,7 @@ def __init__( pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[Communicator] = None, + coordinator: Optional[Coordinator] = None, ) -> None: """ The signature of the constructor should not be changed by subclassing processes. @@ -305,7 +305,7 @@ def __init__( self._future = persistence.SavableFuture(loop=self._loop) self._event_helper = EventHelper(ProcessListener) self._logger = logger - self._communicator = communicator + self._communicator = coordinator @super_check def init(self) -> None: @@ -449,7 +449,7 @@ def launch( pid=pid, logger=logger, loop=self.loop, - communicator=self._communicator, + coordinator=self._communicator, ) self.loop.create_task(process.step_until_terminated()) return process @@ -654,7 +654,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self._state: process_states.State = self.recreate_state(saved_state['_state']) - if 'communicator' in load_context: + if 'coordinator' in load_context: self._communicator = load_context.communicator if 'logger' in load_context: @@ -938,7 +938,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non # region Communication - def message_receive(self, _comm: Communicator, msg: Dict[str, Any]) -> Any: + def message_receive(self, _comm: Coordinator, msg: Dict[str, Any]) -> Any: """ Coroutine called when the process receives a message from the communicator @@ -970,7 +970,7 @@ def message_receive(self, _comm: Communicator, msg: Dict[str, Any]) -> Any: raise RuntimeError('Unknown intent') def broadcast_receive( - self, _comm: Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any + self, _comm: Coordinator, body: Any, sender: Any, subject: Any, correlation_id: Any ) -> Optional[concurrent.futures.Future]: """ Coroutine called when the process receives a message from the communicator diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 7e67253f..5df20bf4 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -23,7 +23,7 @@ cast, ) -from plumpy.coordinator import Communicator +from plumpy.coordinator import Coordinator from . import lang, mixins, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE @@ -128,9 +128,9 @@ def __init__( pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[Communicator] = None, + coordinator: Optional[Coordinator] = None, ) -> None: - super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, communicator=communicator) + super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, coordinator=coordinator) self._stepper: Optional[Stepper] = None self._awaitables: Dict[Union[asyncio.Future, processes.Process], str] = {} diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index 4af9a484..9de211ee 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -45,7 +45,7 @@ def sync_controller(thread_communicator: rmq.RmqThreadCommunicator): class TestRemoteProcessController: @pytest.mark.asyncio async def test_pause(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) # Send a pause message @@ -57,7 +57,7 @@ async def test_pause(self, thread_communicator, async_controller): @pytest.mark.asyncio async def test_play(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) assert proc.pause() @@ -75,7 +75,7 @@ async def test_play(self, thread_communicator, async_controller): @pytest.mark.asyncio async def test_kill(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Run the process in the event loop asyncio.ensure_future(proc.step_until_terminated()) @@ -88,7 +88,7 @@ async def test_kill(self, thread_communicator, async_controller): @pytest.mark.asyncio async def test_status(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) @@ -108,7 +108,7 @@ def on_broadcast_receive(**msg): thread_communicator.add_broadcast_subscriber(on_broadcast_receive) - proc = utils.DummyProcess(communicator=thread_communicator) + proc = utils.DummyProcess(coordinator=thread_communicator) proc.execute() expected_subjects = [] @@ -123,7 +123,7 @@ def on_broadcast_receive(**msg): class TestRemoteProcessThreadController: @pytest.mark.asyncio async def test_pause(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Send a pause message pause_future = sync_controller.pause_process(proc.pid) @@ -140,7 +140,7 @@ async def test_pause_all(self, thread_communicator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) + procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator)) sync_controller.pause_all("Slow yo' roll") # Wait until they are all paused @@ -151,7 +151,7 @@ async def test_play_all(self, thread_communicator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) procs.append(proc) proc.pause('hold tight') @@ -162,7 +162,7 @@ async def test_play_all(self, thread_communicator, sync_controller): @pytest.mark.asyncio async def test_play(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) assert proc.pause() # Send a play message @@ -176,7 +176,7 @@ async def test_play(self, thread_communicator, sync_controller): @pytest.mark.asyncio async def test_kill(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Send a kill message kill_future = sync_controller.kill_process(proc.pid) @@ -193,7 +193,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) + procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator)) msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down') @@ -203,7 +203,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): @pytest.mark.asyncio async def test_status(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) diff --git a/tests/test_processes.py b/tests/test_processes.py index 7b232689..927d42b3 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -1075,7 +1075,7 @@ def on_broadcast_receive(_comm, body, sender, subject, correlation_id): messages.append({'body': body, 'subject': subject, 'sender': sender, 'correlation_id': correlation_id}) communicator.add_broadcast_subscriber(on_broadcast_receive) - proc = utils.DummyProcess(communicator=communicator) + proc = utils.DummyProcess(coordinator=communicator) proc.execute() expected_subjects = []