Skip to content

Commit

Permalink
Adopt new message protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 20, 2024
1 parent 17f5d62 commit 582787e
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 62 deletions.
20 changes: 6 additions & 14 deletions src/plumpy/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@
from .utils import PID_TYPE

__all__ = [
'KILL_MSG',
'PAUSE_MSG',
'PLAY_MSG',
'STATUS_MSG',
'MessageBuilder',
'ProcessLauncher',
'create_continue_body',
'create_launch_body',
Expand All @@ -27,7 +24,7 @@
from .processes import Process

INTENT_KEY = 'intent'
MESSAGE_KEY = 'message'
MESSAGE_TEXT_KEY = 'message'
FORCE_KILL_KEY = 'force_kill'


Expand All @@ -40,11 +37,6 @@ class Intent:
STATUS: str = 'status'


PAUSE_MSG = {INTENT_KEY: Intent.PAUSE}
PLAY_MSG = {INTENT_KEY: Intent.PLAY}
KILL_MSG = {INTENT_KEY: Intent.KILL}
STATUS_MSG = {INTENT_KEY: Intent.STATUS}

TASK_KEY = 'task'
TASK_ARGS = 'args'
PERSIST_KEY = 'persist'
Expand Down Expand Up @@ -74,23 +66,23 @@ def play(cls, text: str | None = None) -> MessageType:
"""The play message send over coordinator."""
return {
INTENT_KEY: Intent.PLAY,
MESSAGE_KEY: text,
MESSAGE_TEXT_KEY: text,
}

@classmethod
def pause(cls, text: str | None = None) -> MessageType:
"""The pause message send over coordinator."""
return {
INTENT_KEY: Intent.PAUSE,
MESSAGE_KEY: text,
MESSAGE_TEXT_KEY: text,
}

@classmethod
def kill(cls, text: str | None = None, force_kill: bool = False) -> MessageType:
"""The kill message send over coordinator."""
return {
INTENT_KEY: Intent.KILL,
MESSAGE_KEY: text,
MESSAGE_TEXT_KEY: text,
FORCE_KILL_KEY: force_kill,
}

Expand All @@ -99,7 +91,7 @@ def status(cls, text: str | None = None) -> MessageType:
"""The status message send over coordinator."""
return {
INTENT_KEY: Intent.STATUS,
MESSAGE_KEY: text,
MESSAGE_TEXT_KEY: text,
}


Expand Down
11 changes: 7 additions & 4 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,19 @@ class Interruption(Exception): # noqa: N818


class KillInterruption(Interruption):
def __init__(self, msg: MessageType | None):
def __init__(self, msg_text: str | None):
super().__init__()
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(msg_text)

self.msg: MessageType = msg


class PauseInterruption(Interruption):
pass
def __init__(self, msg_text: str | None):
super().__init__()
msg = MessageBuilder.pause(text=msg_text)

self.msg: MessageType = msg


# region Commands
Expand Down
45 changes: 26 additions & 19 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from .base.utils import call_with_super_check, super_check
from .event_helper import EventHelper
from .futures import CancellableAction, capture_exceptions
from .message import MESSAGE_KEY, MessageBuilder, MessageType
from .message import MESSAGE_TEXT_KEY, MessageBuilder, MessageType
from .process_listener import ProcessListener
from .process_spec import ProcessSpec
from .utils import PID_TYPE, SAVED_STATE_TYPE, protected
Expand Down Expand Up @@ -339,8 +339,7 @@ def init(self) -> None:

def try_killing(future: asyncio.Future) -> None:
if future.cancelled():
msg = MessageBuilder.kill(text='Killed by future being cancelled')
if not self.kill(msg):
if not self.kill('Killed by future being cancelled'):
self.logger.warning(
'Process<%s>: Failed to kill process on future cancel',
self.pid,
Expand Down Expand Up @@ -901,7 +900,7 @@ def on_kill(self, msg: Optional[MessageType]) -> None:
if msg is None:
msg_txt = ''
else:
msg_txt = msg[MESSAGE_KEY] or ''
msg_txt = msg[MESSAGE_TEXT_KEY] or ''

self.set_status(msg_txt)
self.future().set_exception(exceptions.KilledError(msg_txt))
Expand Down Expand Up @@ -942,7 +941,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non

# region Communication

def message_receive(self, _comm: Coordinator, msg: Dict[str, Any]) -> Any:
def message_receive(self, _comm: Coordinator, msg: MessageType) -> Any:
"""
Coroutine called when the process receives a message from the communicator
Expand All @@ -962,9 +961,9 @@ def message_receive(self, _comm: Coordinator, msg: Dict[str, Any]) -> Any:
if intent == message.Intent.PLAY:
return self._schedule_rpc(self.play)
if intent == message.Intent.PAUSE:
return self._schedule_rpc(self.pause, msg=msg.get(message.MESSAGE_KEY, None))
return self._schedule_rpc(self.pause, msg_text=msg.get(MESSAGE_TEXT_KEY, None))
if intent == message.Intent.KILL:
return self._schedule_rpc(self.kill, msg=msg.get(message.MESSAGE_KEY, None))
return self._schedule_rpc(self.kill, msg_text=msg.get(MESSAGE_TEXT_KEY, None))
if intent == message.Intent.STATUS:
status_info: Dict[str, Any] = {}
self.get_status_info(status_info)
Expand All @@ -974,7 +973,7 @@ def message_receive(self, _comm: Coordinator, msg: Dict[str, Any]) -> Any:
raise RuntimeError('Unknown intent')

def broadcast_receive(
self, _comm: Coordinator, body: Any, sender: Any, subject: Any, correlation_id: Any
self, _comm: Coordinator, msg: MessageType, sender: Any, subject: Any, correlation_id: Any
) -> Optional[concurrent.futures.Future]:
"""
Coroutine called when the process receives a message from the communicator
Expand All @@ -988,16 +987,16 @@ def broadcast_receive(
self.pid,
subject,
_comm,
body,
msg,
)
# If we get a message we recognise then action it, otherwise ignore
fn = None
if subject == message.Intent.PLAY:
fn = self._schedule_rpc(self.play)
elif subject == message.Intent.PAUSE:
fn = self._schedule_rpc(self.pause, msg=body)
fn = self._schedule_rpc(self.pause, msg_text=msg.get(MESSAGE_TEXT_KEY, None))
elif subject == message.Intent.KILL:
fn = self._schedule_rpc(self.kill, msg=body)
fn = self._schedule_rpc(self.kill, msg_text=msg.get(MESSAGE_TEXT_KEY, None))

if fn is None:
self.logger.warning(
Expand Down Expand Up @@ -1078,7 +1077,7 @@ def transition_failed(
)
self.transition_to(new_state)

def pause(self, msg: Union[str, None] = None) -> Union[bool, CancellableAction]:
def pause(self, msg_text: str | None = None) -> Union[bool, CancellableAction]:
"""Pause the process.
:param msg: an optional message to set as the status. The current status will be saved in the private
Expand All @@ -1102,22 +1101,29 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, CancellableAction]:
if self._stepping:
# Ask the step function to pause by setting this flag and giving the
# caller back a future
interrupt_exception = process_states.PauseInterruption(msg)
interrupt_exception = process_states.PauseInterruption(msg_text)
self._set_interrupt_action_from_exception(interrupt_exception)
self._pausing = self._interrupt_action
# Try to interrupt the state
self._state.interrupt(interrupt_exception)
return cast(CancellableAction, self._interrupt_action)

msg = MessageBuilder.pause(msg_text)
return self._do_pause(msg)

def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool:
def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[process_states.State] = None) -> bool:
"""Carry out the pause procedure, optionally transitioning to the next state first"""
try:
if next_state is not None:
self.transition_to(next_state)
call_with_super_check(self.on_pausing, state_msg)
call_with_super_check(self.on_paused, state_msg)

if state_msg is None:
msg_text = ''
else:
msg_text = state_msg[MESSAGE_TEXT_KEY]

call_with_super_check(self.on_pausing, msg_text)
call_with_super_check(self.on_paused, msg_text)
finally:
self._pausing = None

Expand All @@ -1132,7 +1138,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> Ca
"""
if isinstance(exception, process_states.PauseInterruption):
do_pause = functools.partial(self._do_pause, str(exception))
do_pause = functools.partial(self._do_pause, exception.msg)
return CancellableAction(do_pause, cookie=exception)

if isinstance(exception, process_states.KillInterruption):
Expand Down Expand Up @@ -1197,7 +1203,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
)
self.transition_to(new_state)

def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]:
"""
Kill the process
:param msg: An optional kill message
Expand All @@ -1217,12 +1223,13 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
if self._stepping:
# Ask the step function to pause by setting this flag and giving the
# caller back a future
interrupt_exception = process_states.KillInterruption(msg)
interrupt_exception = process_states.KillInterruption(msg_text)
self._set_interrupt_action_from_exception(interrupt_exception)
self._killing = self._interrupt_action
self._state.interrupt(interrupt_exception)
return cast(CancellableAction, self._interrupt_action)

msg = MessageBuilder.kill(msg_text)
new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg)
self.transition_to(new_state)
return True
Expand Down
32 changes: 14 additions & 18 deletions src/plumpy/rmq/process_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus':
result = await asyncio.wrap_future(future)
return result

async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult':
async def pause_process(self, pid: 'PID_TYPE', msg: Optional[str] = None) -> 'ProcessResult':
"""
Pause the process
Expand Down Expand Up @@ -77,16 +77,15 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult':
result = await asyncio.wrap_future(future)
return result

async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult':
async def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult':
"""
Kill the process
:param pid: the pid of the process to kill
:param msg: optional kill message
:return: True if killed, False otherwise
"""
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(msg_text)

# Wait for the communication to go through
kill_future = self._coordinator.rpc_send(pid, msg)
Expand Down Expand Up @@ -212,7 +211,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future:
"""
return self._coordinator.rpc_send(pid, MessageBuilder.status())

def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future:
def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future:
"""
Pause the process
Expand All @@ -221,16 +220,18 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu
:return: a response future from the process to be paused
"""
msg = MessageBuilder.pause(text=msg)
msg = MessageBuilder.pause(text=msg_text)

return self._coordinator.rpc_send(pid, msg)

def pause_all(self, msg: Any) -> None:
def pause_all(self, msg_text: Optional[str]) -> None:
"""
Pause all processes that are subscribed to the same coordinator
:param msg: an optional pause message
"""
msg = MessageBuilder.pause(text=msg_text)

self._coordinator.broadcast_send(msg, subject=Intent.PAUSE)

def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future:
Expand All @@ -249,7 +250,7 @@ def play_all(self) -> None:
"""
self._coordinator.broadcast_send(None, subject=Intent.PLAY)

def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future:
def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future:
"""
Kill the process
Expand All @@ -258,31 +259,26 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> ki
:return: a response future from the process to be killed
"""
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(text=msg_text)

return self._coordinator.rpc_send(pid, msg)

def kill_all(self, msg: Optional[MessageType]) -> None:
def kill_all(self, msg_text: Optional[str]) -> None:
"""
Kill all processes that are subscribed to the same coordinator
:param msg: an optional pause message
"""
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(msg_text)

self._coordinator.broadcast_send(msg, subject=Intent.KILL)

def notify_all(self, msg: MessageType | None, sender: Hashable | None = None, subject: str | None = None) -> None:
def notify_msg(self, msg: MessageType, sender: Hashable | None = None, subject: str | None = None) -> None:
"""
Notify all processes by broadcasting
Notify all processes by broadcasting of a msg
:param msg: an optional pause message
"""
if msg is None:
msg = MessageBuilder.kill()

self._coordinator.broadcast_send(msg, sender=sender, subject=subject)

def continue_process(
Expand Down
4 changes: 1 addition & 3 deletions tests/rmq/test_process_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ async def test_kill_all(self, _coordinator, sync_controller):
for _ in range(10):
procs.append(utils.WaitForSignalProcess(coordinator=_coordinator))

msg = process_control.MessageBuilder.kill(text='bang bang, I shot you down')

sync_controller.kill_all(msg)
sync_controller.kill_all(msg_text='bang bang, I shot you down')
await utils.wait_util(lambda: all([proc.killed() for proc in procs]))
assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs])

Expand Down
8 changes: 4 additions & 4 deletions tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import plumpy
from plumpy import BundleKeys, Process, ProcessState
from plumpy.message import MessageBuilder
from plumpy.message import MESSAGE_TEXT_KEY, MessageBuilder
from plumpy.utils import AttributesFrozendict
from . import utils

Expand Down Expand Up @@ -322,10 +322,10 @@ def run(self, **kwargs):
def test_kill(self):
proc: Process = utils.DummyProcess()

msg = MessageBuilder.kill(text='Farewell!')
proc.kill(msg)
msg_text = 'Farewell!'
proc.kill(msg_text=msg_text)
self.assertTrue(proc.killed())
self.assertEqual(proc.killed_msg(), msg)
self.assertEqual(proc.killed_msg()[MESSAGE_TEXT_KEY], msg_text)
self.assertEqual(proc.state, ProcessState.KILLED)

def test_wait_continue(self):
Expand Down

0 comments on commit 582787e

Please sign in to comment.