Skip to content

Commit

Permalink
Message builder for constructing message with carrying more informati…
Browse files Browse the repository at this point in the history
…on (#291)

The messages to passing over rabbitmq for process control is build dynamically and able to carry more information. In the old implementation, the messages ace global dictionary variables and when the message need to change by copy which is error-prone. 
This commit introduce the `MessageBuilder` with class methods for creating kill/pause/status/play messages. For "kill" message, I also add support for passing the `force_kill` option.
  • Loading branch information
unkcpz authored Dec 13, 2024
1 parent 4611154 commit f760b4a
Show file tree
Hide file tree
Showing 13 changed files with 273 additions and 119 deletions.
4 changes: 2 additions & 2 deletions docs/source/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@
" def continue_fn(self):\n",
" print('continuing')\n",
" # message is stored in the process status\n",
" return plumpy.Kill('I was killed')\n",
" return plumpy.Kill(plumpy.MessageBuilder.kill('I was killed'))\n",
"\n",
"\n",
"process = ContinueProcess()\n",
Expand Down Expand Up @@ -1118,7 +1118,7 @@
"\n",
"process = SimpleProcess(communicator=communicator)\n",
"\n",
"pprint(communicator.rpc_send(str(process.pid), plumpy.STATUS_MSG).result())"
"pprint(communicator.rpc_send(str(process.pid), plumpy.MessageBuilder.status()).result())"
]
},
{
Expand Down
70 changes: 45 additions & 25 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
# -*- coding: utf-8 -*-
"""The state machine for processes"""

from __future__ import annotations

import enum
import functools
import inspect
import logging
import os
import sys
from types import TracebackType
from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Sequence, Set, Type, Union, cast
from typing import (
Any,
Callable,
Dict,
Hashable,
Iterable,
List,
Optional,
Sequence,
Set,
Type,
Union,
)

from plumpy.futures import Future

Expand All @@ -31,7 +45,7 @@ class StateEntryFailed(Exception): # noqa: N818
Failed to enter a state, can provide the next state to go to via this exception
"""

def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None:
def __init__(self, state: State, *args: Any, **kwargs: Any) -> None:
super().__init__('failed to enter state')
self.state = state
self.args = args
Expand Down Expand Up @@ -187,7 +201,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine':
:param kwargs: Any keyword arguments to be passed to the constructor
:return: An instance of the state machine
"""
inst = super().__call__(*args, **kwargs)
inst: StateMachine = super().__call__(*args, **kwargs)
inst.transition_to(inst.create_initial_state())
call_with_super_check(inst.init)
return inst
Expand Down Expand Up @@ -300,16 +314,25 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None:
def on_terminated(self) -> None:
"""Called when a terminal state is entered"""

def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None:
def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
"""Transite to the new state.
The new target state will be create lazily when the state is not yet instantiated,
which will happened for states not in the expect path such as pause and kill.
The arguments are passed to the state class to create state instance.
(process arg does not need to pass since it will always call with 'self' as process)
"""
assert not self._transitioning, 'Cannot call transition_to when already transitioning state'

if new_state is None:
# early return if the new state is `None`
# it can happened when transit from terminal state
return None

initial_state_label = self._state.LABEL if self._state is not None else None
label = None
try:
self._transitioning = True

# Make sure we have a state instance
new_state = self._create_state_instance(new_state, *args, **kwargs)
label = new_state.LABEL

# If the previous transition failed, do not try to exit it but go straight to next state
Expand All @@ -319,8 +342,7 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A
try:
self._enter_next_state(new_state)
except StateEntryFailed as exception:
# Make sure we have a state instance
new_state = self._create_state_instance(exception.state, *exception.args, **exception.kwargs)
new_state = exception.state
label = new_state.LABEL
self._exit_current_state(new_state)
self._enter_next_state(new_state)
Expand All @@ -338,7 +360,11 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A
self._transitioning = False

def transition_failed(
self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType
self,
initial_state: Hashable,
final_state: Hashable,
exception: Exception,
trace: TracebackType,
) -> None:
"""Called when a state transitions fails.
Expand All @@ -355,6 +381,10 @@ def set_debug(self, enabled: bool) -> None:
self._debug: bool = enabled

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
# XXX: this method create state from label, which is duplicate as _create_state_instance and less generic
# because the label is defined after the state and required to be know before calling this function.
# This method should be replaced by `_create_state_instance`.
# aiida-core using this method for its Waiting state override.
try:
return self.get_states_map()[state_label](self, *args, **kwargs)
except KeyError:
Expand Down Expand Up @@ -383,20 +413,10 @@ def _enter_next_state(self, next_state: State) -> None:
self._state = next_state
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

def _create_state_instance(self, state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> State:
if isinstance(state, State):
# It's already a state instance
return state

# OK, have to create it
state_cls = self._ensure_state_class(state)
return state_cls(self, *args, **kwargs)
def _create_state_instance(self, state_cls: Hashable, **kwargs: Any) -> State:
if state_cls not in self.get_states_map():
raise ValueError(f'{state_cls} is not a valid state')

def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State]:
if inspect.isclass(state) and issubclass(state, State):
return state
cls = self.get_states_map()[state_cls]

try:
return self.get_states_map()[cast(Hashable, state)]
except KeyError:
raise ValueError(f'{state} is not a valid state')
return cls(self, **kwargs)
95 changes: 63 additions & 32 deletions src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
"""Module for process level communication functions and classes"""

from __future__ import annotations

import asyncio
import copy
import logging
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast

Expand All @@ -12,10 +13,7 @@
from .utils import PID_TYPE

__all__ = [
'KILL_MSG',
'PAUSE_MSG',
'PLAY_MSG',
'STATUS_MSG',
'MessageBuilder',
'ProcessLauncher',
'RemoteProcessController',
'RemoteProcessThreadController',
Expand All @@ -31,6 +29,7 @@

INTENT_KEY = 'intent'
MESSAGE_KEY = 'message'
FORCE_KILL_KEY = 'force_kill'


class Intent:
Expand All @@ -42,10 +41,45 @@ class Intent:
STATUS: str = 'status'


PAUSE_MSG = {INTENT_KEY: Intent.PAUSE}
PLAY_MSG = {INTENT_KEY: Intent.PLAY}
KILL_MSG = {INTENT_KEY: Intent.KILL}
STATUS_MSG = {INTENT_KEY: Intent.STATUS}
MessageType = Dict[str, Any]


class MessageBuilder:
"""MessageBuilder will construct different messages that can passing over communicator."""

@classmethod
def play(cls, text: str | None = None) -> MessageType:
"""The play message send over communicator."""
return {
INTENT_KEY: Intent.PLAY,
MESSAGE_KEY: text,
}

@classmethod
def pause(cls, text: str | None = None) -> MessageType:
"""The pause message send over communicator."""
return {
INTENT_KEY: Intent.PAUSE,
MESSAGE_KEY: text,
}

@classmethod
def kill(cls, text: str | None = None, force_kill: bool = False) -> MessageType:
"""The kill message send over communicator."""
return {
INTENT_KEY: Intent.KILL,
MESSAGE_KEY: text,
FORCE_KILL_KEY: force_kill,
}

@classmethod
def status(cls, text: str | None = None) -> MessageType:
"""The status message send over communicator."""
return {
INTENT_KEY: Intent.STATUS,
MESSAGE_KEY: text,
}


TASK_KEY = 'task'
TASK_ARGS = 'args'
Expand Down Expand Up @@ -162,7 +196,7 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus':
:param pid: the process id
:return: the status response from the process
"""
future = self._communicator.rpc_send(pid, STATUS_MSG)
future = self._communicator.rpc_send(pid, MessageBuilder.status())
result = await asyncio.wrap_future(future)
return result

Expand All @@ -174,11 +208,9 @@ async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pr
:param msg: optional pause message
:return: True if paused, False otherwise
"""
message = copy.copy(PAUSE_MSG)
if msg is not None:
message[MESSAGE_KEY] = msg
msg = MessageBuilder.pause(text=msg)

pause_future = self._communicator.rpc_send(pid, message)
pause_future = self._communicator.rpc_send(pid, msg)
# rpc_send return a thread future from communicator
future = await asyncio.wrap_future(pause_future)
# future is just returned from rpc call which return a kiwipy future
Expand All @@ -192,25 +224,24 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult':
:param pid: the pid of the process to play
:return: True if played, False otherwise
"""
play_future = self._communicator.rpc_send(pid, PLAY_MSG)
play_future = self._communicator.rpc_send(pid, MessageBuilder.play())
future = await asyncio.wrap_future(play_future)
result = await asyncio.wrap_future(future)
return result

async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult':
async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult':
"""
Kill the process
:param pid: the pid of the process to kill
:param msg: optional kill message
:return: True if killed, False otherwise
"""
message = copy.copy(KILL_MSG)
if msg is not None:
message[MESSAGE_KEY] = msg
if msg is None:
msg = MessageBuilder.kill()

# Wait for the communication to go through
kill_future = self._communicator.rpc_send(pid, message)
kill_future = self._communicator.rpc_send(pid, msg)
future = await asyncio.wrap_future(kill_future)
# Now wait for the kill to be enacted
result = await asyncio.wrap_future(future)
Expand Down Expand Up @@ -331,7 +362,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future:
:param pid: the process id
:return: the status response from the process
"""
return self._communicator.rpc_send(pid, STATUS_MSG)
return self._communicator.rpc_send(pid, MessageBuilder.status())

def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future:
"""
Expand All @@ -342,11 +373,9 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu
:return: a response future from the process to be paused
"""
message = copy.copy(PAUSE_MSG)
if msg is not None:
message[MESSAGE_KEY] = msg
msg = MessageBuilder.pause(text=msg)

return self._communicator.rpc_send(pid, message)
return self._communicator.rpc_send(pid, msg)

def pause_all(self, msg: Any) -> None:
"""
Expand All @@ -364,15 +393,15 @@ def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future:
:return: a response future from the process to be played
"""
return self._communicator.rpc_send(pid, PLAY_MSG)
return self._communicator.rpc_send(pid, MessageBuilder.play())

def play_all(self) -> None:
"""
Play all processes that are subscribed to the same communicator
"""
self._communicator.broadcast_send(None, subject=Intent.PLAY)

def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future:
def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future:
"""
Kill the process
Expand All @@ -381,18 +410,20 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fut
:return: a response future from the process to be killed
"""
message = copy.copy(KILL_MSG)
if msg is not None:
message[MESSAGE_KEY] = msg
if msg is None:
msg = MessageBuilder.kill()

return self._communicator.rpc_send(pid, message)
return self._communicator.rpc_send(pid, msg)

def kill_all(self, msg: Optional[Any]) -> None:
def kill_all(self, msg: Optional[MessageType]) -> None:
"""
Kill all processes that are subscribed to the same communicator
:param msg: an optional pause message
"""
if msg is None:
msg = MessageBuilder.kill()

self._communicator.broadcast_send(msg, subject=Intent.KILL)

def continue_process(
Expand Down
Loading

0 comments on commit f760b4a

Please sign in to comment.