Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Message interface is hidden inside function call #301

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,15 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus':
result = await asyncio.wrap_future(future)
return result

async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult':
async def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult':
"""
Pause the process

:param pid: the pid of the process to pause
:param msg: optional pause message
:return: True if paused, False otherwise
"""
msg = MessageBuilder.pause(text=msg)
msg = MessageBuilder.pause(text=msg_text)

pause_future = self._communicator.rpc_send(pid, msg)
# rpc_send return a thread future from communicator
Expand All @@ -229,16 +229,15 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult':
result = await asyncio.wrap_future(future)
return result

async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult':
async def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult':
"""
Kill the process

:param pid: the pid of the process to kill
:param msg: optional kill message
:return: True if killed, False otherwise
"""
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(text=msg_text)

# Wait for the communication to go through
kill_future = self._communicator.rpc_send(pid, msg)
Expand Down Expand Up @@ -364,7 +363,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future:
"""
return self._communicator.rpc_send(pid, MessageBuilder.status())

def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future:
def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future:
"""
Pause the process

Expand All @@ -373,16 +372,17 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu
:return: a response future from the process to be paused

"""
msg = MessageBuilder.pause(text=msg)
msg = MessageBuilder.pause(text=msg_text)

return self._communicator.rpc_send(pid, msg)

def pause_all(self, msg: Any) -> None:
def pause_all(self, msg_text: Optional[str]) -> None:
"""
Pause all processes that are subscribed to the same communicator

:param msg: an optional pause message
"""
msg = MessageBuilder.pause(text=msg_text)
self._communicator.broadcast_send(msg, subject=Intent.PAUSE)

def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future:
Expand All @@ -401,28 +401,24 @@ def play_all(self) -> None:
"""
self._communicator.broadcast_send(None, subject=Intent.PLAY)

def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future:
def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None, force_kill: bool = False) -> kiwipy.Future:
unkcpz marked this conversation as resolved.
Show resolved Hide resolved
"""
Kill the process

:param pid: the pid of the process to kill
:param msg: optional kill message
:return: a response future from the process to be killed

"""
if msg is None:
msg = MessageBuilder.kill()

msg = MessageBuilder.kill(text=msg_text, force_kill=force_kill)
return self._communicator.rpc_send(pid, msg)

def kill_all(self, msg: Optional[MessageType]) -> None:
def kill_all(self, msg_text: Optional[str]) -> None:
"""
Kill all processes that are subscribed to the same communicator

:param msg: an optional pause message
"""
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(msg_text)

self._communicator.broadcast_send(msg, subject=Intent.KILL)

Expand Down
11 changes: 7 additions & 4 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,19 @@ class Interruption(Exception): # noqa: N818


class KillInterruption(Interruption):
def __init__(self, msg: MessageType | None):
def __init__(self, msg_text: str | None):
super().__init__()
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(text=msg_text)

self.msg: MessageType = msg


class PauseInterruption(Interruption):
pass
def __init__(self, msg_text: str | None):
super().__init__()
msg = MessageBuilder.pause(text=msg_text)

self.msg: MessageType = msg


# region Commands
Expand Down
43 changes: 25 additions & 18 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,7 @@ def init(self) -> None:

def try_killing(future: futures.Future) -> None:
if future.cancelled():
msg = MessageBuilder.kill(text='Killed by future being cancelled')
if not self.kill(msg):
if not self.kill('Killed by future being cancelled'):
self.logger.warning(
'Process<%s>: Failed to kill process on future cancel',
self.pid,
Expand Down Expand Up @@ -944,7 +943,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non

# region Communication

def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> Any:
def message_receive(self, _comm: kiwipy.Communicator, msg: MessageType) -> Any:
"""
Coroutine called when the process receives a message from the communicator

Expand All @@ -964,9 +963,9 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An
if intent == process_comms.Intent.PLAY:
return self._schedule_rpc(self.play)
if intent == process_comms.Intent.PAUSE:
return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None))
return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_KEY, None))
if intent == process_comms.Intent.KILL:
return self._schedule_rpc(self.kill, msg=msg)
return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_KEY, None))
if intent == process_comms.Intent.STATUS:
status_info: Dict[str, Any] = {}
self.get_status_info(status_info)
Expand All @@ -976,7 +975,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An
raise RuntimeError('Unknown intent')

def broadcast_receive(
self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any
self, _comm: kiwipy.Communicator, msg: MessageType, sender: Any, subject: Any, correlation_id: Any
) -> Optional[kiwipy.Future]:
"""
Coroutine called when the process receives a message from the communicator
Expand All @@ -990,16 +989,16 @@ def broadcast_receive(
self.pid,
subject,
_comm,
body,
msg,
)

# If we get a message we recognise then action it, otherwise ignore
if subject == process_comms.Intent.PLAY:
return self._schedule_rpc(self.play)
if subject == process_comms.Intent.PAUSE:
return self._schedule_rpc(self.pause, msg=body)
return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_KEY, None))
unkcpz marked this conversation as resolved.
Show resolved Hide resolved
if subject == process_comms.Intent.KILL:
return self._schedule_rpc(self.kill, msg=body)
return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_KEY, None))
return None

def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future:
Expand Down Expand Up @@ -1071,7 +1070,7 @@ def transition_failed(
)
self.transition_to(new_state)

def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]:
def pause(self, msg_text: Optional[str] = None) -> Union[bool, futures.CancellableAction]:
"""Pause the process.

:param msg: an optional message to set as the status. The current status will be saved in the private
Expand All @@ -1095,22 +1094,29 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable
if self._stepping:
# Ask the step function to pause by setting this flag and giving the
# caller back a future
interrupt_exception = process_states.PauseInterruption(msg)
interrupt_exception = process_states.PauseInterruption(msg_text)
self._set_interrupt_action_from_exception(interrupt_exception)
self._pausing = self._interrupt_action
# Try to interrupt the state
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

return self._do_pause(msg)
msg = MessageBuilder.pause(msg_text)
return self._do_pause(state_msg=msg)

def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool:
def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[process_states.State] = None) -> bool:
"""Carry out the pause procedure, optionally transitioning to the next state first"""
try:
if next_state is not None:
self.transition_to(next_state)
call_with_super_check(self.on_pausing, state_msg)
call_with_super_check(self.on_paused, state_msg)

if state_msg is None:
msg_text = ''
else:
msg_text = state_msg[MESSAGE_KEY]
unkcpz marked this conversation as resolved.
Show resolved Hide resolved

call_with_super_check(self.on_pausing, msg_text)
call_with_super_check(self.on_paused, msg_text)
finally:
self._pausing = None

Expand All @@ -1125,7 +1131,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu

"""
if isinstance(exception, process_states.PauseInterruption):
do_pause = functools.partial(self._do_pause, str(exception))
do_pause = functools.partial(self._do_pause, exception.msg)
unkcpz marked this conversation as resolved.
Show resolved Hide resolved
return futures.CancellableAction(do_pause, cookie=exception)

if isinstance(exception, process_states.KillInterruption):
Expand Down Expand Up @@ -1190,7 +1196,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
)
self.transition_to(new_state)

def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]:
"""
Kill the process
:param msg: An optional kill message
Expand All @@ -1210,12 +1216,13 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
if self._stepping:
# Ask the step function to pause by setting this flag and giving the
# caller back a future
interrupt_exception = process_states.KillInterruption(msg)
interrupt_exception = process_states.KillInterruption(msg_text)
self._set_interrupt_action_from_exception(interrupt_exception)
self._killing = self._interrupt_action
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

msg = MessageBuilder.kill(msg_text)
new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg)
self.transition_to(new_state)
return True
Expand Down
4 changes: 1 addition & 3 deletions tests/rmq/test_process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ async def test_kill_all(self, thread_communicator, sync_controller):
for _ in range(10):
procs.append(utils.WaitForSignalProcess(communicator=thread_communicator))

msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down')

sync_controller.kill_all(msg)
sync_controller.kill_all(msg_text='bang bang, I shot you down')
await utils.wait_util(lambda: all([proc.killed() for proc in procs]))
assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs])

Expand Down
8 changes: 4 additions & 4 deletions tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import plumpy
from plumpy import BundleKeys, Process, ProcessState
from plumpy.process_comms import MessageBuilder
from plumpy.process_comms import MESSAGE_KEY, MessageBuilder
from plumpy.utils import AttributesFrozendict
from tests import utils

Expand Down Expand Up @@ -322,10 +322,10 @@ def run(self, **kwargs):
def test_kill(self):
proc: Process = utils.DummyProcess()

msg = MessageBuilder.kill(text='Farewell!')
proc.kill(msg)
msg_text = 'Farewell!'
proc.kill(msg_text=msg_text)
self.assertTrue(proc.killed())
self.assertEqual(proc.killed_msg(), msg)
self.assertEqual(proc.killed_msg()[MESSAGE_KEY], msg_text)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this change 🤔 MESSAGE_KEY is imported and passed as key, why is that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is basically the same as previous test. The previous one build the msg which is a dict and I use it to assert. Now, the msg is not build, so I directly compare the msg_text. The different here is with new change, the proc.kill only require passing the msg_text, rather than construct the message type first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed in the office, let's make it this way:

Suggested change
self.assertEqual(proc.killed_msg()[MESSAGE_KEY], msg_text)
self.assertEqual(proc.killed_msg()[MessageBuilder.MESSAGE_TEXT_KEY], msg_text)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@khsrali also prefer having directly using proc.killed_msg()['message'] which I don't agree with both.

I agree with changing the name to MESSAGE_TEXT_KEY can be a bit more clear. But I don't like putting this as a class attribute.

@sphuber what you think? It is mostly a matter of taste and we need the third one to settle it.

from plumpy.process_comms import MESSAGE_TEXT_KEY

msg = proc.killed_msg()
msg_text = msg[MESSAGE_TEXT_KEY]
from plumpy.process_comms import MessageBuilder

msg = proc.killed_msg()
msg_text = msg[MessageBuilder.MESSAGE_TEXT_KEY]
msg = proc.killed_msg()
msg_text = msg['message']

Copy link
Contributor

@khsrali khsrali Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is, why should these keys be outside of the only class that defines them? they are defined as MessageBuilder keys.

from plumpy.process_comms import MESSAGE_TEXT_KEY, FORCE_KILL_KEY, INTENT_KEY

msg = proc.killed_msg()
msg_text = msg[MESSAGE_TEXT_KEY]
intent = msg[INTENT_KEY]
kill_flag = msg[FORCE_KILL_KEY]

whilst you could:

from plumpy.process_comms import MessageBuilder as mb

msg = proc.killed_msg()
msg_text = msg[mb.MESSAGE_TEXT_KEY]
intent = msg[mb.INTENT_KEY]
kill_flag = msg[mb.FORCE_KILL_KEY]

or even:

msg = proc.killed_msg()
msg_text = msg['MESSAGE_TEXT']
intent = msg['INTENT']
kill_flag = msg['FORCE_KILL']

Copy link
Member Author

@unkcpz unkcpz Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

msg = proc.killed_msg()
msg_text = msg['MESSAGE_TEXT']

I think this is not good since the typo in the str won't be checked by the pre-commit or modern IDE. Meanwhile, we did struggle a lot by that the exception from plumpy is not well propagated to the aiida log which makes such error even hard to find and debug. Usually such exception can cause the test in aiida-core hang up rather than raising exceptions.

msg = proc.killed_msg()
msg_text = msg[mb.MESSAGE_TEXT_KEY]

mb can be quite confuse, it looks more like megabytes from first glimpse.

The proper solution will be defining in plumpy.process_comms a standard message protocol and hasing message defined as a dataclass then used in plumpy and se/de -serialize the message through for example msgpack. It is a generic pattern, which can be found such as in redis or bit-torrent.
But change the message to dataclass require also change the de/se function which is beyond the scope of this PR.

TBH, I think this is a very small problem and only the matter of taste. As you can see from https://github.com/aiidateam/aiida-core/pull/6668/files, the MESSAGE_TEXT_KEY is not used there but only in plumpy, it defined as a part of message protocol of current design (which using dict as the message body).

self.assertEqual(proc.state, ProcessState.KILLED)

def test_wait_continue(self):
Expand Down
Loading