Skip to content

Commit

Permalink
Add sequential chatbot client for length-1 messages.
Browse files Browse the repository at this point in the history
  • Loading branch information
zh-plus committed Jan 29, 2024
1 parent e548c03 commit 72a2c03
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 16 deletions.
103 changes: 91 additions & 12 deletions openlrc/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,49 @@
from typing import List, Union, Dict, Callable

import openai
from openai import AsyncClient
from openai import AsyncClient, Client

from openlrc.exceptions import ChatBotException
from openlrc.exceptions import ChatBotException, LengthExceedException
from openlrc.logger import logger
from openlrc.utils import get_messages_token_number, get_text_token_number


# def retry_on_openai_failure(retry_num):
# """
# Exception handling wrapper.
# :param retry_num:
# """
#
# def decorate(f):
# def applicator(*args, **kwargs):
# try:
# return f(*args, **kwargs)
# except openai.RateLimitError:
# sleep_time = random.randint(30, 60)
# logger.warning(f'Rate limit exceeded. Wait {sleep_time}s before retry. Retry num: {retry_num + 1}.')
# time.sleep(sleep_time)
# except openai.APITimeoutError:
# logger.warning(f'Timeout. Wait 3 before retry. Retry num: {retry_num + 1}.')
# time.sleep(3)
# except openai.APIConnectionError:
# logger.warning(f'API connection error. Wait 15s before retry. Retry num: {retry_num + 1}.')
# time.sleep(15)
# except openai.APIError:
# logger.warning(f'API error. Wait 15s before retry. Retry num: {retry_num + 1}.')
# time.sleep(15)
#
# return OpenaiFailureException()
#
# return applicator
#
# return decorate


class GPTBot:
def __init__(self, model='gpt-3.5-turbo-1106', temperature=1, top_p=1, retry=8, max_async=16, json_mode=False,
fee_limit=0.05):
self.client = AsyncClient(api_key=os.environ['OPENAI_API_KEY'])
self.async_client = AsyncClient(api_key=os.environ['OPENAI_API_KEY'])
self.seq_client = Client(api_key=os.environ['OPENAI_API_KEY'])

# Pricing for 1k tokens, info from https://openai.com/pricing
self.pricing = {
Expand Down Expand Up @@ -75,13 +107,63 @@ def update_fee(self, response):

self.api_fees[-1] += (prompt_tokens * prompt_price + completion_tokens * completion_price) / 1000

def _create_chat(self, messages: List[Dict], output_checker: Callable = lambda *args, **kw: True):
logger.debug(f'Raw content: {messages}')

# @retry_on_openai_failure
# def create_chat():
# return self.seq_client.chat.completions.create(
# model=self.model,
# messages=messages,
# temperature=self.temperature,
# top_p=self.top_p,
# response_format={'type': 'json_object' if self.json_mode else 'text'}
# )

response = None
for i in range(self.retry):
try:
response = self.seq_client.chat.completions.create(
model=self.model,
messages=messages,
temperature=self.temperature,
top_p=self.top_p,
response_format={'type': 'json_object' if self.json_mode else 'text'}
)
self.update_fee(response)
if response.choices[0].finish_reason == 'length':
raise LengthExceedException(response)
if not output_checker(messages, response.choices[0].message.content):
logger.warning(f'Invalid response format. Retry num: {i + 1}.')
continue

break
except openai.RateLimitError:
sleep_time = random.randint(30, 60)
logger.warning(f'Rate limit exceeded. Wait {sleep_time}s before retry. Retry num: {i + 1}.')
time.sleep(sleep_time)
except openai.APITimeoutError:
logger.warning(f'Timeout. Wait 3 before retry. Retry num: {i + 1}.')
time.sleep(3)
except openai.APIConnectionError:
logger.warning(f'API connection error. Wait 15s before retry. Retry num: {i + 1}.')
time.sleep(15)
except openai.APIError:
logger.warning(f'API error. Wait 15s before retry. Retry num: {i + 1}.')
time.sleep(15)

if response is None:
raise ChatBotException('Failed to create a chat.')

return response

async def _create_achat(self, messages: List[Dict], output_checker: Callable = lambda *args, **kw: True):
logger.debug(f'Raw content: {messages}')

response = None
for i in range(self.retry):
try:
response = await self.client.chat.completions.create(
response = await self.async_client.chat.completions.create(
model=self.model,
messages=messages,
temperature=self.temperature,
Expand All @@ -90,13 +172,7 @@ async def _create_achat(self, messages: List[Dict], output_checker: Callable = l
)
self.update_fee(response)
if response.choices[0].finish_reason == 'length':
raise ChatBotException(
f'Failed to get completion. Exceed max token length. '
f'Prompt tokens: {response.usage.prompt_tokens}, '
f'Completion tokens: {response.usage.completion_tokens}, '
f'Total tokens: {response.usage.total_tokens} '
f'Reduce chunk_size may help.'
)
raise LengthExceedException(response)
if not output_checker(messages, response.choices[0].message.content):
logger.warning(f'Invalid response format. Retry num: {i + 1}.')
continue
Expand Down Expand Up @@ -156,7 +232,10 @@ def message(self, messages_list: Union[List[Dict], List[List[Dict]]],
f'exceeds the limit: {self.fee_limit}$.')

try:
results = asyncio.run(self._amessage(messages_list, output_checker=output_checker))
if len(messages_list) == 1:
results = [self._create_chat(messages_list[0], output_checker=output_checker)]
else:
results = asyncio.run(self._amessage(messages_list, output_checker=output_checker))
except ChatBotException as e:
logger.error(f'Failed to message with GPT. Error: {e}')
raise e
Expand Down
28 changes: 27 additions & 1 deletion openlrc/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (C) 2023. Hao Zheng
# Copyright (C) 2024. Hao Zheng
# All rights reserved.
from openai.types.chat import ChatCompletion


class SameLanguageException(Exception):
"""
Expand All @@ -22,6 +24,30 @@ def __init__(self, message):
super().__init__(message)


class LengthExceedException(ChatBotException):
"""
Raised when the length of generated response exceeds the limit.
"""

def __init__(self, response: ChatCompletion):
super().__init__(
f'Failed to get completion. Exceed max token length. '
f'Prompt tokens: {response.usage.prompt_tokens}, '
f'Completion tokens: {response.usage.completion_tokens}, '
f'Total tokens: {response.usage.total_tokens} '
f'Reduce chunk_size may help.'
)


class OpenaiFailureException(Exception):
"""
Raised when OpenAI API fails to generate response.
"""

def __init__(self):
super().__init__('OpenAI API failed to generate response.')


class FfmpegException(Exception):
def __init__(self, message):
super().__init__(message)
Expand Down
2 changes: 1 addition & 1 deletion openlrc/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def translate(self, texts: Union[str, List[str]], src_lang, target_lang, audio_t

translations.extend(translated)
summaries.append(summary)
logger.info(f'Translating {title}: {i}/{len(chunks)}')
logger.info(f'Translated {title}: {i}/{len(chunks)}')
logger.info(f'summary: {summary}')
logger.info(f'scene: {scene}')

Expand Down
14 changes: 12 additions & 2 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023. Hao Zheng
# Copyright (C) 2024. Hao Zheng
# All rights reserved.

import unittest
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_update_fee(self):

assert bot.api_fees == [0.0005, 0.001, 0.0015]

def test_message(self):
def test_message_async(self):
bot = self.bot
messages_list = [
[
Expand All @@ -60,3 +60,13 @@ def test_message(self):
]
results = bot.message(messages_list)
assert all(['hello' in r.choices[0].message.content.lower() for r in results])

def test_message_seq(self):
bot = self.bot
messages_list = [
[
{'role': 'user', 'content': 'Echo hello:'}
]
]
results = bot.message(messages_list)
assert 'hello' in results[0].choices[0].message.content.lower()

0 comments on commit 72a2c03

Please sign in to comment.