Skip to content

Commit

Permalink
feat: optimize queue.py (#477)
Browse files Browse the repository at this point in the history
* feat: optimize queue.py

* chore: `black .`

* Update queue.py

* Update queue.py

* chore: `black .`

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
BobTheBuidler and github-actions[bot] authored Dec 16, 2024
1 parent 39882d0 commit 0b31c8c
Showing 1 changed file with 73 additions and 49 deletions.
122 changes: 73 additions & 49 deletions a_sync/primitives/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
import logging
import sys
import weakref
from asyncio import InvalidStateError, QueueEmpty, gather

import a_sync.asyncio
from a_sync import _smart
from a_sync._smart import SmartFuture, create_future
from a_sync._smart import _Key as _SmartKey
from a_sync._typing import *

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -156,7 +158,7 @@ async def get_all(self) -> List[T]:
"""
try:
return self.get_all_nowait()
except asyncio.QueueEmpty:
except QueueEmpty:
return [await self.get()]

def get_all_nowait(self) -> List[T]:
Expand All @@ -173,13 +175,16 @@ def get_all_nowait(self) -> List[T]:
>>> tasks = queue.get_all_nowait()
>>> print(tasks)
"""
get_nowait = self.get_nowait
values: List[T] = []
append = values.append

while True:
try:
values.append(self.get_nowait())
except asyncio.QueueEmpty as e:
append(get_nowait())
except QueueEmpty as e:
if not values:
raise asyncio.QueueEmpty from e
raise QueueEmpty from e
return values

async def get_multi(self, i: int, can_return_less: bool = False) -> List[T]:
Expand All @@ -198,12 +203,16 @@ async def get_multi(self, i: int, can_return_less: bool = False) -> List[T]:
>>> print(tasks)
"""
_validate_args(i, can_return_less)
get_next = self.get
get_multi = self.get_multi_nowait

items = []
extend = items.extend
while len(items) < i and not can_return_less:
try:
items.extend(self.get_multi_nowait(i - len(items), can_return_less=True))
except asyncio.QueueEmpty:
items = [await self.get()]
extend(get_multi(i - len(items), can_return_less=True))
except QueueEmpty:
items = [await get_next()]
return items

def get_multi_nowait(self, i: int, can_return_less: bool = False) -> List[T]:
Expand All @@ -222,17 +231,22 @@ def get_multi_nowait(self, i: int, can_return_less: bool = False) -> List[T]:
>>> print(tasks)
"""
_validate_args(i, can_return_less)

get_nowait = self.get_nowait

items = []
append = items.append
for _ in range(i):
try:
items.append(self.get_nowait())
except asyncio.QueueEmpty:
append(get_nowait())
except QueueEmpty:
if items and can_return_less:
return items
# put these back in the queue since we didn't return them
put_nowait = self.put_nowait
for value in items:
self.put_nowait(value)
raise asyncio.QueueEmpty from None
put_nowait(value)
raise QueueEmpty from None
return items


Expand Down Expand Up @@ -359,7 +373,8 @@ def __del__(self) -> None:
context = {
"message": f"{self} was destroyed but has work pending!",
}
asyncio.get_event_loop().call_exception_handler(context)
if loop := asyncio.events._get_running_loop():
loop.call_exception_handler(context)

@property
def name(self) -> str:
Expand Down Expand Up @@ -426,7 +441,7 @@ def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]":

def _create_future(self) -> "asyncio.Future[V]":
"""Creates a future for the task."""
return asyncio.get_event_loop().create_future()
return asyncio.events._get_running_loop().create_future()

def _ensure_workers(self) -> None:
"""Ensures that the worker tasks are running."""
Expand All @@ -448,16 +463,16 @@ def _ensure_workers(self) -> None:
def _workers(self) -> "asyncio.Task[NoReturn]":
"""Creates and manages the worker tasks for the queue."""
logger.debug("starting worker task for %s", self)
workers = [
workers = tuple(
a_sync.asyncio.create_task(
coro=self._worker_coro(),
name=f"{self.name} [Task-{i}]",
log_destroy_pending=False,
)
for i in range(self.num_workers)
]
)
task = a_sync.asyncio.create_task(
asyncio.gather(*workers),
gather(*workers),
name=f"{self.name} worker main Task",
log_destroy_pending=False,
)
Expand All @@ -468,49 +483,53 @@ async def __worker_coro(self) -> NoReturn:
"""
The coroutine executed by worker tasks to process the queue.
"""
get_next_job = self.get
func = self.func
task_done = self.task_done

args: P.args
kwargs: P.kwargs
if self._no_futs:
while True:
try:
args, kwargs = await self.get()
await self.func(*args, **kwargs)
args, kwargs = await get_next_job()
await func(*args, **kwargs)
except Exception as e:
logger.error("%s in worker for %s!", type(e).__name__, self)
logger.exception(e)
self.task_done()
task_done()
else:
fut: asyncio.Future[V]
while True:
try:
args, kwargs, fut = await self.get()
args, kwargs, fut = await get_next_job()
try:
if fut is None:
# the weakref was already cleaned up, we don't need to process this item
self.task_done()
task_done()
continue
result = await self.func(*args, **kwargs)
result = await func(*args, **kwargs)
fut.set_result(result)
except asyncio.exceptions.InvalidStateError:
except InvalidStateError:
logger.error(
"cannot set result for %s %s: %s",
self.func.__name__,
func.__name__,
fut,
result,
)
except Exception as e:
try:
fut.set_exception(e)
except asyncio.exceptions.InvalidStateError:
except InvalidStateError:
logger.error(
"cannot set exception for %s %s: %s",
self.func.__name__,
func.__name__,
fut,
e,
)
self.task_done()
task_done()
except Exception as e:
logger.error("%s for %s is broken!!!", type(self).__name__, self.func)
logger.error("%s for %s is broken!!!", type(self).__name__, func)
logger.exception(e)
raise

Expand Down Expand Up @@ -540,10 +559,10 @@ def _validate_args(i: int, can_return_less: bool) -> None:

class _SmartFutureRef(weakref.ref, Generic[T]):
"""
Weak reference for :class:`~_smart.SmartFuture` objects used in priority queues.
Weak reference for :class:`~SmartFuture` objects used in priority queues.
See Also:
:class:`~_smart.SmartFuture`
:class:`~SmartFuture`
"""

def __lt__(self, other: "_SmartFutureRef[T]") -> bool:
Expand Down Expand Up @@ -713,7 +732,7 @@ def _get(self, heappop=heapq.heappop):
# `self._queue` will always be in proper order for next call to `self._get`.
return heappop(self._queue)

def _get_key(self, *args, **kwargs) -> _smart._Key:
def _get_key(self, *args, **kwargs) -> _SmartKey:
"""
Generates a unique key for task identification based on arguments.
Expand Down Expand Up @@ -773,7 +792,7 @@ class SmartProcessingQueue(_VariablePriorityQueueMixin[T], ProcessingQueue[Conca
_no_futs = False
"""Whether smart futures are used."""

_futs: "weakref.WeakValueDictionary[_smart._Key[T], _smart.SmartFuture[T]]"
_futs: "weakref.WeakValueDictionary[_SmartKey[T], SmartFuture[T]]"
"""
Weak reference dictionary for managing smart futures.
"""
Expand Down Expand Up @@ -802,7 +821,7 @@ def __init__(
super().__init__(func, num_workers, return_data=True, name=name, loop=loop)
self._futs = weakref.WeakValueDictionary()

async def put(self, *args: P.args, **kwargs: P.kwargs) -> _smart.SmartFuture[V]:
async def put(self, *args: P.args, **kwargs: P.kwargs) -> SmartFuture[V]:
"""
Asynchronously adds a task with smart future handling to the queue.
Expand All @@ -826,7 +845,7 @@ async def put(self, *args: P.args, **kwargs: P.kwargs) -> _smart.SmartFuture[V]:
await Queue.put(self, (_SmartFutureRef(fut), args, kwargs))
return fut

def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> _smart.SmartFuture[V]:
def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> SmartFuture[V]:
"""
Immediately adds a task with smart future handling to the queue without waiting.
Expand All @@ -850,9 +869,9 @@ def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> _smart.SmartFuture[V]
Queue.put_nowait(self, (_SmartFutureRef(fut), args, kwargs))
return fut

def _create_future(self, key: _smart._Key) -> "asyncio.Future[V]":
def _create_future(self, key: _SmartKey) -> "asyncio.Future[V]":
"""Creates a smart future for the task."""
return _smart.create_future(queue=self, key=key, loop=self._loop)
return create_future(queue=self, key=key, loop=self._loop)

def _get(self):
"""
Expand Down Expand Up @@ -880,40 +899,45 @@ async def __worker_coro(self) -> NoReturn:
Example:
>>> await queue.__worker_coro()
"""
get_next_job = self.get
func = self.func
task_done = self.task_done
log_debug = logger.debug

args: P.args
kwargs: P.kwargs
fut: _smart.SmartFuture[V]
fut: SmartFuture[V]
while True:
try:
try:
args, kwargs, fut = await self.get()
args, kwargs, fut = await get_next_job()
if fut is None:
# the weakref was already cleaned up, we don't need to process this item
self.task_done()
task_done()
continue
logger.debug("processing %s", fut)
result = await self.func(*args, **kwargs)
log_debug("processing %s", fut)
result = await func(*args, **kwargs)
fut.set_result(result)
except asyncio.exceptions.InvalidStateError:
except InvalidStateError:
logger.error(
"cannot set result for %s %s: %s",
self.func.__name__,
func.__name__,
fut,
result,
)
except Exception as e:
logger.debug("%s: %s", type(e).__name__, e)
log_debug("%s: %s", type(e).__name__, e)
try:
fut.set_exception(e)
except asyncio.exceptions.InvalidStateError:
except InvalidStateError:
logger.error(
"cannot set exception for %s %s: %s",
self.func.__name__,
func.__name__,
fut,
e,
)
self.task_done()
task_done()
except Exception as e:
logger.error("%s for %s is broken!!!", type(self).__name__, self.func)
logger.error("%s for %s is broken!!!", type(self).__name__, func)
logger.exception(e)
raise

0 comments on commit 0b31c8c

Please sign in to comment.