Skip to content

Commit

Permalink
fix: pyright errors
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-friday committed May 19, 2024
1 parent 940d4a5 commit 90ebc67
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 76 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,4 @@ pythonVersion = '3.8'
pythonPlatform = 'Linux'
diagnostic = 'strict'
reportPrivateUsage = false
reportMissingTypeStubs = false
38 changes: 19 additions & 19 deletions src/timeout_executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,19 @@
T2 = TypeVar("T2", infer_variance=True)


class Executor(Callback, Generic[P, T]):
class Executor(Callback[P, T], Generic[P, T]):
def __init__(
self,
timeout: float,
func: Callable[P, T],
callbacks: Callable[[], Iterable[ProcessCallback]] | None = None,
callbacks: Callable[[], Iterable[ProcessCallback[P, T]]] | None = None,
) -> None:
self._timeout = timeout
self._func = func
self._func_name = func_name(func)
self._unique_id = uuid4()
self._init_callbacks = callbacks
self._callbacks: deque[ProcessCallback] = deque()
self._callbacks: deque[ProcessCallback[P, T]] = deque()

@property
def unique_id(self) -> UUID:
Expand Down Expand Up @@ -75,7 +75,7 @@ def _dump_args(
) -> bytes:
input_args = (self._func, args, kwargs, output_file)
logger.debug("%r before dump input args", self)
input_args_as_bytes = cloudpickle.dumps(input_args)
input_args_as_bytes = cloudpickle.dumps(input_args) # pyright: ignore[reportUnknownMemberType]
logger.debug(
"%r after dump input args :: size: %d", self, len(input_args_as_bytes)
)
Expand All @@ -100,8 +100,8 @@ def _create_executor_args(
self,
input_file: Path | anyio.Path,
output_file: Path | anyio.Path,
terminator: Terminator,
) -> ExecutorArgs:
terminator: Terminator[P, T],
) -> ExecutorArgs[P, T]:
return ExecutorArgs(
executor=self,
func_name=self._func_name,
Expand All @@ -116,20 +116,20 @@ def _init_process(
input_file: Path | anyio.Path,
output_file: Path | anyio.Path,
stacklevel: int = 2,
) -> AsyncResult[T]:
) -> AsyncResult[P, T]:
logger.debug("%r before init process", self, stacklevel=stacklevel)
executor_args_builder = partial(
self._create_executor_args, input_file, output_file
)
terminator = Terminator(executor_args_builder, self.callbacks)
process = self._create_process(input_file, stacklevel=stacklevel + 1)
result = AsyncResult(process, terminator.executor_args)
result: AsyncResult[P, T] = AsyncResult(process, terminator.executor_args)
terminator.callback_args = CallbackArgs(process=process, result=result)
terminator.start()
logger.debug("%r after init process", self, stacklevel=stacklevel)
return result

def apply(self, *args: P.args, **kwargs: P.kwargs) -> AsyncResult[T]:
def apply(self, *args: P.args, **kwargs: P.kwargs) -> AsyncResult[P, T]:
input_file, output_file = self._create_temp_files()
input_args_as_bytes = self._dump_args(output_file, *args, **kwargs)

Expand All @@ -140,7 +140,7 @@ def apply(self, *args: P.args, **kwargs: P.kwargs) -> AsyncResult[T]:

return self._init_process(input_file, output_file)

async def delay(self, *args: P.args, **kwargs: P.kwargs) -> AsyncResult[T]:
async def delay(self, *args: P.args, **kwargs: P.kwargs) -> AsyncResult[P, T]:
input_file, output_file = self._create_temp_files()
input_file, output_file = anyio.Path(input_file), anyio.Path(output_file)
input_args_as_bytes = self._dump_args(output_file, *args, **kwargs)
Expand All @@ -156,18 +156,18 @@ def __repr__(self) -> str:
return f"<{type(self).__name__}: {self._func_name}>"

@override
def callbacks(self) -> Iterable[ProcessCallback]:
def callbacks(self) -> Iterable[ProcessCallback[P, T]]:
if self._init_callbacks is None:
return self._callbacks.copy()
return chain(self._init_callbacks(), self._callbacks.copy())

@override
def add_callback(self, callback: ProcessCallback) -> Self:
def add_callback(self, callback: ProcessCallback[P, T]) -> Self:
self._callbacks.append(callback)
return self

@override
def remove_callback(self, callback: ProcessCallback) -> Self:
def remove_callback(self, callback: ProcessCallback[P, T]) -> Self:
with suppress(ValueError):
self._callbacks.remove(callback)
return self
Expand All @@ -179,7 +179,7 @@ def apply_func(
func: Callable[P2, Coroutine[Any, Any, T2]],
*args: P2.args,
**kwargs: P2.kwargs,
) -> AsyncResult[T2]: ...
) -> AsyncResult[P2, T2]: ...


@overload
Expand All @@ -188,15 +188,15 @@ def apply_func(
func: Callable[P2, T2],
*args: P2.args,
**kwargs: P2.kwargs,
) -> AsyncResult[T2]: ...
) -> AsyncResult[P2, T2]: ...


def apply_func(
timeout_or_executor: float | TimeoutExecutor,
func: Callable[P2, Any],
*args: P2.args,
**kwargs: P2.kwargs,
) -> AsyncResult[Any]:
) -> AsyncResult[P2, Any]:
"""run function with deadline
Args:
Expand All @@ -221,7 +221,7 @@ async def delay_func(
func: Callable[P2, Coroutine[Any, Any, T2]],
*args: P2.args,
**kwargs: P2.kwargs,
) -> AsyncResult[T2]: ...
) -> AsyncResult[P2, T2]: ...


@overload
Expand All @@ -230,15 +230,15 @@ async def delay_func(
func: Callable[P2, T2],
*args: P2.args,
**kwargs: P2.kwargs,
) -> AsyncResult[T2]: ...
) -> AsyncResult[P2, T2]: ...


async def delay_func(
timeout_or_executor: float | TimeoutExecutor,
func: Callable[P2, Any],
*args: P2.args,
**kwargs: P2.kwargs,
) -> AsyncResult[Any]:
) -> AsyncResult[P2, Any]:
"""run function with deadline
Args:
Expand Down
31 changes: 16 additions & 15 deletions src/timeout_executor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections import deque
from contextlib import suppress
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Iterable, overload
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generic, Iterable, overload

from typing_extensions import ParamSpec, Self, TypeVar, override

Expand All @@ -16,14 +16,15 @@

P = ParamSpec("P")
T = TypeVar("T", infer_variance=True)
AnyT = TypeVar("AnyT", infer_variance=True, default=Any)


class TimeoutExecutor(Callback):
class TimeoutExecutor(Callback[..., AnyT], Generic[AnyT]):
"""timeout executor"""

def __init__(self, timeout: float) -> None:
self._timeout = timeout
self._callbacks: deque[ProcessCallback] = deque()
self._callbacks: deque[ProcessCallback[..., AnyT]] = deque()

@property
def timeout(self) -> float:
Expand All @@ -36,14 +37,14 @@ def apply(
func: Callable[P, Coroutine[Any, Any, T]],
*args: P.args,
**kwargs: P.kwargs,
) -> AsyncResult[T]: ...
) -> AsyncResult[P, T]: ...
@overload
def apply(
self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> AsyncResult[T]: ...
) -> AsyncResult[P, T]: ...
def apply(
self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs
) -> AsyncResult[Any]:
) -> AsyncResult[P, Any]:
"""run function with deadline
Args:
Expand All @@ -60,14 +61,14 @@ async def delay(
func: Callable[P, Coroutine[Any, Any, T]],
*args: P.args,
**kwargs: P.kwargs,
) -> AsyncResult[T]: ...
) -> AsyncResult[P, T]: ...
@overload
async def delay(
self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> AsyncResult[T]: ...
) -> AsyncResult[P, T]: ...
async def delay(
self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs
) -> AsyncResult[Any]:
) -> AsyncResult[P, Any]:
"""run function with deadline
Args:
Expand All @@ -84,14 +85,14 @@ async def apply_async(
func: Callable[P, Coroutine[Any, Any, T]],
*args: P.args,
**kwargs: P.kwargs,
) -> AsyncResult[T]: ...
) -> AsyncResult[P, T]: ...
@overload
async def apply_async(
self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> AsyncResult[T]: ...
) -> AsyncResult[P, T]: ...
async def apply_async(
self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs
) -> AsyncResult[Any]:
) -> AsyncResult[P, Any]:
"""run function with deadline.
alias of `delay`
Expand All @@ -108,16 +109,16 @@ def __repr__(self) -> str:
return f"<{type(self).__name__}, timeout: {self.timeout:.2f}s>"

@override
def callbacks(self) -> Iterable[ProcessCallback]:
def callbacks(self) -> Iterable[ProcessCallback[..., AnyT]]:
return self._callbacks.copy()

@override
def add_callback(self, callback: ProcessCallback) -> Self:
def add_callback(self, callback: ProcessCallback[..., AnyT]) -> Self:
self._callbacks.append(callback)
return self

@override
def remove_callback(self, callback: ProcessCallback) -> Self:
def remove_callback(self, callback: ProcessCallback[..., AnyT]) -> Self:
with suppress(ValueError):
self._callbacks.remove(callback)
return self
15 changes: 8 additions & 7 deletions src/timeout_executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import anyio
import cloudpickle
from async_wrapper import async_to_sync, sync_to_async
from typing_extensions import Self, TypeVar, override
from typing_extensions import ParamSpec, Self, TypeVar, override

from timeout_executor.logging import logger
from timeout_executor.serde import SerializedError, loads_error
Expand All @@ -22,18 +22,19 @@

__all__ = ["AsyncResult"]

P = ParamSpec("P")
T = TypeVar("T", infer_variance=True)

SENTINEL = object()


class AsyncResult(Callback, Generic[T]):
class AsyncResult(Callback[P, T], Generic[P, T]):
"""async result container"""

_result: Any

def __init__(
self, process: subprocess.Popen[str], executor_args: ExecutorArgs
self, process: subprocess.Popen[str], executor_args: ExecutorArgs[P, T]
) -> None:
self._process = process

Expand All @@ -58,7 +59,7 @@ def _func_name(self) -> str:
return self._executor_args.func_name

@property
def _terminator(self) -> Terminator:
def _terminator(self) -> Terminator[P, T]:
return self._executor_args.terminator

def result(self, timeout: float | None = None) -> T:
Expand Down Expand Up @@ -128,17 +129,17 @@ def __repr__(self) -> str:
return f"<{type(self).__name__}: {self._func_name}>"

@override
def add_callback(self, callback: ProcessCallback) -> Self:
def add_callback(self, callback: ProcessCallback[P, T]) -> Self:
self._terminator.add_callback(callback)
return self

@override
def remove_callback(self, callback: ProcessCallback) -> Self:
def remove_callback(self, callback: ProcessCallback[P, T]) -> Self:
self._terminator.remove_callback(callback)
return self

@override
def callbacks(self) -> Iterable[ProcessCallback]:
def callbacks(self) -> Iterable[ProcessCallback[P, T]]:
return self._terminator.callbacks()


Expand Down
14 changes: 7 additions & 7 deletions src/timeout_executor/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

import cloudpickle
from tblib.pickling_support import (
pickle_exception,
pickle_traceback,
unpickle_exception,
unpickle_traceback,
pickle_exception, # pyright: ignore[reportUnknownVariableType]
pickle_traceback, # pyright: ignore[reportUnknownVariableType]
unpickle_exception, # pyright: ignore[reportUnknownVariableType]
unpickle_traceback, # pyright: ignore[reportUnknownVariableType]
)

__all__ = ["dumps_error", "loads_error", "serialize_error", "deserialize_error"]
Expand All @@ -34,7 +34,7 @@ class SerializedError:


def serialize_traceback(traceback: TracebackType) -> tuple[Any, ...]:
return pickle_traceback(traceback)
return pickle_traceback(traceback) # pyright: ignore[reportUnknownVariableType]


def serialize_error(error: BaseException) -> SerializedError:
Expand Down Expand Up @@ -84,15 +84,15 @@ def deserialize_error(error: SerializedError) -> BaseException:
traceback = unpickle_traceback(*value)
result.insert(index + salt, traceback)

return unpickle_exception(*arg_exception, *exception)
return unpickle_exception(*arg_exception, *exception) # pyright: ignore[reportUnknownVariableType]


def dumps_error(error: BaseException | SerializedError) -> bytes:
"""serialize exception as bytes"""
if not isinstance(error, SerializedError):
error = serialize_error(error)

return cloudpickle.dumps(error)
return cloudpickle.dumps(error) # pyright: ignore[reportUnknownMemberType]


def loads_error(error: bytes | SerializedError) -> BaseException:
Expand Down
2 changes: 1 addition & 1 deletion src/timeout_executor/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def run_in_subprocess() -> None:
def dumps_value(value: Any) -> bytes:
if isinstance(value, BaseException):
return dumps_error(value)
return cloudpickle.dumps(value)
return cloudpickle.dumps(value) # pyright: ignore[reportUnknownMemberType]


def output_to_file(
Expand Down
Loading

0 comments on commit 90ebc67

Please sign in to comment.