diff --git a/releasenotes/notes/allow-async-callbacks-e75b66109e759c3b.yaml b/releasenotes/notes/allow-async-callbacks-e75b66109e759c3b.yaml new file mode 100644 index 00000000..d906b758 --- /dev/null +++ b/releasenotes/notes/allow-async-callbacks-e75b66109e759c3b.yaml @@ -0,0 +1,3 @@ +--- +features: + - Allow asynchronous callbacks for `before`, `after`, `retry_error_callback`, `wait`, and `before_sleep` parameters. diff --git a/tenacity/_asyncio.py b/tenacity/_asyncio.py index 374ef206..484651a2 100644 --- a/tenacity/_asyncio.py +++ b/tenacity/_asyncio.py @@ -17,35 +17,32 @@ import functools import sys -import typing -from asyncio import sleep +from asyncio import sleep as aio_sleep +from collections.abc import Awaitable +from inspect import iscoroutinefunction +from typing import Union, Callable, Any, TypeVar -from tenacity import AttemptManager -from tenacity import BaseRetrying -from tenacity import DoAttempt -from tenacity import DoSleep -from tenacity import RetryCallState +from tenacity import AttemptManager, BaseRetrying, DoAttempt, DoSleep, RetryCallState, RetryAction, TryAgain -WrappedFn = typing.TypeVar("WrappedFn", bound=typing.Callable) -_RetValT = typing.TypeVar("_RetValT") +WrappedFn = TypeVar("WrappedFn", bound=Callable) +_RetValT = TypeVar("_RetValT") class AsyncRetrying(BaseRetrying): - def __init__(self, sleep: typing.Callable[[float], typing.Awaitable] = sleep, **kwargs: typing.Any) -> None: + def __init__(self, sleep: Callable[[float], Awaitable] = aio_sleep, **kwargs: Any) -> None: super().__init__(**kwargs) self.sleep = sleep async def __call__( # type: ignore # Change signature from supertype self, - fn: typing.Callable[..., typing.Awaitable[_RetValT]], - *args: typing.Any, - **kwargs: typing.Any, + fn: Callable[..., Awaitable[_RetValT]], + *args: Any, + **kwargs: Any, ) -> _RetValT: self.begin() - retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs) while True: - do = self.iter(retry_state=retry_state) + do = await self.iter(retry_state=retry_state) if isinstance(do, DoAttempt): try: result = await fn(*args, **kwargs) @@ -64,9 +61,9 @@ def __aiter__(self) -> "AsyncRetrying": self._retry_state = RetryCallState(self, fn=None, args=(), kwargs={}) return self - async def __anext__(self) -> typing.Union[AttemptManager, typing.Any]: + async def __anext__(self) -> Union[AttemptManager, Any]: while True: - do = self.iter(retry_state=self._retry_state) + do = await self.iter(retry_state=self._retry_state) if do is None: raise StopAsyncIteration elif isinstance(do, DoAttempt): @@ -82,7 +79,7 @@ def wraps(self, fn: WrappedFn) -> WrappedFn: # Ensure wrapper is recognized as a coroutine function. @functools.wraps(fn) - async def async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + async def async_wrapped(*args: Any, **kwargs: Any) -> Any: return await fn(*args, **kwargs) # Preserve attributes @@ -90,3 +87,46 @@ async def async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: async_wrapped.retry_with = fn.retry_with return async_wrapped + + @staticmethod + async def handle_custom_function(func: Union[Callable, Awaitable], retry_state: RetryCallState) -> Any: + if iscoroutinefunction(func): + return await func(retry_state) + return func(retry_state) + + async def iter(self, retry_state: "RetryCallState") -> Union[DoAttempt, DoSleep, Any]: + fut = retry_state.outcome + if fut is None: + if self.before is not None: + await self.handle_custom_function(self.before, retry_state) + return DoAttempt() + + is_explicit_retry = retry_state.outcome.failed and isinstance(retry_state.outcome.exception(), TryAgain) + if not (is_explicit_retry or self.retry(retry_state=retry_state)): + return fut.result() + + if self.after is not None: + await self.handle_custom_function(self.after, retry_state) + + self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start + if self.stop(retry_state=retry_state): + if self.retry_error_callback: + return await self.handle_custom_function(self.retry_error_callback, retry_state) + retry_exc = self.retry_error_cls(fut) + if self.reraise: + raise retry_exc.reraise() + raise retry_exc from fut.exception() + + if self.wait: + sleep = await self.handle_custom_function(self.wait, retry_state=retry_state) + else: + sleep = 0.0 + retry_state.next_action = RetryAction(sleep) + retry_state.idle_for += sleep + self.statistics["idle_for"] += sleep + self.statistics["attempt_number"] += 1 + + if self.before_sleep is not None: + await self.handle_custom_function(self.before_sleep, retry_state) + + return DoSleep(sleep) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index b370e29c..4d69f68a 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -15,15 +15,15 @@ import asyncio import inspect +import logging import unittest from functools import wraps -from tenacity import AsyncRetrying, RetryError +import tenacity from tenacity import _asyncio as tasyncio from tenacity import retry, stop_after_attempt from tenacity.wait import wait_fixed - -from .test_tenacity import NoIOErrorAfterCount, current_time_ms +from .test_tenacity import CapturingHandler, NoneReturnUntilAfterCount, NoIOErrorAfterCount, current_time_ms def asynctest(callable_): @@ -67,7 +67,7 @@ async def test_iscoroutinefunction(self): @asynctest async def test_retry_using_async_retying(self): thing = NoIOErrorAfterCount(5) - retrying = AsyncRetrying() + retrying = tenacity.AsyncRetrying() await retrying(_async_function, thing) assert thing.counter == thing.count @@ -76,7 +76,7 @@ async def test_stop_after_attempt(self): thing = NoIOErrorAfterCount(2) try: await _retryable_coroutine_with_2_attempts(thing) - except RetryError: + except tenacity.RetryError: assert thing.counter == 2 def test_repr(self): @@ -86,6 +86,31 @@ def test_retry_attributes(self): assert hasattr(_retryable_coroutine, "retry") assert hasattr(_retryable_coroutine, "retry_with") + @asynctest + async def test_async_retry_error_callback_handler(self): + num_attempts = 3 + self.attempt_counter = 0 + + async def _retry_error_callback_handler(retry_state: tenacity.RetryCallState): + _retry_error_callback_handler.called_times += 1 + return retry_state.outcome + + _retry_error_callback_handler.called_times = 0 + + @retry( + stop=stop_after_attempt(num_attempts), + retry_error_callback=_retry_error_callback_handler, + ) + async def _foobar(): + self.attempt_counter += 1 + raise Exception("This exception should not be raised") + + result = await _foobar() + + self.assertEqual(_retry_error_callback_handler.called_times, 1) + self.assertEqual(num_attempts, self.attempt_counter) + self.assertIsInstance(result, tenacity.Future) + @asynctest async def test_attempt_number_is_correct_for_interleaved_coroutines(self): @@ -125,7 +150,7 @@ async def test_do_max_attempts(self): with attempt: attempts += 1 raise Exception - except RetryError: + except tenacity.RetryError: pass assert attempts == 3 @@ -151,11 +176,110 @@ async def test_sleeps(self): async for attempt in tasyncio.AsyncRetrying(stop=stop_after_attempt(1), wait=wait_fixed(1)): with attempt: raise Exception() - except RetryError: + except tenacity.RetryError: pass t = current_time_ms() - start self.assertLess(t, 1.1) +class TestAsyncBeforeAfterAttempts(unittest.TestCase): + _attempt_number = 0 + + @asynctest + async def test_before_attempts(self): + TestAsyncBeforeAfterAttempts._attempt_number = 0 + + async def _before(retry_state): + TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number + + @retry( + wait=tenacity.wait_fixed(1), + stop=tenacity.stop_after_attempt(1), + before=_before, + ) + async def _test_before(): + pass + + await _test_before() + + self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 1) + + @asynctest + async def test_after_attempts(self): + TestAsyncBeforeAfterAttempts._attempt_number = 0 + + async def _after(retry_state): + TestAsyncBeforeAfterAttempts._attempt_number = retry_state.attempt_number + + @retry( + wait=tenacity.wait_fixed(0.1), + stop=tenacity.stop_after_attempt(3), + after=_after, + ) + async def _test_after(): + if TestAsyncBeforeAfterAttempts._attempt_number < 2: + raise Exception("testing after_attempts handler") + else: + pass + + await _test_after() + + self.assertTrue(TestAsyncBeforeAfterAttempts._attempt_number == 2) + + @asynctest + async def test_before_sleep(self): + async def _before_sleep(retry_state): + self.assertGreater(retry_state.next_action.sleep, 0) + _before_sleep.attempt_number = retry_state.attempt_number + + _before_sleep.attempt_number = 0 + + @retry( + wait=tenacity.wait_fixed(0.01), + stop=tenacity.stop_after_attempt(3), + before_sleep=_before_sleep, + ) + async def _test_before_sleep(): + if _before_sleep.attempt_number < 2: + raise Exception("testing before_sleep_attempts handler") + + await _test_before_sleep() + self.assertEqual(_before_sleep.attempt_number, 2) + + async def _test_before_sleep_log_returns(self, exc_info): + thing = NoneReturnUntilAfterCount(2) + logger = logging.getLogger(self.id()) + logger.propagate = False + logger.setLevel(logging.INFO) + handler = CapturingHandler() + logger.addHandler(handler) + try: + _before_sleep = tenacity.before_sleep_log(logger, logging.INFO, exc_info=exc_info) + _retry = tenacity.retry_if_result(lambda result: result is None) + retrying = tenacity.AsyncRetrying( + wait=tenacity.wait_fixed(0.01), + stop=tenacity.stop_after_attempt(3), + retry=_retry, + before_sleep=_before_sleep, + ) + await retrying(_async_function, thing) + finally: + logger.removeHandler(handler) + + etalon_re = r"^Retrying .* in 0\.01 seconds as it returned None\.$" + self.assertEqual(len(handler.records), 2) + fmt = logging.Formatter().format + self.assertRegex(fmt(handler.records[0]), etalon_re) + self.assertRegex(fmt(handler.records[1]), etalon_re) + + @asynctest + async def test_before_sleep_log_returns_without_exc_info(self): + await self._test_before_sleep_log_returns(exc_info=False) + + @asynctest + async def test_before_sleep_log_returns_with_exc_info(self): + await self._test_before_sleep_log_returns(exc_info=True) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py index b6f6bbb0..454733f5 100644 --- a/tests/test_tenacity.py +++ b/tests/test_tenacity.py @@ -1093,6 +1093,8 @@ def _before_sleep(retry_state): self.assertGreater(retry_state.next_action.sleep, 0) _before_sleep.attempt_number = retry_state.attempt_number + _before_sleep.attempt_number = 0 + @retry( wait=tenacity.wait_fixed(0.01), stop=tenacity.stop_after_attempt(3),