Skip to content

Commit

Permalink
fix: propagating Authorization header to the upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Sep 10, 2024
1 parent 8808f93 commit b5dc07e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
2 changes: 1 addition & 1 deletion aidial_interceptors_sdk/chat_completion/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ class Impl(DialChatCompletion):
async def chat_completion(
self, request: DialRequest, response: DialResponse
) -> None:

dial_client = await DialClient.create(
api_key=request.api_key,
api_version=request.api_version,
authorization=request.jwt,
)

interceptor = cls(
Expand Down
18 changes: 15 additions & 3 deletions aidial_interceptors_sdk/dial_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,23 @@ def dial_url(self) -> str:

@classmethod
async def create(
cls, api_key: str | None, api_version: str | None
cls,
api_key: str | None,
authorization: str | None,
api_version: str | None,
) -> "DialClient":
if not api_key:
raise InvalidRequestError("The 'api-key' request header is missing")

extra_headers = {}
if authorization is not None:
extra_headers["Authorization"] = authorization

client = AsyncAzureOpenAI(
azure_endpoint=DIAL_URL,
azure_deployment="interceptor",
# NOTE: DIAL SDK takes care of propagating auth headers
api_key="dummy",
# NOTE: DIAL SDK takes care of propagating api-key header
api_key="-",
# NOTE: api-version query parameter is not required in the chat completions DIAL API.
# However, it is required in Azure OpenAI API, that's why the openai library fails when it's missing:
# https://github.com/openai/openai-python/blob/9850c169c4126fd04dc6796e4685f1b9e4924aa4/src/openai/lib/azure.py#L174-L177
Expand All @@ -47,6 +54,11 @@ async def create(
api_version=api_version or "",
max_retries=0,
http_client=get_http_client(),
# NOTE: if Authorization header was provided in the request,
# then propagate it to the upstream.
# Whether interceptor gets the header or not, is determined by
# `forwardAuthToken` option set for the interceptor in the DIAL Core config.
default_headers=extra_headers,
)

storage = FileStorage(dial_url=DIAL_URL, api_key=api_key)
Expand Down
3 changes: 2 additions & 1 deletion aidial_interceptors_sdk/embeddings/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ def interceptor_to_embeddings_handler(cls: Type[EmbeddingsInterceptor]):
async def _handler(request: Request) -> dict:

dial_client = await DialClient.create(
api_key=request.headers.get("api-key", None),
api_key=request.headers.get("api-key"),
api_version=request.query_params.get("api-version"),
authorization=request.headers.get("authorization"),
)

interceptor = cls(dial_client=dial_client, **request.path_params)
Expand Down

0 comments on commit b5dc07e

Please sign in to comment.