From 0b31c8c0c726a32ab3e0a7679602fa8d0b3388a0 Mon Sep 17 00:00:00 2001 From: BobTheBuidler <70677534+BobTheBuidler@users.noreply.github.com> Date: Sun, 15 Dec 2024 21:00:34 -0400 Subject: [PATCH] feat: optimize queue.py (#477) * feat: optimize queue.py * chore: `black .` * Update queue.py * Update queue.py * chore: `black .` --------- Co-authored-by: github-actions[bot] --- a_sync/primitives/queue.py | 122 ++++++++++++++++++++++--------------- 1 file changed, 73 insertions(+), 49 deletions(-) diff --git a/a_sync/primitives/queue.py b/a_sync/primitives/queue.py index da053fac..72a538e7 100644 --- a/a_sync/primitives/queue.py +++ b/a_sync/primitives/queue.py @@ -19,9 +19,11 @@ import logging import sys import weakref +from asyncio import InvalidStateError, QueueEmpty, gather import a_sync.asyncio -from a_sync import _smart +from a_sync._smart import SmartFuture, create_future +from a_sync._smart import _Key as _SmartKey from a_sync._typing import * logger = logging.getLogger(__name__) @@ -156,7 +158,7 @@ async def get_all(self) -> List[T]: """ try: return self.get_all_nowait() - except asyncio.QueueEmpty: + except QueueEmpty: return [await self.get()] def get_all_nowait(self) -> List[T]: @@ -173,13 +175,16 @@ def get_all_nowait(self) -> List[T]: >>> tasks = queue.get_all_nowait() >>> print(tasks) """ + get_nowait = self.get_nowait values: List[T] = [] + append = values.append + while True: try: - values.append(self.get_nowait()) - except asyncio.QueueEmpty as e: + append(get_nowait()) + except QueueEmpty as e: if not values: - raise asyncio.QueueEmpty from e + raise QueueEmpty from e return values async def get_multi(self, i: int, can_return_less: bool = False) -> List[T]: @@ -198,12 +203,16 @@ async def get_multi(self, i: int, can_return_less: bool = False) -> List[T]: >>> print(tasks) """ _validate_args(i, can_return_less) + get_next = self.get + get_multi = self.get_multi_nowait + items = [] + extend = items.extend while len(items) < i and not can_return_less: try: - items.extend(self.get_multi_nowait(i - len(items), can_return_less=True)) - except asyncio.QueueEmpty: - items = [await self.get()] + extend(get_multi(i - len(items), can_return_less=True)) + except QueueEmpty: + items = [await get_next()] return items def get_multi_nowait(self, i: int, can_return_less: bool = False) -> List[T]: @@ -222,17 +231,22 @@ def get_multi_nowait(self, i: int, can_return_less: bool = False) -> List[T]: >>> print(tasks) """ _validate_args(i, can_return_less) + + get_nowait = self.get_nowait + items = [] + append = items.append for _ in range(i): try: - items.append(self.get_nowait()) - except asyncio.QueueEmpty: + append(get_nowait()) + except QueueEmpty: if items and can_return_less: return items # put these back in the queue since we didn't return them + put_nowait = self.put_nowait for value in items: - self.put_nowait(value) - raise asyncio.QueueEmpty from None + put_nowait(value) + raise QueueEmpty from None return items @@ -359,7 +373,8 @@ def __del__(self) -> None: context = { "message": f"{self} was destroyed but has work pending!", } - asyncio.get_event_loop().call_exception_handler(context) + if loop := asyncio.events._get_running_loop(): + loop.call_exception_handler(context) @property def name(self) -> str: @@ -426,7 +441,7 @@ def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> "asyncio.Future[V]": def _create_future(self) -> "asyncio.Future[V]": """Creates a future for the task.""" - return asyncio.get_event_loop().create_future() + return asyncio.events._get_running_loop().create_future() def _ensure_workers(self) -> None: """Ensures that the worker tasks are running.""" @@ -448,16 +463,16 @@ def _ensure_workers(self) -> None: def _workers(self) -> "asyncio.Task[NoReturn]": """Creates and manages the worker tasks for the queue.""" logger.debug("starting worker task for %s", self) - workers = [ + workers = tuple( a_sync.asyncio.create_task( coro=self._worker_coro(), name=f"{self.name} [Task-{i}]", log_destroy_pending=False, ) for i in range(self.num_workers) - ] + ) task = a_sync.asyncio.create_task( - asyncio.gather(*workers), + gather(*workers), name=f"{self.name} worker main Task", log_destroy_pending=False, ) @@ -468,49 +483,53 @@ async def __worker_coro(self) -> NoReturn: """ The coroutine executed by worker tasks to process the queue. """ + get_next_job = self.get + func = self.func + task_done = self.task_done + args: P.args kwargs: P.kwargs if self._no_futs: while True: try: - args, kwargs = await self.get() - await self.func(*args, **kwargs) + args, kwargs = await get_next_job() + await func(*args, **kwargs) except Exception as e: logger.error("%s in worker for %s!", type(e).__name__, self) logger.exception(e) - self.task_done() + task_done() else: fut: asyncio.Future[V] while True: try: - args, kwargs, fut = await self.get() + args, kwargs, fut = await get_next_job() try: if fut is None: # the weakref was already cleaned up, we don't need to process this item - self.task_done() + task_done() continue - result = await self.func(*args, **kwargs) + result = await func(*args, **kwargs) fut.set_result(result) - except asyncio.exceptions.InvalidStateError: + except InvalidStateError: logger.error( "cannot set result for %s %s: %s", - self.func.__name__, + func.__name__, fut, result, ) except Exception as e: try: fut.set_exception(e) - except asyncio.exceptions.InvalidStateError: + except InvalidStateError: logger.error( "cannot set exception for %s %s: %s", - self.func.__name__, + func.__name__, fut, e, ) - self.task_done() + task_done() except Exception as e: - logger.error("%s for %s is broken!!!", type(self).__name__, self.func) + logger.error("%s for %s is broken!!!", type(self).__name__, func) logger.exception(e) raise @@ -540,10 +559,10 @@ def _validate_args(i: int, can_return_less: bool) -> None: class _SmartFutureRef(weakref.ref, Generic[T]): """ - Weak reference for :class:`~_smart.SmartFuture` objects used in priority queues. + Weak reference for :class:`~SmartFuture` objects used in priority queues. See Also: - :class:`~_smart.SmartFuture` + :class:`~SmartFuture` """ def __lt__(self, other: "_SmartFutureRef[T]") -> bool: @@ -713,7 +732,7 @@ def _get(self, heappop=heapq.heappop): # `self._queue` will always be in proper order for next call to `self._get`. return heappop(self._queue) - def _get_key(self, *args, **kwargs) -> _smart._Key: + def _get_key(self, *args, **kwargs) -> _SmartKey: """ Generates a unique key for task identification based on arguments. @@ -773,7 +792,7 @@ class SmartProcessingQueue(_VariablePriorityQueueMixin[T], ProcessingQueue[Conca _no_futs = False """Whether smart futures are used.""" - _futs: "weakref.WeakValueDictionary[_smart._Key[T], _smart.SmartFuture[T]]" + _futs: "weakref.WeakValueDictionary[_SmartKey[T], SmartFuture[T]]" """ Weak reference dictionary for managing smart futures. """ @@ -802,7 +821,7 @@ def __init__( super().__init__(func, num_workers, return_data=True, name=name, loop=loop) self._futs = weakref.WeakValueDictionary() - async def put(self, *args: P.args, **kwargs: P.kwargs) -> _smart.SmartFuture[V]: + async def put(self, *args: P.args, **kwargs: P.kwargs) -> SmartFuture[V]: """ Asynchronously adds a task with smart future handling to the queue. @@ -826,7 +845,7 @@ async def put(self, *args: P.args, **kwargs: P.kwargs) -> _smart.SmartFuture[V]: await Queue.put(self, (_SmartFutureRef(fut), args, kwargs)) return fut - def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> _smart.SmartFuture[V]: + def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> SmartFuture[V]: """ Immediately adds a task with smart future handling to the queue without waiting. @@ -850,9 +869,9 @@ def put_nowait(self, *args: P.args, **kwargs: P.kwargs) -> _smart.SmartFuture[V] Queue.put_nowait(self, (_SmartFutureRef(fut), args, kwargs)) return fut - def _create_future(self, key: _smart._Key) -> "asyncio.Future[V]": + def _create_future(self, key: _SmartKey) -> "asyncio.Future[V]": """Creates a smart future for the task.""" - return _smart.create_future(queue=self, key=key, loop=self._loop) + return create_future(queue=self, key=key, loop=self._loop) def _get(self): """ @@ -880,40 +899,45 @@ async def __worker_coro(self) -> NoReturn: Example: >>> await queue.__worker_coro() """ + get_next_job = self.get + func = self.func + task_done = self.task_done + log_debug = logger.debug + args: P.args kwargs: P.kwargs - fut: _smart.SmartFuture[V] + fut: SmartFuture[V] while True: try: try: - args, kwargs, fut = await self.get() + args, kwargs, fut = await get_next_job() if fut is None: # the weakref was already cleaned up, we don't need to process this item - self.task_done() + task_done() continue - logger.debug("processing %s", fut) - result = await self.func(*args, **kwargs) + log_debug("processing %s", fut) + result = await func(*args, **kwargs) fut.set_result(result) - except asyncio.exceptions.InvalidStateError: + except InvalidStateError: logger.error( "cannot set result for %s %s: %s", - self.func.__name__, + func.__name__, fut, result, ) except Exception as e: - logger.debug("%s: %s", type(e).__name__, e) + log_debug("%s: %s", type(e).__name__, e) try: fut.set_exception(e) - except asyncio.exceptions.InvalidStateError: + except InvalidStateError: logger.error( "cannot set exception for %s %s: %s", - self.func.__name__, + func.__name__, fut, e, ) - self.task_done() + task_done() except Exception as e: - logger.error("%s for %s is broken!!!", type(self).__name__, self.func) + logger.error("%s for %s is broken!!!", type(self).__name__, func) logger.exception(e) raise