From 72a2c037d8e73860622acdc844dd70d509fa573e Mon Sep 17 00:00:00 2001 From: pluuus Date: Mon, 29 Jan 2024 16:50:12 +0800 Subject: [PATCH] Add sequential chatbot client for length-1 messages. --- openlrc/chatbot.py | 103 +++++++++++++++++++++++++++++++++++++----- openlrc/exceptions.py | 28 +++++++++++- openlrc/translate.py | 2 +- tests/test_chatbot.py | 14 +++++- 4 files changed, 131 insertions(+), 16 deletions(-) diff --git a/openlrc/chatbot.py b/openlrc/chatbot.py index fdeaa7c..ffc3617 100644 --- a/openlrc/chatbot.py +++ b/openlrc/chatbot.py @@ -8,17 +8,49 @@ from typing import List, Union, Dict, Callable import openai -from openai import AsyncClient +from openai import AsyncClient, Client -from openlrc.exceptions import ChatBotException +from openlrc.exceptions import ChatBotException, LengthExceedException from openlrc.logger import logger from openlrc.utils import get_messages_token_number, get_text_token_number +# def retry_on_openai_failure(retry_num): +# """ +# Exception handling wrapper. +# :param retry_num: +# """ +# +# def decorate(f): +# def applicator(*args, **kwargs): +# try: +# return f(*args, **kwargs) +# except openai.RateLimitError: +# sleep_time = random.randint(30, 60) +# logger.warning(f'Rate limit exceeded. Wait {sleep_time}s before retry. Retry num: {retry_num + 1}.') +# time.sleep(sleep_time) +# except openai.APITimeoutError: +# logger.warning(f'Timeout. Wait 3 before retry. Retry num: {retry_num + 1}.') +# time.sleep(3) +# except openai.APIConnectionError: +# logger.warning(f'API connection error. Wait 15s before retry. Retry num: {retry_num + 1}.') +# time.sleep(15) +# except openai.APIError: +# logger.warning(f'API error. Wait 15s before retry. Retry num: {retry_num + 1}.') +# time.sleep(15) +# +# return OpenaiFailureException() +# +# return applicator +# +# return decorate + + class GPTBot: def __init__(self, model='gpt-3.5-turbo-1106', temperature=1, top_p=1, retry=8, max_async=16, json_mode=False, fee_limit=0.05): - self.client = AsyncClient(api_key=os.environ['OPENAI_API_KEY']) + self.async_client = AsyncClient(api_key=os.environ['OPENAI_API_KEY']) + self.seq_client = Client(api_key=os.environ['OPENAI_API_KEY']) # Pricing for 1k tokens, info from https://openai.com/pricing self.pricing = { @@ -75,13 +107,63 @@ def update_fee(self, response): self.api_fees[-1] += (prompt_tokens * prompt_price + completion_tokens * completion_price) / 1000 + def _create_chat(self, messages: List[Dict], output_checker: Callable = lambda *args, **kw: True): + logger.debug(f'Raw content: {messages}') + + # @retry_on_openai_failure + # def create_chat(): + # return self.seq_client.chat.completions.create( + # model=self.model, + # messages=messages, + # temperature=self.temperature, + # top_p=self.top_p, + # response_format={'type': 'json_object' if self.json_mode else 'text'} + # ) + + response = None + for i in range(self.retry): + try: + response = self.seq_client.chat.completions.create( + model=self.model, + messages=messages, + temperature=self.temperature, + top_p=self.top_p, + response_format={'type': 'json_object' if self.json_mode else 'text'} + ) + self.update_fee(response) + if response.choices[0].finish_reason == 'length': + raise LengthExceedException(response) + if not output_checker(messages, response.choices[0].message.content): + logger.warning(f'Invalid response format. Retry num: {i + 1}.') + continue + + break + except openai.RateLimitError: + sleep_time = random.randint(30, 60) + logger.warning(f'Rate limit exceeded. Wait {sleep_time}s before retry. Retry num: {i + 1}.') + time.sleep(sleep_time) + except openai.APITimeoutError: + logger.warning(f'Timeout. Wait 3 before retry. Retry num: {i + 1}.') + time.sleep(3) + except openai.APIConnectionError: + logger.warning(f'API connection error. Wait 15s before retry. Retry num: {i + 1}.') + time.sleep(15) + except openai.APIError: + logger.warning(f'API error. Wait 15s before retry. Retry num: {i + 1}.') + time.sleep(15) + + if response is None: + raise ChatBotException('Failed to create a chat.') + + return response + async def _create_achat(self, messages: List[Dict], output_checker: Callable = lambda *args, **kw: True): logger.debug(f'Raw content: {messages}') response = None for i in range(self.retry): try: - response = await self.client.chat.completions.create( + response = await self.async_client.chat.completions.create( model=self.model, messages=messages, temperature=self.temperature, @@ -90,13 +172,7 @@ async def _create_achat(self, messages: List[Dict], output_checker: Callable = l ) self.update_fee(response) if response.choices[0].finish_reason == 'length': - raise ChatBotException( - f'Failed to get completion. Exceed max token length. ' - f'Prompt tokens: {response.usage.prompt_tokens}, ' - f'Completion tokens: {response.usage.completion_tokens}, ' - f'Total tokens: {response.usage.total_tokens} ' - f'Reduce chunk_size may help.' - ) + raise LengthExceedException(response) if not output_checker(messages, response.choices[0].message.content): logger.warning(f'Invalid response format. Retry num: {i + 1}.') continue @@ -156,7 +232,10 @@ def message(self, messages_list: Union[List[Dict], List[List[Dict]]], f'exceeds the limit: {self.fee_limit}$.') try: - results = asyncio.run(self._amessage(messages_list, output_checker=output_checker)) + if len(messages_list) == 1: + results = [self._create_chat(messages_list[0], output_checker=output_checker)] + else: + results = asyncio.run(self._amessage(messages_list, output_checker=output_checker)) except ChatBotException as e: logger.error(f'Failed to message with GPT. Error: {e}') raise e diff --git a/openlrc/exceptions.py b/openlrc/exceptions.py index e99e624..ac8a42b 100644 --- a/openlrc/exceptions.py +++ b/openlrc/exceptions.py @@ -1,5 +1,7 @@ -# Copyright (C) 2023. Hao Zheng +# Copyright (C) 2024. Hao Zheng # All rights reserved. +from openai.types.chat import ChatCompletion + class SameLanguageException(Exception): """ @@ -22,6 +24,30 @@ def __init__(self, message): super().__init__(message) +class LengthExceedException(ChatBotException): + """ + Raised when the length of generated response exceeds the limit. + """ + + def __init__(self, response: ChatCompletion): + super().__init__( + f'Failed to get completion. Exceed max token length. ' + f'Prompt tokens: {response.usage.prompt_tokens}, ' + f'Completion tokens: {response.usage.completion_tokens}, ' + f'Total tokens: {response.usage.total_tokens} ' + f'Reduce chunk_size may help.' + ) + + +class OpenaiFailureException(Exception): + """ + Raised when OpenAI API fails to generate response. + """ + + def __init__(self): + super().__init__('OpenAI API failed to generate response.') + + class FfmpegException(Exception): def __init__(self, message): super().__init__(message) diff --git a/openlrc/translate.py b/openlrc/translate.py index e93fa2b..b2699b1 100644 --- a/openlrc/translate.py +++ b/openlrc/translate.py @@ -143,7 +143,7 @@ def translate(self, texts: Union[str, List[str]], src_lang, target_lang, audio_t translations.extend(translated) summaries.append(summary) - logger.info(f'Translating {title}: {i}/{len(chunks)}') + logger.info(f'Translated {title}: {i}/{len(chunks)}') logger.info(f'summary: {summary}') logger.info(f'scene: {scene}') diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index 6cfc4f0..aedbfc4 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023. Hao Zheng +# Copyright (C) 2024. Hao Zheng # All rights reserved. import unittest @@ -48,7 +48,7 @@ def test_update_fee(self): assert bot.api_fees == [0.0005, 0.001, 0.0015] - def test_message(self): + def test_message_async(self): bot = self.bot messages_list = [ [ @@ -60,3 +60,13 @@ def test_message(self): ] results = bot.message(messages_list) assert all(['hello' in r.choices[0].message.content.lower() for r in results]) + + def test_message_seq(self): + bot = self.bot + messages_list = [ + [ + {'role': 'user', 'content': 'Echo hello:'} + ] + ] + results = bot.message(messages_list) + assert 'hello' in results[0].choices[0].message.content.lower()