diff --git a/dank_mids/helpers/_session.py b/dank_mids/helpers/_session.py index 5c3de08e..a8032b39 100644 --- a/dank_mids/helpers/_session.py +++ b/dank_mids/helpers/_session.py @@ -104,6 +104,17 @@ def __new__(cls, value, phrase, description=""): HTTPStatusExtended.CLOUDFLARE_TIMEOUT, # type: ignore [attr-defined] } + +def _get_status_enum(error: ClientResponseError) -> HTTPStatusExtended: + try: + return HTTPStatusExtended(error.status) + except ValueError as ve: + if str(ve).endswith("is not a valid HTTPStatusExtended"): + # we still want the original exc to raise + raise error from ve + raise + + # default is 50 requests/second limiters = defaultdict(lambda: AsyncLimiter(1, 1 / ENVS.REQUESTS_PER_SECOND)) @@ -195,34 +206,31 @@ async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DEC _logger_debug("received response %s", response_data) return response_data except ClientResponseError as ce: - if ce.status == HTTPStatusExtended.TOO_MANY_REQUESTS: # type: ignore [attr-defined] - await self.handle_too_many_requests(endpoint, ce) - else: - try: - if ce.status not in RETRY_FOR_CODES or tried >= 5: - _logger_debug( - "response failed with status %s", - HTTPStatusExtended(ce.status), - ) - raise ce - except ValueError as ve: - if str(ve).endswith("is not a valid HTTPStatusExtended"): - raise ce from ve - raise + status = ce.status + if status == HTTPStatusExtended.TOO_MANY_REQUESTS: # type: ignore [attr-defined] + await self._handle_too_many_requests(endpoint, ce) + + elif status in RETRY_FOR_CODES and tried < 5: + tried += 1 + if debug_logs_enabled: + sleep_for = random() + _logger_log( + DEBUG, + "response failed with status %s, retrying in %.f2s", + (HTTPStatusExtended(status), sleep_for), + ) + await sleep(sleep_for) else: - tried += 1 - if debug_logs_enabled: - sleep_for = random() - _logger_log( - DEBUG, - "response failed with status %s, retrying in %ss", - (HTTPStatusExtended(ce.status), round(sleep_for, 2)), - ) - await sleep(sleep_for) - else: - await sleep(random()) - - async def handle_too_many_requests(self, endpoint: str, error: ClientResponseError) -> None: + await sleep(random()) + + else: + if debug_logs_enabled: + _logger_log( + DEBUG, "response failed with status %s", (_get_status_enum(ce),) + ) + raise + + async def _handle_too_many_requests(self, endpoint: str, error: ClientResponseError) -> None: limiter = limiters[endpoint] if (now := time()) > getattr(limiter, "_last_updated_at", 0) + 10: current_rate = limiter._rate_per_sec