diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index b6a905d4abd..e79378c63a7 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -195,7 +195,6 @@ REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:" - # Rate limiting for auth endpoints RATE_LIMIT_WINDOW_SECONDS: int | None = None _rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS") @@ -213,6 +212,7 @@ except ValueError: pass +AUTH_RATE_LIMITING_ENABLED = RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS # Used for general redis things REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0)) diff --git a/backend/onyx/main.py b/backend/onyx/main.py index 93baba26f1b..099b154ed3d 100644 --- a/backend/onyx/main.py +++ b/backend/onyx/main.py @@ -30,6 +30,7 @@ from onyx.configs.app_configs import APP_API_PREFIX from onyx.configs.app_configs import APP_HOST from onyx.configs.app_configs import APP_PORT +from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED from onyx.configs.app_configs import AUTH_TYPE from onyx.configs.app_configs import DISABLE_GENERATIVE_AI from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY @@ -74,9 +75,9 @@ from onyx.server.manage.slack_bot import router as slack_bot_management_router from onyx.server.manage.users import router as user_router from onyx.server.middleware.latency_logging import add_latency_logging_middleware -from onyx.server.middleware.rate_limiting import close_limiter +from onyx.server.middleware.rate_limiting import close_auth_limiter from onyx.server.middleware.rate_limiting import get_auth_rate_limiters -from onyx.server.middleware.rate_limiting import setup_limiter +from onyx.server.middleware.rate_limiting import setup_auth_limiter from onyx.server.onyx_api.ingestion import router as onyx_api_router from onyx.server.openai_assistants_api.full_openai_assistants_api import ( get_full_openai_assistants_api_router, @@ -174,7 +175,7 @@ def include_auth_router_with_prefix( @asynccontextmanager -async def lifespan(app: FastAPI) -> AsyncGenerator: +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # Set recursion limit if SYSTEM_RECURSION_LIMIT is not None: sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT) @@ -216,13 +217,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__}) - # Set up rate limiter - await setup_limiter() + if AUTH_RATE_LIMITING_ENABLED: + await setup_auth_limiter() yield - # Close rate limiter - await close_limiter() + if AUTH_RATE_LIMITING_ENABLED: + await close_auth_limiter() def log_http_error(_: Request, exc: Exception) -> JSONResponse: diff --git a/backend/onyx/redis/redis_pool.py b/backend/onyx/redis/redis_pool.py index acca2db8a56..83f7d010376 100644 --- a/backend/onyx/redis/redis_pool.py +++ b/backend/onyx/redis/redis_pool.py @@ -1,6 +1,7 @@ import asyncio import functools import json +import ssl import threading from collections.abc import Callable from typing import Any @@ -194,10 +195,6 @@ def create_pool( redis_pool = RedisPool() -def get_redis_client(*, tenant_id: str | None) -> Redis: - return redis_pool.get_client(tenant_id) - - # # Usage example # redis_pool = RedisPool() # redis_client = redis_pool.get_client() @@ -207,6 +204,18 @@ def get_redis_client(*, tenant_id: str | None) -> Redis: # value = redis_client.get('key') # print(value.decode()) # Output: 'value' + +def get_redis_client(*, tenant_id: str | None) -> Redis: + return redis_pool.get_client(tenant_id) + + +SSL_CERT_REQS_MAP = { + "none": ssl.CERT_NONE, + "optional": ssl.CERT_OPTIONAL, + "required": ssl.CERT_REQUIRED, +} + + _async_redis_connection: aioredis.Redis | None = None _async_lock = asyncio.Lock() @@ -224,15 +233,35 @@ async def get_async_redis_connection() -> aioredis.Redis: async with _async_lock: # Double-check inside the lock to avoid race conditions if _async_redis_connection is None: - scheme = "rediss" if REDIS_SSL else "redis" - url = f"{scheme}://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER}" - - # Create a new Redis connection (or connection pool) from the URL - _async_redis_connection = aioredis.from_url( - url, - password=REDIS_PASSWORD, - max_connections=REDIS_POOL_MAX_CONNECTIONS, - ) + # Load env vars or your config variables + + connection_kwargs: dict[str, Any] = { + "host": REDIS_HOST, + "port": REDIS_PORT, + "db": REDIS_DB_NUMBER, + "password": REDIS_PASSWORD, + "max_connections": REDIS_POOL_MAX_CONNECTIONS, + "health_check_interval": REDIS_HEALTH_CHECK_INTERVAL, + "socket_keepalive": True, + "socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS, + } + + if REDIS_SSL: + ssl_context = ssl.create_default_context() + + if REDIS_SSL_CA_CERTS: + ssl_context.load_verify_locations(REDIS_SSL_CA_CERTS) + ssl_context.check_hostname = False + + # Map your string to the proper ssl.CERT_* constant + ssl_context.verify_mode = SSL_CERT_REQS_MAP.get( + REDIS_SSL_CERT_REQS, ssl.CERT_NONE + ) + + connection_kwargs["ssl"] = ssl_context + + # Create a new Redis connection (or connection pool) with SSL configuration + _async_redis_connection = aioredis.Redis(**connection_kwargs) # Return the established connection (or pool) for all future operations return _async_redis_connection diff --git a/backend/onyx/server/middleware/rate_limiting.py b/backend/onyx/server/middleware/rate_limiting.py index c3a02079811..98d2dcd674d 100644 --- a/backend/onyx/server/middleware/rate_limiting.py +++ b/backend/onyx/server/middleware/rate_limiting.py @@ -6,18 +6,19 @@ from fastapi_limiter import FastAPILimiter from fastapi_limiter.depends import RateLimiter +from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED from onyx.configs.app_configs import RATE_LIMIT_MAX_REQUESTS from onyx.configs.app_configs import RATE_LIMIT_WINDOW_SECONDS from onyx.redis.redis_pool import get_async_redis_connection -async def setup_limiter() -> None: +async def setup_auth_limiter() -> None: # Use the centralized async Redis connection redis = await get_async_redis_connection() await FastAPILimiter.init(redis) -async def close_limiter() -> None: +async def close_auth_limiter() -> None: # This closes the FastAPILimiter connection so we don't leave open connections to Redis. await FastAPILimiter.close() @@ -32,14 +33,14 @@ async def rate_limit_key(request: Request) -> str: def get_auth_rate_limiters() -> List[Callable]: - if not (RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS): + if not AUTH_RATE_LIMITING_ENABLED: return [] return [ Depends( RateLimiter( - times=RATE_LIMIT_MAX_REQUESTS, - seconds=RATE_LIMIT_WINDOW_SECONDS, + times=RATE_LIMIT_MAX_REQUESTS or 100, + seconds=RATE_LIMIT_WINDOW_SECONDS or 60, # Use the custom key function to distinguish users identifier=rate_limit_key, )