diff --git a/src/plumpy/message.py b/src/plumpy/message.py index b2c66d50..009f1b26 100644 --- a/src/plumpy/message.py +++ b/src/plumpy/message.py @@ -14,10 +14,7 @@ from .utils import PID_TYPE __all__ = [ - 'KILL_MSG', - 'PAUSE_MSG', - 'PLAY_MSG', - 'STATUS_MSG', + 'MessageBuilder', 'ProcessLauncher', 'create_continue_body', 'create_launch_body', @@ -27,7 +24,7 @@ from .processes import Process INTENT_KEY = 'intent' -MESSAGE_KEY = 'message' +MESSAGE_TEXT_KEY = 'message' FORCE_KILL_KEY = 'force_kill' @@ -40,11 +37,6 @@ class Intent: STATUS: str = 'status' -PAUSE_MSG = {INTENT_KEY: Intent.PAUSE} -PLAY_MSG = {INTENT_KEY: Intent.PLAY} -KILL_MSG = {INTENT_KEY: Intent.KILL} -STATUS_MSG = {INTENT_KEY: Intent.STATUS} - TASK_KEY = 'task' TASK_ARGS = 'args' PERSIST_KEY = 'persist' @@ -74,7 +66,7 @@ def play(cls, text: str | None = None) -> MessageType: """The play message send over coordinator.""" return { INTENT_KEY: Intent.PLAY, - MESSAGE_KEY: text, + MESSAGE_TEXT_KEY: text, } @classmethod @@ -82,7 +74,7 @@ def pause(cls, text: str | None = None) -> MessageType: """The pause message send over coordinator.""" return { INTENT_KEY: Intent.PAUSE, - MESSAGE_KEY: text, + MESSAGE_TEXT_KEY: text, } @classmethod @@ -90,7 +82,7 @@ def kill(cls, text: str | None = None, force_kill: bool = False) -> MessageType: """The kill message send over coordinator.""" return { INTENT_KEY: Intent.KILL, - MESSAGE_KEY: text, + MESSAGE_TEXT_KEY: text, FORCE_KILL_KEY: force_kill, } @@ -99,7 +91,7 @@ def status(cls, text: str | None = None) -> MessageType: """The status message send over coordinator.""" return { INTENT_KEY: Intent.STATUS, - MESSAGE_KEY: text, + MESSAGE_TEXT_KEY: text, } diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 48f73f77..7b5952ab 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -52,16 +52,19 @@ class Interruption(Exception): # noqa: N818 class KillInterruption(Interruption): - def __init__(self, msg: MessageType | None): + def __init__(self, msg_text: str | None): super().__init__() - if msg is None: - msg = MessageBuilder.kill() + msg = MessageBuilder.kill(msg_text) self.msg: MessageType = msg class PauseInterruption(Interruption): - pass + def __init__(self, msg_text: str | None): + super().__init__() + msg = MessageBuilder.pause(text=msg_text) + + self.msg: MessageType = msg # region Commands diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 52cf875b..feeb3c4e 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -49,7 +49,7 @@ from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper from .futures import CancellableAction, capture_exceptions -from .message import MESSAGE_KEY, MessageBuilder, MessageType +from .message import MESSAGE_TEXT_KEY, MessageBuilder, MessageType from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected @@ -339,8 +339,7 @@ def init(self) -> None: def try_killing(future: asyncio.Future) -> None: if future.cancelled(): - msg = MessageBuilder.kill(text='Killed by future being cancelled') - if not self.kill(msg): + if not self.kill('Killed by future being cancelled'): self.logger.warning( 'Process<%s>: Failed to kill process on future cancel', self.pid, @@ -901,7 +900,7 @@ def on_kill(self, msg: Optional[MessageType]) -> None: if msg is None: msg_txt = '' else: - msg_txt = msg[MESSAGE_KEY] or '' + msg_txt = msg[MESSAGE_TEXT_KEY] or '' self.set_status(msg_txt) self.future().set_exception(exceptions.KilledError(msg_txt)) @@ -942,7 +941,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non # region Communication - def message_receive(self, _comm: Coordinator, msg: Dict[str, Any]) -> Any: + def message_receive(self, _comm: Coordinator, msg: MessageType) -> Any: """ Coroutine called when the process receives a message from the communicator @@ -962,9 +961,9 @@ def message_receive(self, _comm: Coordinator, msg: Dict[str, Any]) -> Any: if intent == message.Intent.PLAY: return self._schedule_rpc(self.play) if intent == message.Intent.PAUSE: - return self._schedule_rpc(self.pause, msg=msg.get(message.MESSAGE_KEY, None)) + return self._schedule_rpc(self.pause, msg_text=msg.get(MESSAGE_TEXT_KEY, None)) if intent == message.Intent.KILL: - return self._schedule_rpc(self.kill, msg=msg.get(message.MESSAGE_KEY, None)) + return self._schedule_rpc(self.kill, msg_text=msg.get(MESSAGE_TEXT_KEY, None)) if intent == message.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) @@ -974,7 +973,7 @@ def message_receive(self, _comm: Coordinator, msg: Dict[str, Any]) -> Any: raise RuntimeError('Unknown intent') def broadcast_receive( - self, _comm: Coordinator, body: Any, sender: Any, subject: Any, correlation_id: Any + self, _comm: Coordinator, msg: MessageType, sender: Any, subject: Any, correlation_id: Any ) -> Optional[concurrent.futures.Future]: """ Coroutine called when the process receives a message from the communicator @@ -988,16 +987,16 @@ def broadcast_receive( self.pid, subject, _comm, - body, + msg, ) # If we get a message we recognise then action it, otherwise ignore fn = None if subject == message.Intent.PLAY: fn = self._schedule_rpc(self.play) elif subject == message.Intent.PAUSE: - fn = self._schedule_rpc(self.pause, msg=body) + fn = self._schedule_rpc(self.pause, msg_text=msg.get(MESSAGE_TEXT_KEY, None)) elif subject == message.Intent.KILL: - fn = self._schedule_rpc(self.kill, msg=body) + fn = self._schedule_rpc(self.kill, msg_text=msg.get(MESSAGE_TEXT_KEY, None)) if fn is None: self.logger.warning( @@ -1078,7 +1077,7 @@ def transition_failed( ) self.transition_to(new_state) - def pause(self, msg: Union[str, None] = None) -> Union[bool, CancellableAction]: + def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]: """Pause the process. :param msg: an optional message to set as the status. The current status will be saved in the private @@ -1102,22 +1101,29 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, CancellableAction]: if self._stepping: # Ask the step function to pause by setting this flag and giving the # caller back a future - interrupt_exception = process_states.PauseInterruption(msg) + interrupt_exception = process_states.PauseInterruption(msg_text) self._set_interrupt_action_from_exception(interrupt_exception) self._pausing = self._interrupt_action # Try to interrupt the state self._state.interrupt(interrupt_exception) return cast(CancellableAction, self._interrupt_action) + msg = MessageBuilder.pause(msg_text) return self._do_pause(msg) - def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool: + def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[process_states.State] = None) -> bool: """Carry out the pause procedure, optionally transitioning to the next state first""" try: if next_state is not None: self.transition_to(next_state) - call_with_super_check(self.on_pausing, state_msg) - call_with_super_check(self.on_paused, state_msg) + + if state_msg is None: + msg_text = '' + else: + msg_text = state_msg[MESSAGE_TEXT_KEY] + + call_with_super_check(self.on_pausing, msg_text) + call_with_super_check(self.on_paused, msg_text) finally: self._pausing = None @@ -1132,7 +1138,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> Ca """ if isinstance(exception, process_states.PauseInterruption): - do_pause = functools.partial(self._do_pause, str(exception)) + do_pause = functools.partial(self._do_pause, exception.msg) return CancellableAction(do_pause, cookie=exception) if isinstance(exception, process_states.KillInterruption): @@ -1197,7 +1203,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac ) self.transition_to(new_state) - def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: + def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: """ Kill the process :param msg: An optional kill message @@ -1217,12 +1223,13 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] if self._stepping: # Ask the step function to pause by setting this flag and giving the # caller back a future - interrupt_exception = process_states.KillInterruption(msg) + interrupt_exception = process_states.KillInterruption(msg_text) self._set_interrupt_action_from_exception(interrupt_exception) self._killing = self._interrupt_action self._state.interrupt(interrupt_exception) return cast(CancellableAction, self._interrupt_action) + msg = MessageBuilder.kill(msg_text) new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg) self.transition_to(new_state) return True diff --git a/src/plumpy/rmq/process_control.py b/src/plumpy/rmq/process_control.py index fe040aa5..74595dec 100644 --- a/src/plumpy/rmq/process_control.py +++ b/src/plumpy/rmq/process_control.py @@ -48,7 +48,7 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus': result = await asyncio.wrap_future(future) return result - async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult': + async def pause_process(self, pid: 'PID_TYPE', msg: Optional[str] = None) -> 'ProcessResult': """ Pause the process @@ -77,7 +77,7 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult': result = await asyncio.wrap_future(future) return result - async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult': + async def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult': """ Kill the process @@ -85,8 +85,7 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) :param msg: optional kill message :return: True if killed, False otherwise """ - if msg is None: - msg = MessageBuilder.kill() + msg = MessageBuilder.kill(msg_text) # Wait for the communication to go through kill_future = self._coordinator.rpc_send(pid, msg) @@ -212,7 +211,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future: """ return self._coordinator.rpc_send(pid, MessageBuilder.status()) - def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: + def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future: """ Pause the process @@ -221,16 +220,18 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu :return: a response future from the process to be paused """ - msg = MessageBuilder.pause(text=msg) + msg = MessageBuilder.pause(text=msg_text) return self._coordinator.rpc_send(pid, msg) - def pause_all(self, msg: Any) -> None: + def pause_all(self, msg_text: Optional[str]) -> None: """ Pause all processes that are subscribed to the same coordinator :param msg: an optional pause message """ + msg = MessageBuilder.pause(text=msg_text) + self._coordinator.broadcast_send(msg, subject=Intent.PAUSE) def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future: @@ -249,7 +250,7 @@ def play_all(self) -> None: """ self._coordinator.broadcast_send(None, subject=Intent.PLAY) - def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future: + def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future: """ Kill the process @@ -258,31 +259,26 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> ki :return: a response future from the process to be killed """ - if msg is None: - msg = MessageBuilder.kill() + msg = MessageBuilder.kill(text=msg_text) return self._coordinator.rpc_send(pid, msg) - def kill_all(self, msg: Optional[MessageType]) -> None: + def kill_all(self, msg_text: Optional[str]) -> None: """ Kill all processes that are subscribed to the same coordinator :param msg: an optional pause message """ - if msg is None: - msg = MessageBuilder.kill() + msg = MessageBuilder.kill(msg_text) self._coordinator.broadcast_send(msg, subject=Intent.KILL) - def notify_all(self, msg: MessageType | None, sender: Hashable | None = None, subject: str | None = None) -> None: + def notify_msg(self, msg: MessageType, sender: Hashable | None = None, subject: str | None = None) -> None: """ - Notify all processes by broadcasting + Notify all processes by broadcasting of a msg :param msg: an optional pause message """ - if msg is None: - msg = MessageBuilder.kill() - self._coordinator.broadcast_send(msg, sender=sender, subject=subject) def continue_process( diff --git a/tests/rmq/test_process_control.py b/tests/rmq/test_process_control.py index 9e917b37..672995c8 100644 --- a/tests/rmq/test_process_control.py +++ b/tests/rmq/test_process_control.py @@ -195,9 +195,7 @@ async def test_kill_all(self, _coordinator, sync_controller): for _ in range(10): procs.append(utils.WaitForSignalProcess(coordinator=_coordinator)) - msg = process_control.MessageBuilder.kill(text='bang bang, I shot you down') - - sync_controller.kill_all(msg) + sync_controller.kill_all(msg_text='bang bang, I shot you down') await utils.wait_util(lambda: all([proc.killed() for proc in procs])) assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs]) diff --git a/tests/test_processes.py b/tests/test_processes.py index 52a456dc..a05d09a3 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -10,7 +10,7 @@ import plumpy from plumpy import BundleKeys, Process, ProcessState -from plumpy.message import MessageBuilder +from plumpy.message import MESSAGE_TEXT_KEY, MessageBuilder from plumpy.utils import AttributesFrozendict from . import utils @@ -322,10 +322,10 @@ def run(self, **kwargs): def test_kill(self): proc: Process = utils.DummyProcess() - msg = MessageBuilder.kill(text='Farewell!') - proc.kill(msg) + msg_text = 'Farewell!' + proc.kill(msg_text=msg_text) self.assertTrue(proc.killed()) - self.assertEqual(proc.killed_msg(), msg) + self.assertEqual(proc.killed_msg()[MESSAGE_TEXT_KEY], msg_text) self.assertEqual(proc.state, ProcessState.KILLED) def test_wait_continue(self):