Skip to content

Commit

Permalink
Async Redis (#3618)
Browse files Browse the repository at this point in the history
* k

* update configs for clarity

* typing

* update
  • Loading branch information
pablonyx authored Jan 7, 2025
1 parent d9e9c69 commit 5b5c116
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 26 deletions.
2 changes: 1 addition & 1 deletion backend/onyx/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@

REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"


# Rate limiting for auth endpoints
RATE_LIMIT_WINDOW_SECONDS: int | None = None
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
Expand All @@ -213,6 +212,7 @@
except ValueError:
pass

AUTH_RATE_LIMITING_ENABLED = RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS
# Used for general redis things
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))

Expand Down
15 changes: 8 additions & 7 deletions backend/onyx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from onyx.configs.app_configs import APP_API_PREFIX
from onyx.configs.app_configs import APP_HOST
from onyx.configs.app_configs import APP_PORT
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
Expand Down Expand Up @@ -74,9 +75,9 @@
from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.users import router as user_router
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_limiter
from onyx.server.middleware.rate_limiting import close_auth_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
from onyx.server.middleware.rate_limiting import setup_limiter
from onyx.server.middleware.rate_limiting import setup_auth_limiter
from onyx.server.onyx_api.ingestion import router as onyx_api_router
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router,
Expand Down Expand Up @@ -174,7 +175,7 @@ def include_auth_router_with_prefix(


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Set recursion limit
if SYSTEM_RECURSION_LIMIT is not None:
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
Expand Down Expand Up @@ -216,13 +217,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:

optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})

# Set up rate limiter
await setup_limiter()
if AUTH_RATE_LIMITING_ENABLED:
await setup_auth_limiter()

yield

# Close rate limiter
await close_limiter()
if AUTH_RATE_LIMITING_ENABLED:
await close_auth_limiter()


def log_http_error(_: Request, exc: Exception) -> JSONResponse:
Expand Down
55 changes: 42 additions & 13 deletions backend/onyx/redis/redis_pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import functools
import json
import ssl
import threading
from collections.abc import Callable
from typing import Any
Expand Down Expand Up @@ -194,10 +195,6 @@ def create_pool(
redis_pool = RedisPool()


def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client(tenant_id)


# # Usage example
# redis_pool = RedisPool()
# redis_client = redis_pool.get_client()
Expand All @@ -207,6 +204,18 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
# value = redis_client.get('key')
# print(value.decode()) # Output: 'value'


def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client(tenant_id)


SSL_CERT_REQS_MAP = {
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,
"required": ssl.CERT_REQUIRED,
}


_async_redis_connection: aioredis.Redis | None = None
_async_lock = asyncio.Lock()

Expand All @@ -224,15 +233,35 @@ async def get_async_redis_connection() -> aioredis.Redis:
async with _async_lock:
# Double-check inside the lock to avoid race conditions
if _async_redis_connection is None:
scheme = "rediss" if REDIS_SSL else "redis"
url = f"{scheme}://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER}"

# Create a new Redis connection (or connection pool) from the URL
_async_redis_connection = aioredis.from_url(
url,
password=REDIS_PASSWORD,
max_connections=REDIS_POOL_MAX_CONNECTIONS,
)
# Load env vars or your config variables

connection_kwargs: dict[str, Any] = {
"host": REDIS_HOST,
"port": REDIS_PORT,
"db": REDIS_DB_NUMBER,
"password": REDIS_PASSWORD,
"max_connections": REDIS_POOL_MAX_CONNECTIONS,
"health_check_interval": REDIS_HEALTH_CHECK_INTERVAL,
"socket_keepalive": True,
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
}

if REDIS_SSL:
ssl_context = ssl.create_default_context()

if REDIS_SSL_CA_CERTS:
ssl_context.load_verify_locations(REDIS_SSL_CA_CERTS)
ssl_context.check_hostname = False

# Map your string to the proper ssl.CERT_* constant
ssl_context.verify_mode = SSL_CERT_REQS_MAP.get(
REDIS_SSL_CERT_REQS, ssl.CERT_NONE
)

connection_kwargs["ssl"] = ssl_context

# Create a new Redis connection (or connection pool) with SSL configuration
_async_redis_connection = aioredis.Redis(**connection_kwargs)

# Return the established connection (or pool) for all future operations
return _async_redis_connection
Expand Down
11 changes: 6 additions & 5 deletions backend/onyx/server/middleware/rate_limiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@
from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter

from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
from onyx.configs.app_configs import RATE_LIMIT_MAX_REQUESTS
from onyx.configs.app_configs import RATE_LIMIT_WINDOW_SECONDS
from onyx.redis.redis_pool import get_async_redis_connection


async def setup_limiter() -> None:
async def setup_auth_limiter() -> None:
# Use the centralized async Redis connection
redis = await get_async_redis_connection()
await FastAPILimiter.init(redis)


async def close_limiter() -> None:
async def close_auth_limiter() -> None:
# This closes the FastAPILimiter connection so we don't leave open connections to Redis.
await FastAPILimiter.close()

Expand All @@ -32,14 +33,14 @@ async def rate_limit_key(request: Request) -> str:


def get_auth_rate_limiters() -> List[Callable]:
if not (RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS):
if not AUTH_RATE_LIMITING_ENABLED:
return []

return [
Depends(
RateLimiter(
times=RATE_LIMIT_MAX_REQUESTS,
seconds=RATE_LIMIT_WINDOW_SECONDS,
times=RATE_LIMIT_MAX_REQUESTS or 100,
seconds=RATE_LIMIT_WINDOW_SECONDS or 60,
# Use the custom key function to distinguish users
identifier=rate_limit_key,
)
Expand Down

0 comments on commit 5b5c116

Please sign in to comment.