diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 3400026a..c9a08848 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -315,7 +315,10 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A # Make sure we have a state instance new_state = self._create_state_instance(new_state, *args, **kwargs) label = new_state.LABEL - self._exit_current_state(new_state) + + # If the previous transition failed, do not try to exit it but go straight to next state + if not self._transition_failing: + self._exit_current_state(new_state) try: self._enter_next_state(new_state) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 523fe2ea..41bd3609 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -828,6 +828,13 @@ def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None: """Entering the EXCEPTED state.""" exception = exc_info[1] exception.__traceback__ = exc_info[2] + + # It is possible that we already got into a finished state and the future result was set, in which case, we + # should reset it before setting the exception or else ``asyncio`` will raise an exception. + future = self.future() + + if future.done(): + self._future = persistence.SavableFuture(loop=self._loop) self.future().set_exception(exception) @super_check diff --git a/test/test_processes.py b/test/test_processes.py index 7df907ed..e357fc4d 100644 --- a/test/test_processes.py +++ b/test/test_processes.py @@ -654,6 +654,40 @@ def test_execute_twice(self): with self.assertRaises(plumpy.ClosedError): proc.execute() + def test_exception_during_on_entered(self): + """Test that an exception raised during ``on_entered`` will cause the process to be excepted.""" + + class RaisingProcess(Process): + + def on_entered(self, from_state): + if from_state is not None and from_state.label == ProcessState.RUNNING: + raise RuntimeError('exception during on_entered') + super().on_entered(from_state) + + process = RaisingProcess() + + with self.assertRaises(RuntimeError): + process.execute() + + assert not process.is_successful + assert process.is_excepted + assert str(process.exception()) == 'exception during on_entered' + + def test_exception_during_run(self): + + class RaisingProcess(Process): + + def run(self): + raise RuntimeError('exception during run') + + process = RaisingProcess() + + with self.assertRaises(RuntimeError): + process.execute() + + assert process.is_excepted + assert str(process.exception()) == 'exception during run' + @plumpy.auto_persist('steps_ran') class SavePauseProc(plumpy.Process):