Skip to content

Commit

Permalink
🎨 style: make more error handling (#58)
Browse files Browse the repository at this point in the history
* make more readable

* more error handling
  • Loading branch information
KenyonY authored Aug 12, 2023
1 parent dcb16a7 commit 7c18fae
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 44 deletions.
2 changes: 1 addition & 1 deletion openai_forward/content/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ def _parse_one_line(line: str):
except KeyError:
return ""

def add_chat(self, chat_info: dict):
def log_chat(self, chat_info: dict):
self.logger.debug(f"{chat_info}")
113 changes: 70 additions & 43 deletions openai_forward/forwarding/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import traceback
import uuid
from itertools import cycle
from typing import Any, AsyncGenerator
from typing import Any, AsyncGenerator, List

import httpx
from fastapi import HTTPException, Request, status
Expand Down Expand Up @@ -46,10 +47,19 @@ async def aiter_bytes(
async for chunk in r.aiter_bytes():
bytes_ += chunk
yield chunk
await r.aclose()
cls.extrasaver.add_log(bytes_)

async def try_send(self, req: httpx.Request, request: Request):
async def try_send(self, client_config: dict, request: Request):
try:
req = self.client.build_request(
method=request.method,
url=client_config['url'],
headers=client_config["headers"],
content=request.stream(),
timeout=self.timeout,
)

r = await self.client.send(req, stream=True)
return r
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
Expand Down Expand Up @@ -96,14 +106,7 @@ async def reverse_proxy(self, request: Request):
assert self.client is not None
client_config = self.prepare_client(request)

req = self.client.build_request(
request.method,
client_config['url'],
headers=client_config["headers"],
content=request.stream(),
timeout=self.timeout,
)
r = await self.try_send(req, request)
r = await self.try_send(client_config, request)

return StreamingResponse(
self.aiter_bytes(r),
Expand All @@ -120,60 +123,84 @@ class OpenaiBase(ForwardingBase):
chatsaver = ChatSaver()
whispersaver = WhisperSaver()

async def aiter_bytes(
self, r: httpx.Response, request: Request, route_path: str, uid: str
def _add_result_log(
self, byte_list: List[bytes], uid: str, route_path: str, request_method: str
):
byte_list = []
async for chunk in r.aiter_bytes():
byte_list.append(chunk)
yield chunk
await r.aclose()
try:
if LOG_CHAT and request.method == "POST":
if LOG_CHAT and request_method == "POST":
if route_path == "/v1/chat/completions":
target_info = self.chatsaver.parse_iter_bytes(byte_list)
self.chatsaver.add_chat(
self.chatsaver.log_chat(
{target_info["role"]: target_info["content"], "uid": uid}
)

elif route_path.startswith("/v1/audio/"):
self.whispersaver.add_log(b"".join([_ for _ in byte_list]))
except Exception as e:
logger.debug(f"log chat (not) error:\n{traceback.format_exc()}")

async def reverse_proxy(self, request: Request):
client_config = self.prepare_client(request)
auth_prefix = "Bearer "
auth = client_config["auth"]
auth_headers_dict = client_config["headers"]
url_path = client_config["url_path"]
if self._no_auth_mode or auth and auth[len(auth_prefix) :] in FWD_KEY:
auth = auth_prefix + next(self._cycle_api_key)
auth_headers_dict["Authorization"] = auth
else:
...
except Exception as e:
logger.warning(f"log chat (not) error:\n{traceback.format_exc()}")

async def _add_payload_log(self, request: Request, url_path: str):
uid = None
if LOG_CHAT and request.method == "POST":
try:
if url_path == "/v1/chat/completions":
chat_info = await self.chatsaver.parse_payload(request)
uid = chat_info.get("uid")
if chat_info:
self.chatsaver.add_chat(chat_info)
uid = chat_info.get("uid")
self.chatsaver.log_chat(chat_info)

elif url_path.startswith("/v1/audio/"):
uid = uuid.uuid4().__str__()

else:
...

except Exception as e:
logger.debug(
logger.warning(
f"log chat error:\nhost:{request.client.host} method:{request.method}: {traceback.format_exc()}"
)
return uid

async def aiter_bytes(
self, r: httpx.Response, request: Request, route_path: str, uid: str
):
byte_list = []
async for chunk in r.aiter_bytes():
byte_list.append(chunk)
yield chunk

await r.aclose()

if uid:
if r.is_success:
self._add_result_log(byte_list, uid, route_path, request.method)
else:
response_info = b"".join([_ for _ in byte_list])
logger.warning(f'uid: {uid}\n' f'{response_info}')

async def reverse_proxy(self, request: Request):
client_config = self.prepare_client(request)
url_path = client_config["url_path"]

def set_apikey_from_preset():
nonlocal client_config
auth_prefix = "Bearer "
auth = client_config["auth"]
if self._no_auth_mode or auth and auth[len(auth_prefix) :] in FWD_KEY:
auth = auth_prefix + next(self._cycle_api_key)
client_config["headers"]["Authorization"] = auth

set_apikey_from_preset()

uid = await self._add_payload_log(request, url_path)

r = await self.try_send(client_config, request)

req = self.client.build_request(
request.method,
client_config['url'],
headers=auth_headers_dict,
content=request.stream(),
timeout=self.timeout,
)
r = await self.try_send(req, request)
aiter_bytes = self.aiter_bytes(r, request, url_path, uid)
return StreamingResponse(
aiter_bytes,
self.aiter_bytes(r, request, url_path, uid),
status_code=r.status_code,
media_type=r.headers.get("content-type"),
)

0 comments on commit 7c18fae

Please sign in to comment.