diff --git a/pyproject.toml b/pyproject.toml index 10fa7c1..7fb8924 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,3 +64,4 @@ pythonVersion = '3.8' pythonPlatform = 'Linux' diagnostic = 'strict' reportPrivateUsage = false +reportMissingTypeStubs = false diff --git a/src/timeout_executor/executor.py b/src/timeout_executor/executor.py index 06a003f..b28ec2d 100644 --- a/src/timeout_executor/executor.py +++ b/src/timeout_executor/executor.py @@ -35,19 +35,19 @@ T2 = TypeVar("T2", infer_variance=True) -class Executor(Callback, Generic[P, T]): +class Executor(Callback[P, T], Generic[P, T]): def __init__( self, timeout: float, func: Callable[P, T], - callbacks: Callable[[], Iterable[ProcessCallback]] | None = None, + callbacks: Callable[[], Iterable[ProcessCallback[P, T]]] | None = None, ) -> None: self._timeout = timeout self._func = func self._func_name = func_name(func) self._unique_id = uuid4() self._init_callbacks = callbacks - self._callbacks: deque[ProcessCallback] = deque() + self._callbacks: deque[ProcessCallback[P, T]] = deque() @property def unique_id(self) -> UUID: @@ -75,7 +75,7 @@ def _dump_args( ) -> bytes: input_args = (self._func, args, kwargs, output_file) logger.debug("%r before dump input args", self) - input_args_as_bytes = cloudpickle.dumps(input_args) + input_args_as_bytes = cloudpickle.dumps(input_args) # pyright: ignore[reportUnknownMemberType] logger.debug( "%r after dump input args :: size: %d", self, len(input_args_as_bytes) ) @@ -100,8 +100,8 @@ def _create_executor_args( self, input_file: Path | anyio.Path, output_file: Path | anyio.Path, - terminator: Terminator, - ) -> ExecutorArgs: + terminator: Terminator[P, T], + ) -> ExecutorArgs[P, T]: return ExecutorArgs( executor=self, func_name=self._func_name, @@ -116,20 +116,20 @@ def _init_process( input_file: Path | anyio.Path, output_file: Path | anyio.Path, stacklevel: int = 2, - ) -> AsyncResult[T]: + ) -> AsyncResult[P, T]: logger.debug("%r before init process", self, stacklevel=stacklevel) executor_args_builder = partial( self._create_executor_args, input_file, output_file ) terminator = Terminator(executor_args_builder, self.callbacks) process = self._create_process(input_file, stacklevel=stacklevel + 1) - result = AsyncResult(process, terminator.executor_args) + result: AsyncResult[P, T] = AsyncResult(process, terminator.executor_args) terminator.callback_args = CallbackArgs(process=process, result=result) terminator.start() logger.debug("%r after init process", self, stacklevel=stacklevel) return result - def apply(self, *args: P.args, **kwargs: P.kwargs) -> AsyncResult[T]: + def apply(self, *args: P.args, **kwargs: P.kwargs) -> AsyncResult[P, T]: input_file, output_file = self._create_temp_files() input_args_as_bytes = self._dump_args(output_file, *args, **kwargs) @@ -140,7 +140,7 @@ def apply(self, *args: P.args, **kwargs: P.kwargs) -> AsyncResult[T]: return self._init_process(input_file, output_file) - async def delay(self, *args: P.args, **kwargs: P.kwargs) -> AsyncResult[T]: + async def delay(self, *args: P.args, **kwargs: P.kwargs) -> AsyncResult[P, T]: input_file, output_file = self._create_temp_files() input_file, output_file = anyio.Path(input_file), anyio.Path(output_file) input_args_as_bytes = self._dump_args(output_file, *args, **kwargs) @@ -156,18 +156,18 @@ def __repr__(self) -> str: return f"<{type(self).__name__}: {self._func_name}>" @override - def callbacks(self) -> Iterable[ProcessCallback]: + def callbacks(self) -> Iterable[ProcessCallback[P, T]]: if self._init_callbacks is None: return self._callbacks.copy() return chain(self._init_callbacks(), self._callbacks.copy()) @override - def add_callback(self, callback: ProcessCallback) -> Self: + def add_callback(self, callback: ProcessCallback[P, T]) -> Self: self._callbacks.append(callback) return self @override - def remove_callback(self, callback: ProcessCallback) -> Self: + def remove_callback(self, callback: ProcessCallback[P, T]) -> Self: with suppress(ValueError): self._callbacks.remove(callback) return self @@ -179,7 +179,7 @@ def apply_func( func: Callable[P2, Coroutine[Any, Any, T2]], *args: P2.args, **kwargs: P2.kwargs, -) -> AsyncResult[T2]: ... +) -> AsyncResult[P2, T2]: ... @overload @@ -188,7 +188,7 @@ def apply_func( func: Callable[P2, T2], *args: P2.args, **kwargs: P2.kwargs, -) -> AsyncResult[T2]: ... +) -> AsyncResult[P2, T2]: ... def apply_func( @@ -196,7 +196,7 @@ def apply_func( func: Callable[P2, Any], *args: P2.args, **kwargs: P2.kwargs, -) -> AsyncResult[Any]: +) -> AsyncResult[P2, Any]: """run function with deadline Args: @@ -221,7 +221,7 @@ async def delay_func( func: Callable[P2, Coroutine[Any, Any, T2]], *args: P2.args, **kwargs: P2.kwargs, -) -> AsyncResult[T2]: ... +) -> AsyncResult[P2, T2]: ... @overload @@ -230,7 +230,7 @@ async def delay_func( func: Callable[P2, T2], *args: P2.args, **kwargs: P2.kwargs, -) -> AsyncResult[T2]: ... +) -> AsyncResult[P2, T2]: ... async def delay_func( @@ -238,7 +238,7 @@ async def delay_func( func: Callable[P2, Any], *args: P2.args, **kwargs: P2.kwargs, -) -> AsyncResult[Any]: +) -> AsyncResult[P2, Any]: """run function with deadline Args: diff --git a/src/timeout_executor/main.py b/src/timeout_executor/main.py index 3ab22ff..2292196 100644 --- a/src/timeout_executor/main.py +++ b/src/timeout_executor/main.py @@ -2,7 +2,7 @@ from collections import deque from contextlib import suppress -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Iterable, overload +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generic, Iterable, overload from typing_extensions import ParamSpec, Self, TypeVar, override @@ -16,14 +16,15 @@ P = ParamSpec("P") T = TypeVar("T", infer_variance=True) +AnyT = TypeVar("AnyT", infer_variance=True, default=Any) -class TimeoutExecutor(Callback): +class TimeoutExecutor(Callback[..., AnyT], Generic[AnyT]): """timeout executor""" def __init__(self, timeout: float) -> None: self._timeout = timeout - self._callbacks: deque[ProcessCallback] = deque() + self._callbacks: deque[ProcessCallback[..., AnyT]] = deque() @property def timeout(self) -> float: @@ -36,14 +37,14 @@ def apply( func: Callable[P, Coroutine[Any, Any, T]], *args: P.args, **kwargs: P.kwargs, - ) -> AsyncResult[T]: ... + ) -> AsyncResult[P, T]: ... @overload def apply( self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs - ) -> AsyncResult[T]: ... + ) -> AsyncResult[P, T]: ... def apply( self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs - ) -> AsyncResult[Any]: + ) -> AsyncResult[P, Any]: """run function with deadline Args: @@ -60,14 +61,14 @@ async def delay( func: Callable[P, Coroutine[Any, Any, T]], *args: P.args, **kwargs: P.kwargs, - ) -> AsyncResult[T]: ... + ) -> AsyncResult[P, T]: ... @overload async def delay( self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs - ) -> AsyncResult[T]: ... + ) -> AsyncResult[P, T]: ... async def delay( self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs - ) -> AsyncResult[Any]: + ) -> AsyncResult[P, Any]: """run function with deadline Args: @@ -84,14 +85,14 @@ async def apply_async( func: Callable[P, Coroutine[Any, Any, T]], *args: P.args, **kwargs: P.kwargs, - ) -> AsyncResult[T]: ... + ) -> AsyncResult[P, T]: ... @overload async def apply_async( self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs - ) -> AsyncResult[T]: ... + ) -> AsyncResult[P, T]: ... async def apply_async( self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs - ) -> AsyncResult[Any]: + ) -> AsyncResult[P, Any]: """run function with deadline. alias of `delay` @@ -108,16 +109,16 @@ def __repr__(self) -> str: return f"<{type(self).__name__}, timeout: {self.timeout:.2f}s>" @override - def callbacks(self) -> Iterable[ProcessCallback]: + def callbacks(self) -> Iterable[ProcessCallback[..., AnyT]]: return self._callbacks.copy() @override - def add_callback(self, callback: ProcessCallback) -> Self: + def add_callback(self, callback: ProcessCallback[..., AnyT]) -> Self: self._callbacks.append(callback) return self @override - def remove_callback(self, callback: ProcessCallback) -> Self: + def remove_callback(self, callback: ProcessCallback[..., AnyT]) -> Self: with suppress(ValueError): self._callbacks.remove(callback) return self diff --git a/src/timeout_executor/result.py b/src/timeout_executor/result.py index 2e86e2c..7ff2236 100644 --- a/src/timeout_executor/result.py +++ b/src/timeout_executor/result.py @@ -7,7 +7,7 @@ import anyio import cloudpickle from async_wrapper import async_to_sync, sync_to_async -from typing_extensions import Self, TypeVar, override +from typing_extensions import ParamSpec, Self, TypeVar, override from timeout_executor.logging import logger from timeout_executor.serde import SerializedError, loads_error @@ -22,18 +22,19 @@ __all__ = ["AsyncResult"] +P = ParamSpec("P") T = TypeVar("T", infer_variance=True) SENTINEL = object() -class AsyncResult(Callback, Generic[T]): +class AsyncResult(Callback[P, T], Generic[P, T]): """async result container""" _result: Any def __init__( - self, process: subprocess.Popen[str], executor_args: ExecutorArgs + self, process: subprocess.Popen[str], executor_args: ExecutorArgs[P, T] ) -> None: self._process = process @@ -58,7 +59,7 @@ def _func_name(self) -> str: return self._executor_args.func_name @property - def _terminator(self) -> Terminator: + def _terminator(self) -> Terminator[P, T]: return self._executor_args.terminator def result(self, timeout: float | None = None) -> T: @@ -128,17 +129,17 @@ def __repr__(self) -> str: return f"<{type(self).__name__}: {self._func_name}>" @override - def add_callback(self, callback: ProcessCallback) -> Self: + def add_callback(self, callback: ProcessCallback[P, T]) -> Self: self._terminator.add_callback(callback) return self @override - def remove_callback(self, callback: ProcessCallback) -> Self: + def remove_callback(self, callback: ProcessCallback[P, T]) -> Self: self._terminator.remove_callback(callback) return self @override - def callbacks(self) -> Iterable[ProcessCallback]: + def callbacks(self) -> Iterable[ProcessCallback[P, T]]: return self._terminator.callbacks() diff --git a/src/timeout_executor/serde.py b/src/timeout_executor/serde.py index 65b7e3d..045923f 100644 --- a/src/timeout_executor/serde.py +++ b/src/timeout_executor/serde.py @@ -9,10 +9,10 @@ import cloudpickle from tblib.pickling_support import ( - pickle_exception, - pickle_traceback, - unpickle_exception, - unpickle_traceback, + pickle_exception, # pyright: ignore[reportUnknownVariableType] + pickle_traceback, # pyright: ignore[reportUnknownVariableType] + unpickle_exception, # pyright: ignore[reportUnknownVariableType] + unpickle_traceback, # pyright: ignore[reportUnknownVariableType] ) __all__ = ["dumps_error", "loads_error", "serialize_error", "deserialize_error"] @@ -34,7 +34,7 @@ class SerializedError: def serialize_traceback(traceback: TracebackType) -> tuple[Any, ...]: - return pickle_traceback(traceback) + return pickle_traceback(traceback) # pyright: ignore[reportUnknownVariableType] def serialize_error(error: BaseException) -> SerializedError: @@ -84,7 +84,7 @@ def deserialize_error(error: SerializedError) -> BaseException: traceback = unpickle_traceback(*value) result.insert(index + salt, traceback) - return unpickle_exception(*arg_exception, *exception) + return unpickle_exception(*arg_exception, *exception) # pyright: ignore[reportUnknownVariableType] def dumps_error(error: BaseException | SerializedError) -> bytes: @@ -92,7 +92,7 @@ def dumps_error(error: BaseException | SerializedError) -> bytes: if not isinstance(error, SerializedError): error = serialize_error(error) - return cloudpickle.dumps(error) + return cloudpickle.dumps(error) # pyright: ignore[reportUnknownMemberType] def loads_error(error: bytes | SerializedError) -> BaseException: diff --git a/src/timeout_executor/subprocess.py b/src/timeout_executor/subprocess.py index e3d0f8e..53e559c 100644 --- a/src/timeout_executor/subprocess.py +++ b/src/timeout_executor/subprocess.py @@ -37,7 +37,7 @@ def run_in_subprocess() -> None: def dumps_value(value: Any) -> bytes: if isinstance(value, BaseException): return dumps_error(value) - return cloudpickle.dumps(value) + return cloudpickle.dumps(value) # pyright: ignore[reportUnknownMemberType] def output_to_file( diff --git a/src/timeout_executor/terminate.py b/src/timeout_executor/terminate.py index 1570178..fc722ed 100644 --- a/src/timeout_executor/terminate.py +++ b/src/timeout_executor/terminate.py @@ -6,49 +6,52 @@ from collections import deque from contextlib import suppress from itertools import chain -from typing import Callable, Iterable +from typing import Any, Callable, Generic, Iterable from psutil import pid_exists -from typing_extensions import Self, override +from typing_extensions import ParamSpec, Self, TypeVar, override from timeout_executor.logging import logger from timeout_executor.types import Callback, CallbackArgs, ExecutorArgs, ProcessCallback __all__ = [] +P = ParamSpec("P") +T = TypeVar("T", infer_variance=True) -class Terminator(Callback): + +class Terminator(Callback[P, T], Generic[P, T]): _process: subprocess.Popen[str] | None _callback_thread: threading.Thread | None _terminator_thread: threading.Thread | None def __init__( self, - executor_args_factory: Callable[[Terminator], ExecutorArgs], - callbacks: Callable[[], Iterable[ProcessCallback]] | None = None, + executor_args_factory: Callable[[Terminator[P, T]], ExecutorArgs[P, T]], + callbacks: Callable[[], Iterable[ProcessCallback[P, T]]] | None = None, ) -> None: self._is_active = False self._executor_args = executor_args_factory(self) self._init_callbacks = callbacks - self._callbacks: deque[ProcessCallback] = deque() + self._callbacks: deque[ProcessCallback[P, T]] = deque() self._callback_thread = None self._terminator_thread = None - self._callback_args = None + self._callback_args: CallbackArgs[P, T] | None = None @property - def executor_args(self) -> ExecutorArgs: + def executor_args(self) -> ExecutorArgs[P, T]: return self._executor_args @property - def callback_args(self) -> CallbackArgs: + def callback_args(self) -> CallbackArgs[P, T]: if self._callback_args is None: raise AttributeError("there is no callback args") return self._callback_args @callback_args.setter - def callback_args(self, value: CallbackArgs) -> None: + def callback_args(self, value: CallbackArgs[P, T]) -> None: if self._callback_args is not None: raise AttributeError("already has callback args") self._callback_args = value @@ -147,13 +150,13 @@ def func_name(self) -> str: return self._executor_args.func_name @override - def callbacks(self) -> Iterable[ProcessCallback]: + def callbacks(self) -> Iterable[ProcessCallback[P, T]]: if self._init_callbacks is None: return self._callbacks.copy() return chain(self._init_callbacks(), self._callbacks.copy()) @override - def add_callback(self, callback: ProcessCallback) -> Self: + def add_callback(self, callback: ProcessCallback[P, T]) -> Self: if ( self.is_active or self.callback_args.process.returncode is not None @@ -165,13 +168,13 @@ def add_callback(self, callback: ProcessCallback) -> Self: return self @override - def remove_callback(self, callback: ProcessCallback) -> Self: + def remove_callback(self, callback: ProcessCallback[P, T]) -> Self: with suppress(ValueError): self._callbacks.remove(callback) return self -def terminate(terminator: Terminator) -> None: +def terminate(terminator: Terminator[Any, Any]) -> None: try: with suppress(TimeoutError, subprocess.TimeoutExpired): terminator.callback_args.process.wait(terminator.timeout) @@ -179,7 +182,7 @@ def terminate(terminator: Terminator) -> None: terminator.close("terminator thread") -def callback(terminator: Terminator) -> None: +def callback(terminator: Terminator[Any, Any]) -> None: try: terminator.callback_args.process.wait() finally: diff --git a/src/timeout_executor/types.py b/src/timeout_executor/types.py index 5006586..d93e8fa 100644 --- a/src/timeout_executor/types.py +++ b/src/timeout_executor/types.py @@ -4,7 +4,9 @@ from abc import ABC, abstractmethod from collections import deque from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Iterable +from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable + +from typing_extensions import ParamSpec, TypeVar from timeout_executor.logging import logger @@ -31,14 +33,17 @@ _DATACLASS_FROZEN_KWARGS.update({"kw_only": True, "slots": True}) _DATACLASS_NON_FROZEN_KWARGS.update({"kw_only": True, "slots": True}) +P = ParamSpec("P") +T = TypeVar("T", infer_variance=True) + @dataclass(**_DATACLASS_FROZEN_KWARGS) -class ExecutorArgs: +class ExecutorArgs(Generic[P, T]): """executor args""" - executor: Executor + executor: Executor[P, T] func_name: str - terminator: Terminator + terminator: Terminator[P, T] input_file: Path | anyio.Path output_file: Path | anyio.Path timeout: float @@ -50,30 +55,30 @@ class State: @dataclass(**_DATACLASS_FROZEN_KWARGS) -class CallbackArgs: +class CallbackArgs(Generic[P, T]): """callback args""" process: subprocess.Popen[str] - result: AsyncResult + result: AsyncResult[P, T] state: State = field(init=False, default_factory=State) -class Callback(ABC): +class Callback(ABC, Generic[P, T]): """callback api interface""" @abstractmethod - def callbacks(self) -> Iterable[ProcessCallback]: + def callbacks(self) -> Iterable[ProcessCallback[P, T]]: """return callbacks""" @abstractmethod - def add_callback(self, callback: ProcessCallback) -> Self: + def add_callback(self, callback: ProcessCallback[P, T]) -> Self: """add callback""" @abstractmethod - def remove_callback(self, callback: ProcessCallback) -> Self: + def remove_callback(self, callback: ProcessCallback[P, T]) -> Self: """remove callback if exists""" - def run_callbacks(self, callback_args: CallbackArgs, func_name: str) -> None: + def run_callbacks(self, callback_args: CallbackArgs[P, T], func_name: str) -> None: """run all callbacks""" logger.debug("%r start callbacks", self) errors: deque[Exception] = deque() @@ -90,4 +95,4 @@ def run_callbacks(self, callback_args: CallbackArgs, func_name: str) -> None: raise ExceptionGroup(error_msg, errors) -ProcessCallback: TypeAlias = "Callable[[CallbackArgs], Any]" +ProcessCallback: TypeAlias = "Callable[[CallbackArgs[P, T]], Any]"