diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 2035d4ab..07c27076 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -314,7 +314,7 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: def on_terminated(self) -> None: """Called when a terminal state is entered""" - def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> None: + def transition_to(self, new_state: State | None, **kwargs: Any) -> None: """Transite to the new state. The new target state will be create lazily when the state is not yet instantiated, @@ -332,9 +332,9 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> try: self._transitioning = True - if not isinstance(new_state, State): - # Make sure we have a state instance - new_state = self._create_state_instance(new_state, **kwargs) + # if not isinstance(new_state, State): + # # Make sure we have a state instance + # new_state = self._create_state_instance(new_state, **kwargs) label = new_state.LABEL diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index ef558fa1..aef49acd 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1064,7 +1064,9 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to(process_states.Excepted, exception=exception, trace_back=trace) + state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace) + self.transition_to(new_state) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. @@ -1127,7 +1129,9 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: process_states.State) -> Any: try: - self.transition_to(process_states.Killed, msg=exception.msg) + state_class = self.get_states_map()[process_states.ProcessState.KILLED] + new_state = self._create_state_instance(state_class, msg=exception.msg) + self.transition_to(new_state) return True finally: self._killing = None @@ -1179,7 +1183,9 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ - self.transition_to(process_states.Excepted, exception=exception, trace_back=trace_back) + state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace_back) + self.transition_to(new_state) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ @@ -1207,7 +1213,9 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - self.transition_to(process_states.Killed, msg=msg) + state_class = self.get_states_map()[process_states.ProcessState.KILLED] + new_state = self._create_state_instance(state_class, msg=msg) + self.transition_to(new_state) return True @property diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index 5b4b73d8..3a1621a2 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -57,6 +57,7 @@ class Paused(state_machine.State): def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' super().__init__(player) + self._player = player self.playing_state = playing_state def __str__(self): @@ -64,7 +65,7 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing, track=track) + self.state_machine.transition_to(Playing(player=self.state_machine, track=track)) else: self.state_machine.transition_to(self.playing_state) @@ -80,7 +81,7 @@ def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing, track=track) + self.state_machine.transition_to(Playing(self.state_machine, track=track)) class CdPlayer(state_machine.StateMachine): @@ -107,12 +108,12 @@ def play(self, track=None): @state_machine.event(from_states=Playing, to_states=Paused) def pause(self): - self.transition_to(Paused, playing_state=self._state) + self.transition_to(Paused(self, playing_state=self._state)) return True @state_machine.event(from_states=(Playing, Paused), to_states=Stopped) def stop(self): - self.transition_to(Stopped) + self.transition_to(Stopped(self)) class TestStateMachine(unittest.TestCase):