Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Message passing with more information #291

Merged
merged 17 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@
" def continue_fn(self):\n",
" print('continuing')\n",
" # message is stored in the process status\n",
" return plumpy.Kill('I was killed')\n",
" return plumpy.Kill(plumpy.KillMessage.build('I was killed'))\n",
"\n",
"\n",
"process = ContinueProcess()\n",
Expand Down Expand Up @@ -1118,7 +1118,7 @@
"\n",
"process = SimpleProcess(communicator=communicator)\n",
"\n",
"pprint(communicator.rpc_send(str(process.pid), plumpy.STATUS_MSG).result())"
"pprint(communicator.rpc_send(str(process.pid), plumpy.StatusMessage.build()).result())"
]
},
{
Expand Down
68 changes: 42 additions & 26 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
# -*- coding: utf-8 -*-
"""The state machine for processes"""

from __future__ import annotations

import enum
import functools
import inspect
import logging
import os
import sys
from types import TracebackType
from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Sequence, Set, Type, Union, cast
from typing import (
Any,
Callable,
Dict,
Hashable,
Iterable,
List,
Optional,
Sequence,
Set,
Type,
Union,
)

from plumpy.futures import Future

Expand All @@ -31,7 +45,7 @@ class StateEntryFailed(Exception): # noqa: N818
Failed to enter a state, can provide the next state to go to via this exception
"""

def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None:
def __init__(self, state: State, *args: Any, **kwargs: Any) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this change and what does it mean?
also please note the default value None is taken away, is this backward compatible?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function inside plumpy is only called:

    @super_check
    def on_finish(self, result: Any, successful: bool) -> None:
        """Entering the FINISHED state."""
        if successful:
            validation_error = self.spec().outputs.validate(self.outputs)
            if validation_error:
                state_cls = self.get_states_map()[process_states.ProcessState.FINISHED]
                finished_state = state_cls(self, result=result, successful=False)
                raise StateEntryFailed(finished_state)

I did the same to create the state first and then pass to the function for consistency as transition_to.

also please note the default value None is taken away, is this backward compatible?

It depends, I think we don't have a clear public API list for plumpy. Since aiida-core didn't use this directly, it does not break backward compatibility of aiida-core.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends, I think we don't have a clear public API list for plumpy. Since aiida-core didn't use this directly, it does not break backward compatibility of aiida-core.

ok, agreed, I'm also not so strict in this case, since so far aiida-core is the only known dependent of this plumpy

super().__init__('failed to enter state')
self.state = state
self.args = args
Expand Down Expand Up @@ -187,7 +201,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine':
:param kwargs: Any keyword arguments to be passed to the constructor
:return: An instance of the state machine
"""
inst = super().__call__(*args, **kwargs)
inst: StateMachine = super().__call__(*args, **kwargs)
inst.transition_to(inst.create_initial_state())
call_with_super_check(inst.init)
return inst
Expand Down Expand Up @@ -300,16 +314,23 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None:
def on_terminated(self) -> None:
"""Called when a terminal state is entered"""

def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None:
def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
"""Transite to the new state.

The new target state will be create lazily when the state is not yet instantiated,
which will happened for states not in the expect path such as pause and kill.
The arguments are passed to the state class to create state instance.
(process arg does not need to pass since it will always call with 'self' as process)
"""
assert not self._transitioning, 'Cannot call transition_to when already transitioning state'

if new_state is None:
return None
unkcpz marked this conversation as resolved.
Show resolved Hide resolved

unkcpz marked this conversation as resolved.
Show resolved Hide resolved
initial_state_label = self._state.LABEL if self._state is not None else None
label = None
try:
self._transitioning = True

# Make sure we have a state instance
new_state = self._create_state_instance(new_state, *args, **kwargs)
label = new_state.LABEL

# If the previous transition failed, do not try to exit it but go straight to next state
Expand All @@ -319,8 +340,7 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A
try:
self._enter_next_state(new_state)
except StateEntryFailed as exception:
# Make sure we have a state instance
new_state = self._create_state_instance(exception.state, *exception.args, **exception.kwargs)
new_state = exception.state
label = new_state.LABEL
self._exit_current_state(new_state)
self._enter_next_state(new_state)
Expand All @@ -338,7 +358,11 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A
self._transitioning = False

def transition_failed(
self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType
self,
initial_state: Hashable,
final_state: Hashable,
exception: Exception,
trace: TracebackType,
) -> None:
"""Called when a state transitions fails.

Expand All @@ -355,6 +379,10 @@ def set_debug(self, enabled: bool) -> None:
self._debug: bool = enabled

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
# XXX: this method create state from label, which is duplicate as _create_state_instance and less generic
# because the label is defined after the state and required to be know before calling this function.
# This method should be replaced by `_create_state_instance`.
# aiida-core using this method for its Waiting state override.
try:
return self.get_states_map()[state_label](self, *args, **kwargs)
except KeyError:
Expand Down Expand Up @@ -383,20 +411,8 @@ def _enter_next_state(self, next_state: State) -> None:
self._state = next_state
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

def _create_state_instance(self, state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> State:
if isinstance(state, State):
# It's already a state instance
return state

# OK, have to create it
state_cls = self._ensure_state_class(state)
return state_cls(self, *args, **kwargs)

def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State]:
if inspect.isclass(state) and issubclass(state, State):
return state
def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State:
if state_cls.LABEL not in self.get_states_map():
raise ValueError(f'{state_cls.LABEL} is not a valid state')
unkcpz marked this conversation as resolved.
Show resolved Hide resolved

try:
return self.get_states_map()[cast(Hashable, state)]
except KeyError:
raise ValueError(f'{state} is not a valid state')
return state_cls(self, **kwargs)
98 changes: 66 additions & 32 deletions src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
"""Module for process level communication functions and classes"""

from __future__ import annotations

import asyncio
import copy
import logging
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast

Expand All @@ -12,13 +13,13 @@
from .utils import PID_TYPE

__all__ = [
'KILL_MSG',
'PAUSE_MSG',
'PLAY_MSG',
'STATUS_MSG',
'KillMessage',
'PauseMessage',
'PlayMessage',
'ProcessLauncher',
'RemoteProcessController',
'RemoteProcessThreadController',
'StatusMessage',
'create_continue_body',
'create_launch_body',
]
Expand All @@ -31,6 +32,7 @@

INTENT_KEY = 'intent'
MESSAGE_KEY = 'message'
FORCE_KILL_KEY = 'force_kill'


class Intent:
Expand All @@ -42,10 +44,45 @@ class Intent:
STATUS: str = 'status'


PAUSE_MSG = {INTENT_KEY: Intent.PAUSE}
PLAY_MSG = {INTENT_KEY: Intent.PLAY}
KILL_MSG = {INTENT_KEY: Intent.KILL}
STATUS_MSG = {INTENT_KEY: Intent.STATUS}
MessageType = Dict[str, Any]


class PlayMessage:
@classmethod
def build(cls, message: str | None = None) -> MessageType:
return {
INTENT_KEY: Intent.PLAY,
MESSAGE_KEY: message,
}


class PauseMessage:
@classmethod
def build(cls, message: str | None = None) -> MessageType:
return {
INTENT_KEY: Intent.PAUSE,
MESSAGE_KEY: message,
}


class KillMessage:
@classmethod
def build(cls, message: str | None = None, force: bool = False) -> MessageType:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd still call it force_kill for clarity

Suggested change
def build(cls, message: str | None = None, force: bool = False) -> MessageType:
def build(cls, message: str | None = None, force_kill: bool = False) -> MessageType:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@unkcpz , maybe you forgot to update this 😅

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note, aiida-core sends a kill message here and then here, how the flag should be send over?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

message = 'Killed through `verdi process kill`'
control.kill_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message)

For this one, the message should be build first in to a KillMessage and then send over.

message = KillMessage.build(message= 'Killed through `verdi process kill`')
control.kill_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the one in _perform_action, it is the same, but I'd prefer to have the argument passing close to the function call of kill_process

_perform_actions(processes, controller.kill_process, 'kill', 'killing', timeout, wait, msg=message)

change it to

message = KillMessage.build(message=message, force=force) 
_perform_actions(processes, functool.partial(controller.kill_process, msg=message), 'kill', 'killing', timeout, wait)

Copy link
Contributor

@khsrali khsrali Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but now aiida-core needs to import KillMessage from plumpy only for adding a flag. wouldn't be easier that msg itself could be a dictionary?

message = {'msg' : 'Killed through `verdi process kill`', 
           'options' : { 
                      'force_kill':True
                       }}
control.kill_processes(processes, controller.kill_process, .., msg=message)

and then plumpy would well receive all options listed, if any.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but now aiida-core needs to import KillMessage from plumpy only for adding a flag

Why it is a problem for importing KillMessage? I think message is a good interface that can flexibly support backward compatibility. We make all four message type public APIs of plumpy and when using process communicator it requires sending over through this message. In the future even if we want to change the real message to other type we can do it without break users' code.

wouldn't be easier that msg itself could be a dictionary?

It is a dictionary returned by build but construct a dictionary by hand is error-prone. Using the build function of the message, the function signature will tell developers which options are allowed to be passed. We can make this works also by dataclass, but then that is overkill for a message in my opinion.

Copy link
Contributor

@khsrali khsrali Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well this is the twisted things that we were always complaining about, import from plumpy to only make a dictionary only for an input of a plumpy function itself...
I mean you see my point 😄 🤌

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My 2¢'s: I think it is rarely preferable to have a free-form dictionary to specify inputs to an API, especially if it is nested, because you will always have to start reading the docs/code to know how it is structured. Using the other approach, (data) classes, explicit function/method arguments etc. these can be auto inspected by an IDE and are much easier to use. The fact that you have to import something is in my eyes not really a problem. Anyway to use a library you have to import something to do anything.

Now in this case, ideally the force_kill argument would have been directly available alongside the message to whatever method/function is called. But if this is not possible with the current design and it has to be incorporated in the message itself, than the KillMessage approach would be a good alternative. I would perhaps just consider making it a proper dataclass to get automatic type validation, e.g.

from dataclass import dataclass

@dataclass
class KillMessage:

    message: str | None = None
    force_kill: bool = False

    @classmethod
    def build(...):
        ....

Actually, if using a dataclass, why not just use the constructor to build the instance instead of using the build classmethod. KillMessage(message="something, force_kill=True) is nicer than KillMessage.build(message="something, force_kill=True) I think

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import from plumpy to only make a dictionary only for an input of a plumpy function itself

I won't say "only", encapsulate it in a function as a contract also serve as an input validator.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @sphuber, I just saw your comment, I had my browser open for during the weekend and I directly reply to Ali's comment without refresh.

Yes, thought about dataclass, and I think it can work. I use build into a dict mostly for the reason to keep the message the same as what was passing as before, since this is the message that should also go through de/serialize by the en/decoder and I am not 100% confident with that.

I think msgpack will take care of the dataclass for sure, but since the en/decoder used was the yaml one, I am not so sure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same here 😄 I had the page open without refreshing, so I didn't notice Sebastiaan's comment.

KillMessage(message="something", force_kill=True) is nicer than KillMessage.build(message="something", force_kill=True) I think

I see and agree with this suggestion, ✔️

return {
INTENT_KEY: Intent.KILL,
MESSAGE_KEY: message,
FORCE_KILL_KEY: force,
}


class StatusMessage:
@classmethod
def build(cls, message: str | None = None) -> MessageType:
return {
INTENT_KEY: Intent.STATUS,
MESSAGE_KEY: message,
}

unkcpz marked this conversation as resolved.
Show resolved Hide resolved

TASK_KEY = 'task'
TASK_ARGS = 'args'
Expand Down Expand Up @@ -162,7 +199,7 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus':
:param pid: the process id
:return: the status response from the process
"""
future = self._communicator.rpc_send(pid, STATUS_MSG)
future = self._communicator.rpc_send(pid, StatusMessage.build())
result = await asyncio.wrap_future(future)
return result

Expand All @@ -174,11 +211,9 @@ async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pr
:param msg: optional pause message
:return: True if paused, False otherwise
"""
message = copy.copy(PAUSE_MSG)
if msg is not None:
message[MESSAGE_KEY] = msg
msg = PauseMessage.build(message=msg)

pause_future = self._communicator.rpc_send(pid, message)
pause_future = self._communicator.rpc_send(pid, msg)
# rpc_send return a thread future from communicator
future = await asyncio.wrap_future(pause_future)
# future is just returned from rpc call which return a kiwipy future
Expand All @@ -192,25 +227,24 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult':
:param pid: the pid of the process to play
:return: True if played, False otherwise
"""
play_future = self._communicator.rpc_send(pid, PLAY_MSG)
play_future = self._communicator.rpc_send(pid, PlayMessage.build())
future = await asyncio.wrap_future(play_future)
result = await asyncio.wrap_future(future)
return result

async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult':
async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult':
"""
Kill the process

:param pid: the pid of the process to kill
:param msg: optional kill message
:return: True if killed, False otherwise
"""
message = copy.copy(KILL_MSG)
if msg is not None:
message[MESSAGE_KEY] = msg
if msg is None:
msg = KillMessage.build()

# Wait for the communication to go through
kill_future = self._communicator.rpc_send(pid, message)
kill_future = self._communicator.rpc_send(pid, msg)
future = await asyncio.wrap_future(kill_future)
# Now wait for the kill to be enacted
result = await asyncio.wrap_future(future)
Expand Down Expand Up @@ -331,7 +365,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future:
:param pid: the process id
:return: the status response from the process
"""
return self._communicator.rpc_send(pid, STATUS_MSG)
return self._communicator.rpc_send(pid, StatusMessage.build())

def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future:
"""
Expand All @@ -342,11 +376,9 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu
:return: a response future from the process to be paused

"""
message = copy.copy(PAUSE_MSG)
if msg is not None:
message[MESSAGE_KEY] = msg
msg = PauseMessage.build(message=msg)

return self._communicator.rpc_send(pid, message)
return self._communicator.rpc_send(pid, msg)

def pause_all(self, msg: Any) -> None:
"""
Expand All @@ -364,15 +396,15 @@ def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future:
:return: a response future from the process to be played

"""
return self._communicator.rpc_send(pid, PLAY_MSG)
return self._communicator.rpc_send(pid, PlayMessage.build())

def play_all(self) -> None:
"""
Play all processes that are subscribed to the same communicator
"""
self._communicator.broadcast_send(None, subject=Intent.PLAY)

def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future:
def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future:
"""
Kill the process

Expand All @@ -381,18 +413,20 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fut
:return: a response future from the process to be killed

"""
message = copy.copy(KILL_MSG)
if msg is not None:
message[MESSAGE_KEY] = msg
if msg is None:
msg = KillMessage.build()

return self._communicator.rpc_send(pid, message)
return self._communicator.rpc_send(pid, msg)

def kill_all(self, msg: Optional[Any]) -> None:
def kill_all(self, msg: Optional[MessageType]) -> None:
"""
Kill all processes that are subscribed to the same communicator

:param msg: an optional pause message
"""
if msg is None:
msg = KillMessage.build()

self._communicator.broadcast_send(msg, subject=Intent.KILL)

def continue_process(
Expand Down
Loading
Loading