Skip to content

Commit

Permalink
support generator send method
Browse files Browse the repository at this point in the history
  • Loading branch information
Nanguage committed Sep 28, 2024
1 parent 02db959 commit 669d71f
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 21 deletions.
27 changes: 27 additions & 0 deletions docs/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,33 @@ asyncio.run(main())
Do not use `engine.wait()` to wait the generator job done,
because the generator job's future is done only when the generator is exhausted.

Generator support the `send` method, you can use this feature to pass data to the generator, it allow you communicate with another thread/process:

```python
import asyncio
from executor.engine import Engine, ProcessJob

def calculator():
res = None
while True:
expr = yield res
res = eval(expr)


async def main():
with Engine() as engine:
job = ProcessJob(calculator)
await engine.submit_async(job)
await job.wait_until_status("running")
g = job.result()
g.send(None) # initialize the generator
print(g.send("1 + 2")) # 3
print(g.send("3 * 4")) # 12
print(g.send("(1 + 2) * 4")) # 12

asyncio.run(main())
```

## Engine

`executor.engine` provides a `Engine` class for managing jobs.
Expand Down
2 changes: 1 addition & 1 deletion executor/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .core import Engine, EngineSetting
from .job import LocalJob, ThreadJob, ProcessJob

__version__ = '0.2.6'
__version__ = '0.2.7'

__all__ = [
'Engine', 'EngineSetting',
Expand Down
4 changes: 2 additions & 2 deletions executor/engine/job/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dask.distributed import Client, LocalCluster

from .base import Job
from .utils import GeneratorWrapper
from .utils import create_generator_wrapper
from ..utils import PortManager


Expand Down Expand Up @@ -69,7 +69,7 @@ async def run_generator(self):
func = functools.partial(self.func, *self.args, **self.kwargs)
fut = client.submit(func)
self._executor = client.get_executor(pure=False)
result = GeneratorWrapper(self, fut)
result = create_generator_wrapper(self, fut)
return result

async def cancel(self):
Expand Down
4 changes: 2 additions & 2 deletions executor/engine/job/local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .base import Job
from .utils import GeneratorWrapper
from .utils import create_generator_wrapper


class LocalJob(Job):
Expand All @@ -10,4 +10,4 @@ async def run_function(self):

async def run_generator(self):
"""Run job as a generator."""
return GeneratorWrapper(self)
return create_generator_wrapper(self)
4 changes: 2 additions & 2 deletions executor/engine/job/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from loky.process_executor import ProcessPoolExecutor

from .base import Job
from .utils import _gen_initializer, GeneratorWrapper
from .utils import _gen_initializer, create_generator_wrapper


class ProcessJob(Job):
Expand Down Expand Up @@ -56,7 +56,7 @@ async def run_generator(self):
func = functools.partial(self.func, *self.args, **self.kwargs)
self._executor = ProcessPoolExecutor(
1, initializer=_gen_initializer, initargs=(func,))
result = GeneratorWrapper(self)
result = create_generator_wrapper(self)
return result

async def cancel(self):
Expand Down
4 changes: 2 additions & 2 deletions executor/engine/job/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from concurrent.futures import ThreadPoolExecutor

from .base import Job
from .utils import _gen_initializer, GeneratorWrapper
from .utils import _gen_initializer, create_generator_wrapper


class ThreadJob(Job):
Expand Down Expand Up @@ -55,7 +55,7 @@ async def run_generator(self):
func = functools.partial(self.func, *self.args, **self.kwargs)
self._executor = ThreadPoolExecutor(
1, initializer=_gen_initializer, initargs=(func,))
result = GeneratorWrapper(self)
result = create_generator_wrapper(self)
return result

async def cancel(self):
Expand Down
71 changes: 59 additions & 12 deletions executor/engine/job/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as T
import asyncio
import inspect
from datetime import datetime
from concurrent.futures import Future
import threading
Expand Down Expand Up @@ -49,20 +50,28 @@ def _gen_initializer(gen_func, args=tuple(), kwargs={}): # pragma: no cover
_thread_locals._generator = gen_func(*args, **kwargs)


def _gen_next(fut=None): # pragma: no cover
def _gen_next(send_value=None, fut=None): # pragma: no cover
global _thread_locals
if fut is None:
return next(_thread_locals._generator)
g = _thread_locals._generator
else:
return next(fut)
g = fut
if send_value is None:
return next(g)
else:
return g.send(send_value)


def _gen_anext(fut=None): # pragma: no cover
def _gen_anext(send_value=None, fut=None): # pragma: no cover
global _thread_locals
if fut is None:
return asyncio.run(_thread_locals._generator.__anext__())
g = _thread_locals._generator
else:
g = fut
if send_value is None:
return asyncio.run(g.__anext__())
else:
return asyncio.run(fut.__anext__())
return asyncio.run(g.asend(send_value))


class GeneratorWrapper():
Expand All @@ -75,19 +84,28 @@ def __init__(self, job: "Job", fut: T.Optional[Future] = None):
self._fut = fut
self._local_res = None


class SyncGeneratorWrapper(GeneratorWrapper):
"""
wrap a generator in executor pool
"""
def __iter__(self):
return self

def __next__(self):
def _next(self, send_value=None):
try:
if self._job._executor is not None:
return self._job._executor.submit(
_gen_next, self._fut).result()
_gen_next, send_value, self._fut).result()
else:
# create local generator
if self._local_res is None:
self._local_res = self._job.func(
*self._job.args, **self._job.kwargs)
return next(self._local_res)
if send_value is not None:
return self._local_res.send(send_value)
else:
return next(self._local_res)
except Exception as e:
engine = self._job.engine
if engine is None:
Expand All @@ -102,23 +120,52 @@ def __next__(self):
fut.result()
raise e

def __next__(self):
return self._next()

def send(self, value):
return self._next(value)


class AsyncGeneratorWrapper(GeneratorWrapper):
"""
wrap a generator in executor pool
"""
def __aiter__(self):
return self

async def __anext__(self):
async def _anext(self, send_value=None):
try:
if self._job._executor is not None:
fut = self._job._executor.submit(_gen_anext, self._fut)
fut = self._job._executor.submit(
_gen_anext, send_value, self._fut)
res = await asyncio.wrap_future(fut)
return res
else:
if self._local_res is None:
self._local_res = self._job.func(
*self._job.args, **self._job.kwargs)
return await self._local_res.__anext__()
if send_value is not None:
return await self._local_res.asend(send_value)
else:
return await self._local_res.__anext__()
except Exception as e:
if isinstance(e, StopAsyncIteration):
await self._job.on_done(self)
else:
await self._job.on_failed(e)
raise e

async def __anext__(self):
return await self._anext()

async def asend(self, value):
return await self._anext(value)


def create_generator_wrapper(
job: "Job", fut: T.Optional[Future] = None) -> GeneratorWrapper:
if inspect.isasyncgenfunction(job.func):
return AsyncGeneratorWrapper(job, fut)
else:
return SyncGeneratorWrapper(job, fut)
73 changes: 73 additions & 0 deletions tests/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,76 @@ async def gen_error():
async for i in job.result():
assert job.status == "running"
assert job.status == "failed"


@pytest.mark.asyncio
async def test_generator_send():
with Engine() as engine:
def gen():
res = 0
for _ in range(3):
res += yield res

job = ProcessJob(gen)
await engine.submit_async(job)
await job.wait_until_status("running")
assert job.status == "running"
g = job.result()
assert g.send(None) == 0
assert g.send(1) == 1
assert g.send(2) == 3
with pytest.raises(StopIteration):
g.send(3)
assert job.status == "done"

async def gen_async():
res = 0
for _ in range(3):
res += yield res

job = ProcessJob(gen_async)
await engine.submit_async(job)
await job.wait_until_status("running")
assert job.status == "running"
g = job.result()
assert await g.asend(None) == 0
assert await g.asend(1) == 1
assert await g.asend(2) == 3
with pytest.raises(StopAsyncIteration):
await g.asend(3)
assert job.status == "done"


@pytest.mark.asyncio
async def test_generator_send_localjob():
with Engine() as engine:
def gen():
res = 0
for _ in range(3):
res += yield res

job = LocalJob(gen)
engine.submit(job)
await job.wait_until_status("running")
g = job.result()
assert g.send(None) == 0
assert g.send(1) == 1
assert g.send(2) == 3
with pytest.raises(StopIteration):
g.send(3)

# test async generator
async def gen_async():
res = 0
for _ in range(3):
res += yield res

job = LocalJob(gen_async)
engine.submit(job)
await job.wait_until_status("running")
g = job.result()
assert await g.asend(None) == 0
assert await g.asend(1) == 1
assert await g.asend(2) == 3
with pytest.raises(StopAsyncIteration):
await g.asend(3)

0 comments on commit 669d71f

Please sign in to comment.