From 96c5842bdc009a8ed741e9a7a21fce7d48731c8a Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sun, 1 Dec 2024 17:35:40 +0100 Subject: [PATCH] Amend --- src/plumpy/base/state_machine.py | 84 +++++-------- src/plumpy/process_comms.py | 4 +- src/plumpy/process_states.py | 139 ++++++++------------- src/plumpy/processes.py | 205 +++++++++++-------------------- tests/test_processes.py | 1 - 5 files changed, 153 insertions(+), 280 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index a371084a..499612e0 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -43,7 +43,7 @@ class StateEntryFailed(Exception): # noqa: N818 Failed to enter a state, can provide the next state to go to via this exception """ - def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: + def __init__(self, state: type['State'], *args: Any, **kwargs: Any) -> None: super().__init__('failed to enter state') self.state = state self.args = args @@ -72,10 +72,10 @@ def __init__( super().__init__(self._format_msg()) def _format_msg(self) -> str: - msg = [f"{self.initial_state} -> {self.final_state}"] + msg = [f'{self.initial_state} -> {self.final_state}'] if self.traceback_str is not None: msg.append(self.traceback_str) - return "\n".join(msg) + return '\n'.join(msg) def event( @@ -83,16 +83,16 @@ def event( to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*', ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """A decorator to check for correct transitions, raising ``EventError`` on invalid transitions.""" - if from_states != "*": + if from_states != '*': if inspect.isclass(from_states): from_states = (from_states,) if not all(issubclass(state, State) for state in from_states): # type: ignore - raise TypeError(f"from_states: {from_states}") - if to_states != "*": + raise TypeError(f'from_states: {from_states}') + if to_states != '*': if inspect.isclass(to_states): to_states = (to_states,) if not all(issubclass(state, State) for state in to_states): # type: ignore - raise TypeError(f"to_states: {to_states}") + raise TypeError(f'to_states: {to_states}') def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: evt_label = wrapped.__name__ @@ -101,20 +101,14 @@ def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: def transition(self: Any, *a: Any, **kw: Any) -> Any: initial = self._state - if from_states != "*" and not any( - isinstance(self._state, state) for state in from_states - ): # type: ignore - raise EventError( - evt_label, f"Event {evt_label} invalid in state {initial.LABEL}" - ) + if from_states != '*' and not any(isinstance(self._state, state) for state in from_states): # type: ignore + raise EventError(evt_label, f'Event {evt_label} invalid in state {initial.LABEL}') result = wrapped(self, *a, **kw) if not (result is False or isinstance(result, Future)): - if to_states != "*" and not any( - isinstance(self._state, state) for state in to_states - ): # type: ignore + if to_states != '*' and not any(isinstance(self._state, state) for state in to_states): # type: ignore if self._state == initial: - raise EventError(evt_label, "Machine did not transition") + raise EventError(evt_label, 'Machine did not transition') raise EventError( evt_label, @@ -160,7 +154,7 @@ def label(self) -> LABEL_TYPE: def enter(self) -> None: """Entering the state""" - def execute(self) -> Optional["State"]: + def execute(self) -> Optional['State']: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. @@ -170,9 +164,9 @@ def execute(self) -> Optional["State"]: def exit(self) -> None: """Exiting the state""" if self.is_terminal(): - raise InvalidStateError(f"Cannot exit a terminal state {self.LABEL}") + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> "State": + def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': return self.state_machine.create_state(state_label, *args, **kwargs) def do_enter(self) -> None: @@ -229,7 +223,7 @@ def get_states(cls) -> Sequence[Type[State]]: if cls.STATES is not None: return cls.STATES - raise RuntimeError("States not defined") + raise RuntimeError('States not defined') @classmethod def initial_state_label(cls) -> LABEL_TYPE: @@ -247,7 +241,7 @@ def get_state_class(cls, label: LABEL_TYPE) -> Type[State]: def __ensure_built(cls) -> None: try: # Check if it's already been built (and therefore sealed) - if cls.__getattribute__(cls, "sealed"): + if cls.__getattribute__(cls, 'sealed'): return except AttributeError: pass @@ -271,9 +265,7 @@ def __init__(self) -> None: self.__ensure_built() self._state: Optional[State] = None self._exception_handler = None # Note this appears to never be used - self.set_debug( - (not sys.flags.ignore_environment and bool(os.environ.get("PYTHONSMDEBUG"))) - ) + self.set_debug((not sys.flags.ignore_environment and bool(os.environ.get('PYTHONSMDEBUG')))) self._transitioning = False self._event_callbacks: Dict[Hashable, List[EVENT_CALLBACK_TYPE]] = {} @@ -282,7 +274,7 @@ def init(self) -> None: """Called after entering initial state in `__call__` method of `StateMachineMeta`""" def __str__(self) -> str: - return f"<{self.__class__.__name__}> ({self.state})" + return f'<{self.__class__.__name__}> ({self.state})' def create_initial_state(self) -> State: return self.get_state_class(self.initial_state_label())(self) @@ -293,9 +285,7 @@ def state(self) -> Optional[LABEL_TYPE]: return None return self._state.LABEL - def add_state_event_callback( - self, hook: Hashable, callback: EVENT_CALLBACK_TYPE - ) -> None: + def add_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None: """ Add a callback to be called on a particular state event hook. The callback should have form fn(state_machine, hook, state) @@ -305,10 +295,8 @@ def add_state_event_callback( """ self._event_callbacks.setdefault(hook, []).append(callback) - def remove_state_event_callback( - self, hook: Hashable, callback: EVENT_CALLBACK_TYPE - ) -> None: - if getattr(self, "_closed", False): + def remove_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None: + if getattr(self, '_closed', False): # if the process is closed, then all callbacks have already been removed return None try: @@ -324,19 +312,15 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: def on_terminated(self) -> None: """Called when a terminal state is entered""" - def transition_to( - self, new_state: Union[State, Type[State]], **kwargs: Any - ) -> None: + def transition_to(self, new_state: Union[State, Type[State]], **kwargs: Any) -> None: """Transite to the new state. - The new target state will be create lazily when the state is not yet instantiated, + The new target state will be create lazily when the state is not yet instantiated, which will happened for states not in the expect path such as pause and kill. The arguments are passed to the state class to create state instance. (process arg does not need to pass since it will always call with 'self' as process) """ - assert ( - not self._transitioning - ), "Cannot call transition_to when already transitioning state" + assert not self._transitioning, 'Cannot call transition_to when already transitioning state' initial_state_label = self._state.LABEL if self._state is not None else None label = None @@ -358,9 +342,7 @@ def transition_to( except StateEntryFailed as exception: # Make sure we have a state instance if not isinstance(exception.state, State): - new_state = self._create_state_instance( - exception.state, **exception.kwargs - ) + new_state = self._create_state_instance(exception.state, **exception.kwargs) label = new_state.LABEL self._exit_current_state(new_state) self._enter_next_state(new_state) @@ -406,7 +388,7 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> Stat try: return self.get_states_map()[state_label](self, *args, **kwargs) except KeyError: - raise ValueError(f"{state_label} is not a valid state") + raise ValueError(f'{state_label} is not a valid state') def _exit_current_state(self, next_state: State) -> None: """Exit the given state""" @@ -415,15 +397,11 @@ def _exit_current_state(self, next_state: State) -> None: # in which case check the new state is the initial state if self._state is None: if next_state.label != self.initial_state_label(): - raise RuntimeError( - f"Cannot enter state '{next_state}' as the initial state" - ) + raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state") return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: - raise RuntimeError( - f"Cannot transition from {self._state.LABEL} to {next_state.label}" - ) + raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') self._fire_state_event(StateEventHook.EXITING_STATE, next_state) self._state.do_exit() @@ -435,10 +413,8 @@ def _enter_next_state(self, next_state: State) -> None: self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - def _create_state_instance( - self, state_cls: type[State], **kwargs: Any - ) -> State: + def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State: if state_cls.LABEL not in self.get_states_map(): - raise ValueError(f"{state_cls.LABEL} is not a valid state") + raise ValueError(f'{state_cls.LABEL} is not a valid state') return state_cls(self, **kwargs) diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 9e1e4110..773a9742 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -33,6 +33,7 @@ MESSAGE_KEY = 'message' FORCE_KILL_KEY = 'force_kill' + class Intent: """Intent constants for a process message""" @@ -41,9 +42,10 @@ class Intent: KILL: str = 'kill' STATUS: str = 'status' + MessageType = dict[str, Any] -PAUSE_MSG: MessageType= {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None} +PAUSE_MSG: MessageType = {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None} PLAY_MSG: MessageType = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None} KILL_MSG: MessageType = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False} STATUS_MSG: MessageType = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None} diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 46f29d8f..ede846e4 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -import sys import copy +import sys import traceback from enum import Enum from types import TracebackType @@ -70,7 +70,7 @@ class Command(persistence.Savable): pass -@auto_persist("msg") +@auto_persist('msg') class Kill(Command): def __init__(self, msg: Optional[MessageType] = None): super().__init__() @@ -81,7 +81,7 @@ class Pause(Command): pass -@auto_persist("msg", "data") +@auto_persist('msg', 'data') class Wait(Command): def __init__( self, @@ -95,7 +95,7 @@ def __init__( self.data = data -@auto_persist("result") +@auto_persist('result') class Stop(Command): def __init__(self, result: Any, successful: bool) -> None: super().__init__() @@ -103,9 +103,9 @@ def __init__(self, result: Any, successful: bool) -> None: self.successful = successful -@auto_persist("args", "kwargs") +@auto_persist('args', 'kwargs') class Continue(Command): - CONTINUE_FN = "continue_fn" + CONTINUE_FN = 'continue_fn' def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): super().__init__() @@ -113,15 +113,11 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext - ) -> None: + def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext - ) -> None: + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) try: self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN]) @@ -148,7 +144,7 @@ class ProcessState(Enum): KILLED: str = 'killed' -@auto_persist("in_state") +@auto_persist('in_state') class State(state_machine.State, persistence.Savable): @property def process(self) -> state_machine.StateMachine: @@ -157,9 +153,7 @@ def process(self) -> state_machine.StateMachine: """ return self.state_machine - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext - ) -> None: + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.state_machine = load_context.process @@ -167,41 +161,33 @@ def interrupt(self, reason: Any) -> None: pass -@auto_persist("args", "kwargs") +@auto_persist('args', 'kwargs') class Created(State): LABEL = ProcessState.CREATED ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} - RUN_FN = "run_fn" + RUN_FN = 'run_fn' - def __init__( - self, process: "Process", run_fn: Callable[..., Any], *args: Any, **kwargs: Any - ) -> None: + def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: super().__init__(process) assert run_fn is not None self.run_fn = run_fn self.args = args self.kwargs = kwargs - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext - ) -> None: + def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) out_state[self.RUN_FN] = self.run_fn.__name__ - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext - ) -> None: + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) def execute(self) -> state_machine.State: - return self.create_state( - ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs - ) + return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) -@auto_persist("args", "kwargs") +@auto_persist('args', 'kwargs') class Running(State): LABEL = ProcessState.RUNNING ALLOWED = { @@ -212,17 +198,15 @@ class Running(State): ProcessState.EXCEPTED, } - RUN_FN = "run_fn" # The key used to store the function to run - COMMAND = "command" # The key used to store an upcoming command + RUN_FN = 'run_fn' # The key used to store the function to run + COMMAND = 'command' # The key used to store an upcoming command # Class level defaults _command: Union[None, Kill, Stop, Wait, Continue] = None _running: bool = False _run_handle = None - def __init__( - self, process: "Process", run_fn: Callable[..., Any], *args: Any, **kwargs: Any - ) -> None: + def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: super().__init__(process) assert run_fn is not None self.run_fn = run_fn @@ -230,23 +214,17 @@ def __init__( self.kwargs = kwargs self._run_handle = None - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext - ) -> None: + def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) out_state[self.RUN_FN] = self.run_fn.__name__ if self._command is not None: out_state[self.COMMAND] = self._command.save() - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext - ) -> None: + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) if self.COMMAND in saved_state: - self._command = persistence.Savable.load( - saved_state[self.COMMAND], load_context - ) # type: ignore + self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore def interrupt(self, reason: Any) -> None: pass @@ -286,24 +264,18 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.create_state( - ProcessState.FINISHED, command.result, command.successful - ) + state = self.create_state(ProcessState.FINISHED, command.result, command.successful) elif isinstance(command, Wait): - state = self.create_state( - ProcessState.WAITING, command.continue_fn, command.msg, command.data - ) + state = self.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) elif isinstance(command, Continue): - state = self.create_state( - ProcessState.RUNNING, command.continue_fn, *command.args - ) + state = self.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) else: - raise ValueError("Unrecognised command") + raise ValueError('Unrecognised command') return cast(State, state) # casting from base.State to process.State -@auto_persist("msg", "data") +@auto_persist('msg', 'data') class Waiting(State): LABEL = ProcessState.WAITING ALLOWED = { @@ -314,19 +286,19 @@ class Waiting(State): ProcessState.FINISHED, } - DONE_CALLBACK = "DONE_CALLBACK" + DONE_CALLBACK = 'DONE_CALLBACK' _interruption = None def __str__(self) -> str: state_info = super().__str__() if self.msg is not None: - state_info += f" ({self.msg})" + state_info += f' ({self.msg})' return state_info def __init__( self, - process: "Process", + process: 'Process', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, data: Optional[Any] = None, @@ -337,16 +309,12 @@ def __init__( self.data = data self._waiting_future: futures.Future = futures.Future() - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext - ) -> None: + def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) if self.done_callback is not None: out_state[self.DONE_CALLBACK] = self.done_callback.__name__ - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext - ) -> None: + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: @@ -372,14 +340,12 @@ async def execute(self) -> State: # type: ignore if result == NULL: next_state = self.create_state(ProcessState.RUNNING, self.done_callback) else: - next_state = self.create_state( - ProcessState.RUNNING, self.done_callback, result - ) + next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) return cast(State, next_state) # casting from base.State to process.State def resume(self, value: Any = NULL) -> None: - assert self._waiting_future is not None, "Not yet waiting" + assert self._waiting_future is not None, 'Not yet waiting' if self._waiting_future.done(): return @@ -397,12 +363,12 @@ class Excepted(State): LABEL = ProcessState.EXCEPTED - EXC_VALUE = "ex_value" - TRACEBACK = "traceback" + EXC_VALUE = 'ex_value' + TRACEBACK = 'traceback' def __init__( self, - process: "Process", + process: 'Process', exception: Optional[BaseException], trace_back: Optional[TracebackType] = None, ): @@ -416,22 +382,16 @@ def __init__( self.traceback = trace_back def __str__(self) -> str: - exception = traceback.format_exception_only( - type(self.exception) if self.exception else None, self.exception - )[0] - return super().__str__() + f"({exception})" + exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] + return super().__str__() + f'({exception})' - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext - ) -> None: + def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) out_state[self.EXC_VALUE] = yaml.dump(self.exception) if self.traceback is not None: - out_state[self.TRACEBACK] = "".join(traceback.format_tb(self.traceback)) + out_state[self.TRACEBACK] = ''.join(traceback.format_tb(self.traceback)) - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext - ) -> None: + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -444,9 +404,7 @@ def load_instance_state( def get_exc_info( self, - ) -> Tuple[ - Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType] - ]: + ) -> Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]: """ Recreate the exc_info tuple and return it """ @@ -457,22 +415,23 @@ def get_exc_info( ) -@auto_persist("result", "successful") +@auto_persist('result', 'successful') class Finished(State): """State for process is finished. :param result: The result of process :param successful: Boolean for the exit code is ``0`` the process is successful. """ + LABEL = ProcessState.FINISHED - def __init__(self, process: "Process", result: Any, successful: bool) -> None: + def __init__(self, process: 'Process', result: Any, successful: bool) -> None: super().__init__(process) self.result = result self.successful = successful -@auto_persist("msg") +@auto_persist('msg') class Killed(State): """ Represents a state where a process has been killed. @@ -485,7 +444,7 @@ class Killed(State): LABEL = ProcessState.KILLED - def __init__(self, process: "Process", msg: Optional[MessageType]): + def __init__(self, process: 'Process', msg: Optional[MessageType]): """ :param process: The associated process :param msg: Optional kill message diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 07e2d20c..9358d927 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -15,6 +15,7 @@ import warnings from types import TracebackType from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -53,15 +54,18 @@ from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper +from .process_comms import KILL_MSG, MESSAGE_KEY, MessageType from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected -from .process_comms import KILL_MSG, MESSAGE_KEY, MessageType + +if TYPE_CHECKING: + from .process_states import State __all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] _LOGGER = logging.getLogger(__name__) -PROCESS_STACK = ContextVar("process stack", default=[]) +PROCESS_STACK = ContextVar('process stack', default=[]) class BundleKeys: @@ -94,20 +98,20 @@ def ensure_not_closed(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: if self._closed: - raise exceptions.ClosedError("Process is closed") + raise exceptions.ClosedError('Process is closed') return func(self, *args, **kwargs) return func_wrapper @persistence.auto_persist( - "_pid", - "_creation_time", - "_future", - "_paused", - "_status", - "_pre_paused_status", - "_event_helper", + '_pid', + '_creation_time', + '_future', + '_paused', + '_status', + '_pre_paused_status', + '_event_helper', ) class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): """ @@ -161,7 +165,7 @@ class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMe __called: bool = False @classmethod - def current(cls) -> Optional["Process"]: + def current(cls) -> Optional['Process']: """ Get the currently running process i.e. the one at the top of the stack @@ -197,15 +201,15 @@ def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: @classmethod def spec(cls) -> ProcessSpec: try: - return cls.__getattribute__(cls, "_spec") + return cls.__getattribute__(cls, '_spec') except AttributeError: try: cls._spec: ProcessSpec = cls._spec_class() # type: ignore cls.__called: bool = False # type: ignore cls.define(cls._spec) # type: ignore assert cls.__called, ( - f"Process.define() was not called by {cls}\nHint: Did you forget to call the superclass method in " - "your define? Try: super().define(spec)" + f'Process.define() was not called by {cls}\nHint: Did you forget to call the superclass method in ' + 'your define? Try: super().define(spec)' ) return cls._spec # type: ignore except Exception: @@ -237,11 +241,11 @@ def get_description(cls) -> Dict[str, Any]: description: Dict[str, Any] = {} if cls.__doc__: - description["description"] = cls.__doc__.strip() + description['description'] = cls.__doc__.strip() spec_description = cls.spec().get_description() if spec_description: - description["spec"] = spec_description + description['spec'] = spec_description return description @@ -250,7 +254,7 @@ def recreate_from( cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None, - ) -> "Process": + ) -> 'Process': """ Recreate a process from a saved state, passing any positional and keyword arguments on to load_instance_state @@ -298,9 +302,7 @@ def __init__( self._paused = None # Input/output - self._raw_inputs = ( - None if inputs is None else utils.AttributesFrozendict(inputs) - ) + self._raw_inputs = None if inputs is None else utils.AttributesFrozendict(inputs) self._pid = pid self._parsed_inputs: Optional[utils.AttributesFrozendict] = None self._outputs: Dict[str, Any] = {} @@ -323,35 +325,19 @@ def init(self) -> None: if self._communicator is not None: try: - identifier = self._communicator.add_rpc_subscriber( - self.message_receive, identifier=str(self.pid) - ) - self.add_cleanup( - functools.partial( - self._communicator.remove_rpc_subscriber, identifier - ) - ) + identifier = self._communicator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid)) + self.add_cleanup(functools.partial(self._communicator.remove_rpc_subscriber, identifier)) except kiwipy.TimeoutError: - self.logger.exception( - "Process<%s>: failed to register as an RPC subscriber", self.pid - ) + self.logger.exception('Process<%s>: failed to register as an RPC subscriber', self.pid) try: # filter out state change broadcasts - subscriber = kiwipy.BroadcastFilter( - self.broadcast_receive, subject=re.compile(r"^(?!state_changed).*") - ) - identifier = self._communicator.add_broadcast_subscriber( - subscriber, identifier=str(self.pid) - ) - self.add_cleanup( - functools.partial( - self._communicator.remove_broadcast_subscriber, identifier - ) - ) + subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*')) + identifier = self._communicator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) + self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) except kiwipy.TimeoutError: self.logger.exception( - "Process<%s>: failed to register as a broadcast subscriber", + 'Process<%s>: failed to register as a broadcast subscriber', self.pid, ) @@ -360,10 +346,10 @@ def init(self) -> None: def try_killing(future: futures.Future) -> None: if future.cancelled(): msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = "Killed by future being cancelled" + msg[MESSAGE_KEY] = 'Killed by future being cancelled' if not self.kill(msg): self.logger.warning( - "Process<%s>: Failed to kill process on future cancel", + 'Process<%s>: Failed to kill process on future cancel', self.pid, ) @@ -460,7 +446,7 @@ def future(self) -> persistence.SavableFuture: @ensure_not_closed def launch( self, - process_class: Type["Process"], + process_class: Type['Process'], inputs: Optional[dict] = None, pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, @@ -498,7 +484,7 @@ def result(self) -> Any: if isinstance(self._state, process_states.Killed): raise exceptions.KilledError(self._state.msg) if isinstance(self._state, process_states.Excepted): - raise (self._state.exception or Exception("process excepted")) + raise (self._state.exception or Exception('process excepted')) raise exceptions.InvalidStateError @@ -510,9 +496,7 @@ def successful(self) -> bool: try: return self._state.successful # type: ignore except AttributeError as exception: - raise exceptions.InvalidStateError( - "process is not in the finished state" - ) from exception + raise exceptions.InvalidStateError('process is not in the finished state') from exception @property def is_successful(self) -> bool: @@ -534,7 +518,7 @@ def killed_msg(self) -> Optional[MessageType]: if isinstance(self._state, process_states.Killed): return self._state.msg - raise exceptions.InvalidStateError("Has not been killed") + raise exceptions.InvalidStateError('Has not been killed') def exception(self) -> Optional[BaseException]: """Return exception, if the process is terminated in excepted state.""" @@ -569,9 +553,7 @@ def loop(self) -> asyncio.AbstractEventLoop: """Return the event loop of the process.""" return self._loop - def call_soon( - self, callback: Callable[..., Any], *args: Any, **kwargs: Any - ) -> events.ProcessCallback: + def call_soon(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> events.ProcessCallback: """ Schedule a callback to what is considered an internal process function (this needn't be a method). @@ -605,16 +587,14 @@ def _process_scope(self) -> Generator[None, None, None]: yield None finally: assert Process.current() is self, ( - "Somehow, the process at the top of the stack is not me, but another process! " - f"({self} != {Process.current()})" + 'Somehow, the process at the top of the stack is not me, but another process! ' + f'({self} != {Process.current()})' ) stack_copy = PROCESS_STACK.get().copy() stack_copy.pop() PROCESS_STACK.set(stack_copy) - async def _run_task( - self, callback: Callable[..., Any], *args: Any, **kwargs: Any - ) -> Any: + async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: """ This method should be used to run all Process related functions and coroutines. If there is an exception the process will enter the EXCEPTED state. @@ -647,7 +627,7 @@ def save_instance_state( """ super().save_instance_state(out_state, save_context) - out_state["_state"] = self._state.save() + out_state['_state'] = self._state.save() # Inputs/outputs if self.raw_inputs is not None: @@ -660,9 +640,7 @@ def save_instance_state( out_state[BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) @protected - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext - ) -> None: + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: """Load the process from its saved instance state. :param saved_state: A bundle to load the state from @@ -680,17 +658,17 @@ def load_instance_state( self._logger = None self._communicator = None - if "loop" in load_context: + if 'loop' in load_context: self._loop = load_context.loop else: self._loop = asyncio.get_event_loop() - self._state: process_states.State = self.recreate_state(saved_state["_state"]) + self._state: process_states.State = self.recreate_state(saved_state['_state']) - if "communicator" in load_context: + if 'communicator' in load_context: self._communicator = load_context.communicator - if "logger" in load_context: + if 'logger' in load_context: self._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above @@ -739,7 +717,7 @@ def set_logger(self, logger: logging.Logger) -> None: @protected def log_with_pid(self, level: int, msg: str) -> None: """Log the message with the process pid.""" - self.logger.log(level, "%s: %s", self.pid, msg) + self.logger.log(level, '%s: %s', self.pid, msg) # region Events @@ -774,24 +752,16 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: call_with_super_check(self.on_killed) if self._communicator and isinstance(self.state, enum.Enum): - from_label = ( - cast(enum.Enum, from_state.LABEL).value - if from_state is not None - else None - ) - subject = f"state_changed.{from_label}.{self.state.value}" - self.logger.info( - "Process<%s>: Broadcasting state change: %s", self.pid, subject - ) + from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None + subject = f'state_changed.{from_label}.{self.state.value}' + self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) try: - self._communicator.broadcast_send( - body=None, sender=self.pid, subject=subject - ) + self._communicator.broadcast_send(body=None, sender=self.pid, subject=subject) except (ConnectionClosed, ChannelInvalidStateError): - message = "Process<%s>: no connection available to broadcast state change from %s to %s" + message = 'Process<%s>: no connection available to broadcast state change from %s to %s' self.logger.warning(message, self.pid, from_label, self.state.value) except kiwipy.TimeoutError: - message = "Process<%s>: sending broadcast of state change from %s to %s timed out" + message = 'Process<%s>: sending broadcast of state change from %s to %s timed out' self.logger.warning(message, self.pid, from_label, self.state.value) def on_exiting(self) -> None: @@ -809,10 +779,7 @@ def on_create(self) -> None: def recursively_copy_dictionaries(value: Any) -> Any: """Recursively copy the mapping but only create copies of the dictionaries not the values.""" if isinstance(value, dict): - return { - key: recursively_copy_dictionaries(subvalue) - for key, subvalue in value.items() - } + return {key: recursively_copy_dictionaries(subvalue) for key, subvalue in value.items()} return value # This will parse the inputs with respect to the input portnamespace of the spec and validate them. The @@ -820,11 +787,7 @@ def recursively_copy_dictionaries(value: Any) -> Any: # ``_raw_inputs`` should not be modified, we pass a clone of it. Note that we only need a clone of the nested # dictionaries, so we don't use ``copy.deepcopy`` (which might seem like the obvious choice) as that will also # create a clone of the values, which we don't want. - raw_inputs = ( - recursively_copy_dictionaries(dict(self._raw_inputs)) - if self._raw_inputs - else {} - ) + raw_inputs = recursively_copy_dictionaries(dict(self._raw_inputs)) if self._raw_inputs else {} self._parsed_inputs = self.spec().inputs.pre_process(raw_inputs) result = self.spec().inputs.validate(self._parsed_inputs) @@ -857,9 +820,7 @@ def on_output_emitting(self, output_port: str, value: Any) -> None: """Output is about to be emitted.""" def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None: - self._event_helper.fire_event( - ProcessListener.on_output_emitted, self, output_port, value, dynamic - ) + self._event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic) @super_check def on_wait(self, awaitables: Sequence[Awaitable]) -> None: @@ -908,9 +869,7 @@ def on_finish(self, result: Any, successful: bool) -> None: if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: - raise StateEntryFailed( - process_states.Finished, result=result, successful=False - ) + raise StateEntryFailed(process_states.Finished, result=result, successful=False) self.future().set_result(self.outputs) @@ -936,17 +895,15 @@ def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None: @super_check def on_excepted(self) -> None: """Entered the EXCEPTED state.""" - self._fire_event( - ProcessListener.on_process_excepted, str(self.future().exception()) - ) + self._fire_event(ProcessListener.on_process_excepted, str(self.future().exception())) @super_check def on_kill(self, msg: Optional[MessageType]) -> None: """Entering the KILLED state.""" if msg is None: - msg_txt = "" + msg_txt = '' else: - msg_txt = msg[MESSAGE_KEY] or "" + msg_txt = msg[MESSAGE_KEY] or '' self.set_status(msg_txt) self.future().set_exception(exceptions.KilledError(msg_txt)) @@ -1007,9 +964,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An if intent == process_comms.Intent.PLAY: return self._schedule_rpc(self.play) if intent == process_comms.Intent.PAUSE: - return self._schedule_rpc( - self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None) - ) + return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.KILL: return self._schedule_rpc(self.kill, msg=msg) if intent == process_comms.Intent.STATUS: @@ -1018,7 +973,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An return status_info # Didn't match any known intents - raise RuntimeError("Unknown intent") + raise RuntimeError('Unknown intent') def broadcast_receive( self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any @@ -1047,9 +1002,7 @@ def broadcast_receive( return self._schedule_rpc(self.kill, msg=body) return None - def _schedule_rpc( - self, callback: Callable[..., Any], *args: Any, **kwargs: Any - ) -> kiwipy.Future: + def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future: """ Schedule a call to a callback as a result of an RPC communication call, this will return a future that resolves to the final result (even after one or more layer of futures being @@ -1113,13 +1066,9 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to( - process_states.Excepted, exception=exception, trace_back=trace - ) + self.transition_to(process_states.Excepted, exception=exception, trace_back=trace) - def pause( - self, msg: Union[str, None] = None - ) -> Union[bool, futures.CancellableAction]: + def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. :param msg: an optional message to set as the status. The current status will be saved in the private @@ -1164,9 +1113,7 @@ def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_state return True - def _create_interrupt_action( - self, exception: process_states.Interruption - ) -> futures.CancellableAction: + def _create_interrupt_action(self, exception: process_states.Interruption) -> futures.CancellableAction: """ Create an interrupt action from the corresponding interrupt exception @@ -1191,9 +1138,7 @@ def do_kill(_next_state: process_states.State) -> Any: raise ValueError(f"Got unknown interruption type '{type(exception)}'") - def _set_interrupt_action( - self, new_action: Optional[futures.CancellableAction] - ) -> None: + def _set_interrupt_action(self, new_action: Optional[futures.CancellableAction]) -> None: """ Set the interrupt action cancelling the current one if it exists :param new_action: The new interrupt action to set @@ -1230,17 +1175,13 @@ def resume(self, *args: Any) -> None: return self._state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) - def fail( - self, exception: Optional[BaseException], trace_back: Optional[TracebackType] - ) -> None: + def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None: """ Fail the process in response to an exception :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ - self.transition_to( - process_states.Excepted, exception=exception, trace_back=trace_back - ) + self.transition_to(process_states.Excepted, exception=exception, trace_back=trace_back) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ @@ -1298,9 +1239,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat :return: An instance of the object with its state loaded from the save state. """ load_context = persistence.LoadSaveContext(process=self) - return cast( - process_states.State, persistence.Savable.load(saved_state, load_context) - ) + return cast(process_states.State, persistence.Savable.load(saved_state, load_context)) # endregion @@ -1333,7 +1272,7 @@ async def step(self) -> None: The execute function running in this method is dependent on the state of the process. """ - assert not self.has_terminated(), "Cannot step, already terminated" + assert not self.has_terminated(), 'Cannot step, already terminated' if self.paused and self._paused is not None: await self._paused @@ -1358,9 +1297,7 @@ async def step(self) -> None: raise except Exception: # Overwrite the next state to go to excepted directly - next_state = self.create_state( - process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:] - ) + next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) self._set_interrupt_action(None) if self._interrupt_action: diff --git a/tests/test_processes.py b/tests/test_processes.py index b4526403..2780f41f 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -5,7 +5,6 @@ import copy import enum from plumpy.process_comms import KILL_MSG, MESSAGE_KEY -from test import utils import unittest import kiwipy