Skip to content

Commit

Permalink
Add handling for backend error
Browse files Browse the repository at this point in the history
  • Loading branch information
Bharat23 committed Jun 6, 2024
1 parent 11b00dc commit 5eaed8c
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 12 deletions.
29 changes: 18 additions & 11 deletions ratelimit/backends/redis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json

from redis.asyncio import StrictRedis
from redis.exceptions import ConnectionError

from ..exceptions import BackendConnectionException
from ..rule import Rule
from . import BaseBackend

Expand Down Expand Up @@ -41,17 +43,22 @@ async def is_blocking(self, user: str) -> int:
return int(await self._redis.ttl(f"blocking:{user}"))

async def retry_after(self, path: str, user: str, rule: Rule) -> int:
block_time = await self.is_blocking(user)
if block_time > 0:
return block_time
try:
block_time = await self.is_blocking(user)
if block_time > 0:
return block_time

ruleset = rule.ruleset(path, user)
retry_after = int(
await self.lua_script(keys=list(ruleset.keys()), args=[json.dumps(ruleset)])
)
ruleset = rule.ruleset(path, user)
retry_after = int(
await self.lua_script(
keys=list(ruleset.keys()), args=[json.dumps(ruleset)]
)
)

if retry_after > 0 and rule.block_time:
await self.set_block_time(user, rule.block_time)
retry_after = rule.block_time
if retry_after > 0 and rule.block_time:
await self.set_block_time(user, rule.block_time)
retry_after = rule.block_time

return retry_after
return retry_after
except ConnectionError as ce:
raise BackendConnectionException(f"Error connecting to Redis: {ce}")
21 changes: 20 additions & 1 deletion ratelimit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple

from .backends import BaseBackend
from .exceptions import BaseBackendException
from .rule import RULENAMES, Rule
from .types import ASGIApp, Receive, Scope, Send

Expand All @@ -23,6 +24,19 @@ async def default_429(scope: Scope, receive: Receive, send: Send) -> None:
return default_429


def _on_backend_error(err) -> ASGIApp:
async def default_503(scope: Scope, receive: Receive, send: Send) -> None:
await send(
{
"type": "http.response.start",
"status": 503,
}
)
await send({"type": "http.response.body", "body": b"", "more_body": False})

return default_503


class RateLimitMiddleware:
"""
rate limit middleware
Expand All @@ -37,6 +51,7 @@ def __init__(
*,
on_auth_error: Optional[Callable[[Exception], Awaitable[ASGIApp]]] = None,
on_blocked: Callable[[int], ASGIApp] = _on_blocked,
on_backend_error: Callable[[int], ASGIApp] = _on_backend_error,
) -> None:
self.app = app
self.authenticate = authenticate
Expand All @@ -53,6 +68,7 @@ def __init__(

self.on_auth_error = on_auth_error
self.on_blocked = on_blocked
self.on_backend_error = on_backend_error

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http": # pragma: no cover
Expand Down Expand Up @@ -90,7 +106,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return await self.app(scope, receive, send)

path: str = url_path if rule.zone is None else rule.zone
retry_after = await self.backend.retry_after(path, user, rule)
try:
retry_after = await self.backend.retry_after(path, user, rule)
except BaseBackendException as be:
return await self.on_backend_error(be)(scope, receive, send)
if retry_after == 0:
return await self.app(scope, receive, send)

Expand Down
3 changes: 3 additions & 0 deletions ratelimit/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa: F401
from .backend_connection import BackendConnectionException
from .base_backend import BaseBackendException
9 changes: 9 additions & 0 deletions ratelimit/exceptions/backend_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .base_backend import BaseBackendException


class BackendConnectionException(BaseBackendException):
"""
Backend exception for ConnectionError
"""

pass
6 changes: 6 additions & 0 deletions ratelimit/exceptions/base_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class BaseBackendException(Exception):
"""
Base class for exception raised by Backends
"""

pass
50 changes: 50 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,56 @@ async def inside_yourself_429(scope: Scope, receive: Receive, send: Send) -> Non
return inside_yourself_429


@pytest.mark.asyncio
async def test_on_backend_error():
# use incorrect port to force connection error
rate_limit = RateLimitMiddleware(
hello_world,
authenticate=auth_func,
backend=RedisBackend(StrictRedis(port=6369)),
config={r"/": [Rule(second=1), Rule(group="admin")]},
)

async with httpx.AsyncClient(
app=rate_limit, base_url="http://testserver"
) as client: # type: httpx.AsyncClient
response = await client.get("/", headers={"user": "user", "group": "default"})
assert response.status_code == 503


@pytest.mark.asyncio
async def test_custom_on_backend_error():
# use incorrect port to force connection error
rate_limit = RateLimitMiddleware(
hello_world,
authenticate=auth_func,
backend=RedisBackend(StrictRedis(port=6369)),
config={r"/": [Rule(second=1), Rule(group="admin")]},
on_backend_error=yourself_503,
)

async with httpx.AsyncClient(
app=rate_limit, base_url="http://testserver"
) as client: # type: httpx.AsyncClient
response = await client.get("/", headers={"user": "user", "group": "default"})
assert response.status_code == 503
assert response.text == "custom 503 page"


def yourself_503(retry_after: int):
async def inside_yourself_503(scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "http.response.start", "status": 503})
await send(
{
"type": "http.response.body",
"body": b"custom 503 page",
"more_body": False,
}
)

return inside_yourself_503


@pytest.mark.asyncio
async def test_custom_blocked():
rate_limit = RateLimitMiddleware(
Expand Down

0 comments on commit 5eaed8c

Please sign in to comment.