Skip to content

Commit

Permalink
feat: use_jinja
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-friday committed Sep 27, 2024
1 parent d872e6a commit 5fdf0fd
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 6 deletions.
108 changes: 104 additions & 4 deletions src/timeout_executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,81 @@ def remove_callback(self, callback: ProcessCallback[P, T]) -> Self:
return self


class JinjaExecutor(Executor[P, T], Generic[P, T]):
__slots__ = (*Executor.__slots__, "_j2_script")
_j2_script: Path | None

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._j2_script = None

def _cleanup(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002
if self._j2_script is None:
return
with suppress(FileNotFoundError):
self._j2_script.unlink()
self._j2_script = None

@override
def _dump_args(
self, output_file: Path | anyio.Path, *args: P.args, **kwargs: P.kwargs
) -> bytes:
"""dump args and output file path to input file"""
input_args = (None, args, kwargs, str(output_file))
logger.debug("%r before dump input args", self)
input_args_as_bytes = cloudpickle.dumps(input_args)
logger.debug(
"%r after dump input args :: size: %d", self, len(input_args_as_bytes)
)
return input_args_as_bytes

@override
def _dump_initializer(self) -> bytes | None:
if self._initializer is None:
logger.debug("%r initializer is None", self)
return None
init_args = (None, self._initializer.args, self._initializer.kwargs)
logger.debug("%r before dump initializer", self)
init_args_as_bytes = cloudpickle.dumps(init_args)
logger.debug(
"%r after dump initializer :: size: %d", self, len(init_args_as_bytes)
)
return init_args_as_bytes

@override
def callbacks(self) -> Iterable[Callable[[CallbackArgs[P, T]], Any]]:
callbacks = super().callbacks()
return chain([self._cleanup], callbacks)

@override
def _command(self, stacklevel: int = 2) -> list[str]:
j2_script = self._render_jinja_subprocess()
self._j2_script = Path(tempfile.gettempdir()) / str(uuid4())
with self._j2_script.open("w+") as file:
file.write(j2_script)
command = f"{sys.executable} {self._j2_script}"
logger.debug("%r command: %s", self, command, stacklevel=stacklevel)
return shlex.split(command)

def _render_jinja_subprocess(self) -> str:
import jinja2

if self._initializer is None:
init_func_code = " def empty_initializer(): pass"
init_func_name = "empty_initializer"
else:
init_func_code, init_func_name = parse_func_code(self._initializer.function)
func_code, func_name = parse_func_code(self._func)
with Path(__file__).with_name("subprocess_jinja.py.j2").open("r") as file:
source = file.read()
return jinja2.Template(source).render(
func_code=func_code,
func_name=func_name,
init_func_code=init_func_code,
init_func_name=init_func_name,
)


@overload
def apply_func(
timeout_or_executor: float | TimeoutExecutor,
Expand Down Expand Up @@ -305,10 +380,17 @@ def apply_func(
Returns:
async result container
"""
executor_type = (
JinjaExecutor
if not isinstance(timeout_or_executor, (int, float))
and timeout_or_executor.use_jinja
else Executor
)

if isinstance(timeout_or_executor, (float, int)):
executor = Executor(timeout_or_executor, func)
executor = executor_type(timeout_or_executor, func)
else:
executor = Executor(
executor = executor_type(
timeout_or_executor.timeout,
func,
timeout_or_executor.callbacks,
Expand Down Expand Up @@ -352,10 +434,17 @@ async def delay_func(
Returns:
async result container
"""
executor_type = (
JinjaExecutor
if not isinstance(timeout_or_executor, (int, float))
and timeout_or_executor.use_jinja
else Executor
)

if isinstance(timeout_or_executor, (float, int)):
executor = Executor(timeout_or_executor, func)
executor = executor_type(timeout_or_executor, func)
else:
executor = Executor(
executor = executor_type(
timeout_or_executor.timeout, func, timeout_or_executor.callbacks
)
return await executor.delay(*args, **kwargs)
Expand All @@ -374,3 +463,14 @@ def is_class(obj: Any) -> bool:

meta = type(obj)
return issubclass(meta, type)


def parse_func_code(func: Callable[..., Any]) -> tuple[str, str]:
import inspect
import textwrap

source = inspect.getsource(func)
source = textwrap.dedent(source)
source = textwrap.indent(source, " ")

return source, func.__name__
5 changes: 3 additions & 2 deletions src/timeout_executor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
class TimeoutExecutor(Callback[Any, AnyT], Generic[AnyT]):
"""timeout executor"""

__slots__ = ("_timeout", "_callbacks", "initializer")
__slots__ = ("_timeout", "_callbacks", "initializer", "use_jinja")

def __init__(self, timeout: float) -> None:
def __init__(self, timeout: float, *, use_jinja: bool = False) -> None:
self._timeout = timeout
self._callbacks: deque[ProcessCallback[..., AnyT]] = deque()
self.initializer: InitializerArgs[..., Any] | None = None
self.use_jinja = use_jinja

@property
def timeout(self) -> float:
Expand Down
106 changes: 106 additions & 0 deletions src/timeout_executor/subprocess_jinja.py.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""only using in subprocess"""

from __future__ import annotations

from functools import partial
from inspect import isawaitable
from os import environ
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable

import anyio
import cloudpickle
from anyio.lowlevel import checkpoint

from timeout_executor.const import (
TIMEOUT_EXECUTOR_INIT_FILE,
TIMEOUT_EXECUTOR_INPUT_FILE,
)

if TYPE_CHECKING:
from typing_extensions import ParamSpec, TypeVar

P = ParamSpec("P")
T = TypeVar("T", infer_variance=True)

__all__ = []


def run_in_subprocess() -> None:
init_file = environ.get(TIMEOUT_EXECUTOR_INIT_FILE, "")
if init_file:
with Path(init_file).open("rb") as file_io:
_, init_args, init_kwargs = cloudpickle.load(file_io)
init_func(*init_args, **init_kwargs)

input_file = Path(environ.get(TIMEOUT_EXECUTOR_INPUT_FILE, ""))
with input_file.open("rb") as file_io:
_, args, kwargs, output_file = cloudpickle.load(file_io)

new_func = output_to_file(output_file)(func)
new_func(*args, **kwargs)


def dumps_value(value: Any) -> bytes:
if isinstance(value, BaseException):
from timeout_executor.serde import dumps_error

return dumps_error(value)
return cloudpickle.dumps(value)


def output_to_file(file: str) -> Callable[[Callable[P, T]], Callable[P, T]]:
def wrapper(func: Callable[P, T]) -> Callable[P, T]:
func = wrap_function_as_sync(func)

def inner(*args: P.args, **kwargs: P.kwargs) -> T:
dump = b""
try:
result = func(*args, **kwargs)
except BaseException as exc:
dump = dumps_value(exc)
raise
else:
dump = dumps_value(result)
return result
finally:
with open(file, "wb+") as file_io: # noqa: PTH123
file_io.write(dump)

return inner

return wrapper


def wrap_function_as_async(func: Callable[P, Any]) -> Callable[P, Any]:
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Any:
await checkpoint()
result = func(*args, **kwargs)
if isawaitable(result):
return await result
return result

return wrapped


def wrap_function_as_sync(func: Callable[P, Any]) -> Callable[P, Any]:
async_wrapped = wrap_function_as_async(func)

def wrapped(*args: P.args, **kwargs: P.kwargs) -> Any:
new_func = partial(async_wrapped, *args, **kwargs)
return anyio.run(new_func)

return wrapped

###

def init_func(*args: Any, **kwargs: Any) -> None:
{{ init_func_code }}
{{ init_func_name }}(*args, **kwargs)

def func(*args: Any, **kwargs: Any) -> Any:
{{ func_code }}
return {{ func_name }}(*args, **kwargs)

if __name__ == "__main__":
run_in_subprocess()

0 comments on commit 5fdf0fd

Please sign in to comment.