From b69a6df6e605c0a9e7833d9264b2d6e160f0556e Mon Sep 17 00:00:00 2001 From: Xander Song Date: Tue, 8 Oct 2024 18:46:47 -0700 Subject: [PATCH] fix(playground): authenticate websockets (#4924) --- src/phoenix/server/app.py | 33 ++++++++++++++++++++++++++++--- src/phoenix/server/bearer_auth.py | 16 +++++++++++---- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 708e6bcde7..9209e16895 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -34,15 +34,16 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from starlette.datastructures import State as StarletteState -from starlette.exceptions import HTTPException +from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request -from starlette.responses import PlainTextResponse, Response +from starlette.responses import JSONResponse, PlainTextResponse, Response from starlette.staticfiles import StaticFiles from starlette.templating import Jinja2Templates from starlette.types import Scope, StatefulLifespan +from starlette.websockets import WebSocket from strawberry.fastapi import GraphQLRouter from strawberry.schema import BaseSchema from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL @@ -607,6 +608,29 @@ async def plain_text_http_exception_handler(request: Request, exc: HTTPException return PlainTextResponse(str(exc.detail), status_code=exc.status_code, headers=headers) +async def websocket_denial_response_handler(websocket: WebSocket, exc: WebSocketException) -> None: + """ + Overrides the default exception handler for WebSocketException to ensure + that the HTTP response returned when a WebSocket connection is denied has + the same status code as the raised exception. This is in keeping with the + WebSocket Denial Response Extension of the ASGI specificiation described + below. + + "Websocket connections start with the client sending a HTTP request + containing the appropriate upgrade headers. On receipt of this request a + server can choose to either upgrade the connection or respond with an HTTP + response (denying the upgrade). The core ASGI specification does not allow + for any control over the denial response, instead specifying that the HTTP + status code 403 should be returned, whereas this extension allows an ASGI + framework to control the denial response." + + For details, see: + - https://asgi.readthedocs.io/en/latest/extensions.html#websocket-denial-response + """ + assert isinstance(exc, WebSocketException) + await websocket.send_denial_response(JSONResponse(status_code=exc.code, content=exc.reason)) + + def create_app( db: DbSessionFactory, export_path: Path, @@ -733,7 +757,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: scaffolder_config=scaffolder_config, ), middleware=middlewares, - exception_handlers={HTTPException: plain_text_http_exception_handler}, + exception_handlers={ + HTTPException: plain_text_http_exception_handler, + WebSocketException: websocket_denial_response_handler, # type: ignore[dict-item] + }, debug=debug, swagger_ui_parameters={ "defaultModelsExpandDepth": -1, # hides the schema section in the Swagger UI diff --git a/src/phoenix/server/bearer_auth.py b/src/phoenix/server/bearer_auth.py index 3c47171c3e..7e0b1447b6 100644 --- a/src/phoenix/server/bearer_auth.py +++ b/src/phoenix/server/bearer_auth.py @@ -7,10 +7,11 @@ Callable, Optional, Tuple, + cast, ) import grpc -from fastapi import HTTPException, Request +from fastapi import HTTPException, Request, WebSocket, WebSocketException from grpc_interceptor import AsyncServerInterceptor from grpc_interceptor.exceptions import Unauthenticated from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser @@ -116,12 +117,19 @@ async def intercept( raise Unauthenticated() -async def is_authenticated(request: Request) -> None: +async def is_authenticated( + # fastapi dependencies require non-optional types + request: Request = cast(Request, None), + websocket: WebSocket = cast(WebSocket, None), +) -> None: """ - Raises a 401 if the request is not authenticated. + Raises a 401 if the request or websocket connection is not authenticated. """ - if not isinstance((user := request.user), PhoenixUser): + assert request or websocket + if request and not isinstance((user := request.user), PhoenixUser): raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token") + if websocket and not isinstance((user := websocket.user), PhoenixUser): + raise WebSocketException(code=HTTP_401_UNAUTHORIZED, reason="Invalid token") claims = user.claims if claims.status is ClaimSetStatus.EXPIRED: raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Expired token")