Skip to content

Commit

Permalink
feat: rate limit control (#217)
Browse files Browse the repository at this point in the history
  • Loading branch information
hank9999 authored Nov 7, 2023
1 parent 25d26b4 commit 3002faf
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 7 deletions.
1 change: 1 addition & 0 deletions khl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .cert import Cert
from .receiver import Receiver, WebhookReceiver, WebsocketReceiver
from .requester import HTTPRequester
from .ratelimiter import RateLimiter
from .gateway import Gateway, Requestable
from .client import Client

Expand Down
12 changes: 7 additions & 5 deletions khl/bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict, Callable, List, Optional, Union, Coroutine, IO

from .. import AsyncRunnable # interfaces
from .. import Cert, HTTPRequester, WebhookReceiver, WebsocketReceiver, Gateway, Client # net related
from .. import Cert, HTTPRequester, RateLimiter, WebhookReceiver, WebsocketReceiver, Gateway, Client # net related
from .. import MessageTypes, EventTypes, SlowModeTypes, SoftwareTypes # types
from .. import User, Channel, PublicChannel, Guild, Event, Message # concepts
from ..command import CommandManager
Expand Down Expand Up @@ -49,7 +49,8 @@ def __init__(self,
out: HTTPRequester = None,
compress: bool = True,
port=5000,
route='/khl-wh'):
route='/khl-wh',
ratelimiter: Optional[RateLimiter] = RateLimiter(start=80)):
"""
The most common usage: ``Bot(token='xxxxxx')``
Expand All @@ -66,7 +67,7 @@ def __init__(self,
if not token and not cert:
raise ValueError('require token or cert')

self._init_client(cert or Cert(token=token), client, gate, out, compress, port, route)
self._init_client(cert or Cert(token=token), client, gate, out, compress, port, route, ratelimiter)
self._register_client_handler()

self.command = CommandManager()
Expand All @@ -78,7 +79,8 @@ def __init__(self,
self._startup_index = []
self._shutdown_index = []

def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPRequester, compress: bool, port, route):
def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPRequester, compress: bool, port, route,
ratelimiter):
"""
construct self.client from args.
Expand All @@ -102,7 +104,7 @@ def _init_client(self, cert: Cert, client: Client, gate: Gateway, out: HTTPReque
return

# client and gate not in args, build them
_out = out if out else HTTPRequester(cert)
_out = out if out else HTTPRequester(cert, ratelimiter)
if cert.type == Cert.Types.WEBSOCKET:
_in = WebsocketReceiver(cert, compress)
elif cert.type == Cert.Types.WEBHOOK:
Expand Down
106 changes: 106 additions & 0 deletions khl/ratelimiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import asyncio
import logging
from typing import Dict

log = logging.getLogger(__name__)


class RateLimiter:
"""rate limit control
@param start: when the remain reach this number, start ratelimit
"""

def __init__(self, start: int = 120):
self._ratelimit_info: Dict[str, RateLimiter.RateLimitData] = {}
self._api_bucket_mapping: Dict[str, str] = {}
self._lock = asyncio.Lock()
self._start = start

async def wait_for_rate(self, route):
"""get and wait delay"""

bucket = await self.get_bucket(route)
delay = await self.get_delay(bucket)
log.debug(f'ratelimiter: {route} req bucket: {bucket} delay: {delay: .3f}s')
await asyncio.sleep(delay)

async def update(self, route, headers):
"""get values and update ratelimit information"""

if 'X-Rate-Limit-Limit' in headers:
bucket, remaining, reset = self.extract_xrate_header(headers)
await self.push_api_bucket_mapping(route, bucket)
await self.update_ratelimit(bucket, remaining, reset)
log.debug(f'ratelimiter: {route} rsp ratelimit: bucket: {bucket} remaining: {remaining} reset: {reset}s')

async def push_api_bucket_mapping(self, api: str, bucket: str):
"""
when finished request, associate bucket that api returned with api route
to avoid that bucket and api router are not the same
"""

api = api.lower()
bucket = bucket.lower()

async with self._lock:
if api not in self._api_bucket_mapping:
self._api_bucket_mapping[api] = bucket

async def get_bucket(self, api: str):
"""get bucket name by api route"""

api = api.lower()

async with self._lock:
if api not in self._api_bucket_mapping:
return api

return self._api_bucket_mapping[api]

async def update_ratelimit(self, bucket: str, remaining: int, reset: int):
"""update rate limit info"""

bucket = bucket.lower()
async with self._lock:
if bucket not in self._ratelimit_info:
self._ratelimit_info[bucket] = self.RateLimitData(remaining, reset)
else:
self._ratelimit_info[bucket].remaining = remaining
self._ratelimit_info[bucket].reset = reset

async def get_delay(self, bucket: str) -> float:
"""get request delay time, seconds"""

bucket = bucket.lower()
async with self._lock:
if bucket not in self._ratelimit_info:
return 0

if self._ratelimit_info[bucket].reset == 0:
return 0

if self._ratelimit_info[bucket].remaining == 0:
return self._ratelimit_info[bucket].reset

if self._ratelimit_info[bucket].remaining > self._start:
return 0

delay = self._ratelimit_info[bucket].reset / self._ratelimit_info[bucket].remaining

return delay

@staticmethod
def extract_xrate_header(headers):
"""get bucket, remaining, reset values from headers"""

bucket = headers['X-Rate-Limit-Bucket']
remaining = int(headers['X-Rate-Limit-Remaining'])
reset = int(headers['X-Rate-Limit-Reset'])
return bucket, remaining, reset

class RateLimitData:
"""to save single bucket rate limit"""

def __init__(self, remaining: int = 120, reset: int = 0):
self.remaining = remaining
self.reset = reset
14 changes: 12 additions & 2 deletions khl/requester.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import logging
from typing import Union, List
from typing import Union, List, Optional

from aiohttp import ClientSession

from .ratelimiter import RateLimiter
from .api import _Req
from .cert import Cert

Expand All @@ -15,9 +16,10 @@
class HTTPRequester:
"""wrap raw requests, handle boilerplate param filling works"""

def __init__(self, cert: Cert):
def __init__(self, cert: Cert, ratelimiter: Optional[RateLimiter]):
self._cert = cert
self._cs: Union[ClientSession, None] = None
self._ratelimiter = ratelimiter

def __del__(self):
if self._cs is not None:
Expand All @@ -29,6 +31,10 @@ async def request(self, method: str, route: str, **params) -> Union[dict, list,
params['headers'] = headers

log.debug(f'{method} {route}: req: {params}') # token is excluded

if self._ratelimiter is not None:
await self._ratelimiter.wait_for_rate(route)

headers['Authorization'] = f'Bot {self._cert.token}'
if self._cs is None: # lazy init
self._cs = ClientSession()
Expand All @@ -40,6 +46,10 @@ async def request(self, method: str, route: str, **params) -> Union[dict, list,
rsp = rsp['data']
else:
rsp = await res.read()

if self._ratelimiter is not None:
await self._ratelimiter.update(route, res.headers)

log.debug(f'{method} {route}: rsp: {rsp}')
return rsp

Expand Down

0 comments on commit 3002faf

Please sign in to comment.