diff --git a/khl/__init__.py b/khl/__init__.py index c0ad453..a28dd4c 100644 --- a/khl/__init__.py +++ b/khl/__init__.py @@ -19,6 +19,7 @@ from .cert import Cert from .receiver import Receiver, WebhookReceiver, WebsocketReceiver from .requester import HTTPRequester +from .ratelimiter import RateLimiter from .gateway import Gateway, Requestable from .client import Client diff --git a/khl/bot/bot.py b/khl/bot/bot.py index b5e382f..688baeb 100644 --- a/khl/bot/bot.py +++ b/khl/bot/bot.py @@ -6,7 +6,7 @@ from typing import Dict, Callable, List, Optional, Union, Coroutine, IO from .. import AsyncRunnable # interfaces -from .. import Cert, HTTPRequester, WebhookReceiver, WebsocketReceiver, Gateway, Client # net related +from .. import Cert, HTTPRequester, RateLimiter, WebhookReceiver, WebsocketReceiver, Gateway, Client # net related from .. import MessageTypes, EventTypes, SlowModeTypes, SoftwareTypes # types from .. import User, Channel, PublicChannel, Guild, Event, Message # concepts from ..command import CommandManager @@ -102,7 +102,7 @@ def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPReque return # client and gate not in args, build them - _out = out if out else HTTPRequester(cert) + _out = out if out else HTTPRequester(cert, RateLimiter()) if cert.type == Cert.Types.WEBSOCKET: _in = WebsocketReceiver(cert, compress) elif cert.type == Cert.Types.WEBHOOK: diff --git a/khl/ratelimiter.py b/khl/ratelimiter.py new file mode 100644 index 0000000..29236d5 --- /dev/null +++ b/khl/ratelimiter.py @@ -0,0 +1,94 @@ +import asyncio +import logging +from typing import Dict + +log = logging.getLogger(__name__) + + +class RateLimiter: + """rate limit control""" + + def __init__(self): + self._ratelimit_info: Dict[str, RateLimiter.RateLimitData] = {} + self._api_bucket_mapping: Dict[str, str] = {} + self._lock = asyncio.Lock() + + def push_api_bucket_mapping(self, api: str, bucket: str): + """ + when finished request, associate bucket that api returned with api route + to avoid that bucket and api router are not the same + """ + + api = api.lower() + bucket = bucket.lower() + + async with self._lock: + if api not in self._api_bucket_mapping: + self._api_bucket_mapping[api] = bucket + + def get_bucket(self, api: str): + """get bucket name by api route""" + + api = api.lower() + + async with self._lock: + if api not in self._api_bucket_mapping: + return api + + return self._api_bucket_mapping[api] + + def update_ratelimit(self, bucket: str, remaining: int, reset: int): + """update rate limit info""" + + bucket = bucket.lower() + async with self._lock: + if bucket not in self._ratelimit_info: + self._ratelimit_info[bucket] = self.RateLimitData(remaining, reset) + else: + self._ratelimit_info[bucket].remaining = remaining + self._ratelimit_info[bucket].reset = reset + + def get_delay(self, bucket: str) -> float: + """get request delay time, seconds""" + + bucket = bucket.lower() + async with self._lock: + if bucket not in self._ratelimit_info: + return 0 + + if self._ratelimit_info[bucket].reset == 0: + return 0 + + if self._ratelimit_info[bucket].remaining == 0: + return self._ratelimit_info[bucket].reset + + delay = self._ratelimit_info[bucket].reset / self._ratelimit_info[bucket].remaining + + return delay + + async def wait_for_rate(self, route): + bucket = self.get_bucket(route) + delay = self.get_delay(bucket) + log.debug(f'ratelimiter: {route} req bucket: {bucket} delay: {delay: .3f}s') + await asyncio.sleep(delay) + + @staticmethod + def extract_xrate_header(headers): + bucket = headers['X-Rate-Limit-Bucket'] + remaining = int(headers['X-Rate-Limit-Remaining']) + reset = int(headers['X-Rate-Limit-Reset']) + return bucket, remaining, reset + + def update(self, route, headers): + if 'X-Rate-Limit-Limit' in headers: + bucket, remaining, reset = self.extract_xrate_header(headers) + self.push_api_bucket_mapping(route, bucket) + self.update_ratelimit(bucket, remaining, reset) + log.debug(f'ratelimiter: {route} rsp ratelimit info: {bucket} {remaining} {reset}s') + + class RateLimitData: + """to save single bucket rate limit""" + + def __init__(self, remaining: int = 120, reset: int = 0): + self.remaining = remaining + self.reset = reset diff --git a/khl/requester.py b/khl/requester.py index 268770a..30167a6 100644 --- a/khl/requester.py +++ b/khl/requester.py @@ -4,6 +4,7 @@ from aiohttp import ClientSession +from .ratelimiter import RateLimiter from .api import _Req from .cert import Cert @@ -15,9 +16,10 @@ class HTTPRequester: """wrap raw requests, handle boilerplate param filling works""" - def __init__(self, cert: Cert): + def __init__(self, cert: Cert, ratelimiter: RateLimiter): self._cert = cert self._cs: Union[ClientSession, None] = None + self._ratelimiter = ratelimiter def __del__(self): if self._cs is not None: @@ -29,6 +31,10 @@ async def request(self, method: str, route: str, **params) -> Union[dict, list, params['headers'] = headers log.debug(f'{method} {route}: req: {params}') # token is excluded + + if self._ratelimiter is not None: + await self._ratelimiter.wait_for_rate(route) + headers['Authorization'] = f'Bot {self._cert.token}' if self._cs is None: # lazy init self._cs = ClientSession() @@ -40,6 +46,10 @@ async def request(self, method: str, route: str, **params) -> Union[dict, list, rsp = rsp['data'] else: rsp = await res.read() + + if self._ratelimiter is not None: + self._ratelimiter.update(route, res.headers) + log.debug(f'{method} {route}: rsp: {rsp}') return rsp