diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 2d6b3bf4..5049dad1 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -229,7 +229,9 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult': result = await asyncio.wrap_future(future) return result - async def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult': + async def kill_process( + self, pid: 'PID_TYPE', msg_text: Optional[str] = None, force_kill: bool = False + ) -> 'ProcessResult': """ Kill the process @@ -237,7 +239,7 @@ async def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> :param msg: optional kill message :return: True if killed, False otherwise """ - msg = MessageBuilder.kill(text=msg_text) + msg = MessageBuilder.kill(text=msg_text, force_kill=force_kill) # Wait for the communication to go through kill_future = self._communicator.rpc_send(pid, msg) @@ -401,7 +403,7 @@ def play_all(self) -> None: """ self._communicator.broadcast_send(None, subject=Intent.PLAY) - def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future: + def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None, force_kill: bool = False) -> kiwipy.Future: """ Kill the process @@ -409,7 +411,7 @@ def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwip :param msg: optional kill message :return: a response future from the process to be killed """ - msg = MessageBuilder.kill(text=msg_text) + msg = MessageBuilder.kill(text=msg_text, force_kill=force_kill) return self._communicator.rpc_send(pid, msg) def kill_all(self, msg_text: Optional[str]) -> None: diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 931dbc5e..7d452b24 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -31,10 +31,11 @@ 'Excepted', 'Finished', 'Interruption', + 'Killed', # Commands 'Kill', 'KillInterruption', - 'Killed', + 'ForceKillInterruption', 'PauseInterruption', 'ProcessState', 'Running', @@ -59,6 +60,14 @@ def __init__(self, msg_text: str | None): self.msg: MessageType = msg +class ForceKillInterruption(Interruption): + def __init__(self, msg_text: str | None): + super().__init__() + msg = MessageBuilder.kill(text=msg_text) + + self.msg: MessageType = msg + + class PauseInterruption(Interruption): def __init__(self, msg_text: str | None): super().__init__() diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 409374d0..7eca4411 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -965,7 +965,11 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: MessageType) -> Any: if intent == process_comms.Intent.PAUSE: return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) if intent == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) + return self._schedule_rpc( + self.kill, + msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None), + force_kill=msg.get(process_comms.FORCE_KILL_KEY), + ) if intent == process_comms.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) @@ -998,7 +1002,11 @@ def broadcast_receive( if subject == process_comms.Intent.PAUSE: return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) if subject == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) + return self._schedule_rpc( + self.kill, + msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None), + force_kill=msg.get(process_comms.FORCE_KILL_KEY), + ) return None def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future: @@ -1160,7 +1168,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu do_pause = functools.partial(self._do_pause, exception.msg) return futures.CancellableAction(do_pause, cookie=exception) - if isinstance(exception, process_states.KillInterruption): + if isinstance(exception, (process_states.KillInterruption, process_states.ForceKillInterruption)): def do_kill(_next_state: process_states.State) -> Any: try: @@ -1222,11 +1230,13 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac ) self.transition_to(new_state) - def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: + def kill(self, msg_text: Optional[str] = None, force_kill: bool = False) -> Union[bool, asyncio.Future]: """ Kill the process :param msg: An optional kill message + :param force_kill: An optional whether force kill the process """ + if self.state == process_states.ProcessState.KILLED: # Already killed return True @@ -1235,10 +1245,27 @@ def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: # Can't kill return False - if self._killing: + if self._killing and not force_kill: # Already killing return self._killing + if force_kill: + # Skip interrupting the state and go straight to killed + interrupt_exception = process_states.ForceKillInterruption(msg_text) + # XXX: this line was not in ali's PR but to make the change align with _stepping, + # it seems it is needed to set the _interrupt_action to be used line after. + # Requires more check to test with aiida-core's PR. + # + # self._set_interrupt_action_from_exception(interrupt_exception) + # + self._killing = self._interrupt_action + self._state.interrupt(interrupt_exception) + + msg = MessageBuilder.kill(msg_text, force_kill=True) + new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg) + self.transition_to(new_state) + return True + if self._stepping: # Ask the step function to pause by setting this flag and giving the # caller back a future @@ -1246,9 +1273,10 @@ def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: self._set_interrupt_action_from_exception(interrupt_exception) self._killing = self._interrupt_action self._state.interrupt(interrupt_exception) + return cast(futures.CancellableAction, self._interrupt_action) - msg = MessageBuilder.kill(msg_text) + msg = MessageBuilder.kill(msg_text, force_kill=False) new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg) self.transition_to(new_state) return True