From 4f74fa7e35afc0e18a29212981be5c8a34db24fd Mon Sep 17 00:00:00 2001 From: yym68686 Date: Sat, 2 Dec 2023 17:32:40 +0800 Subject: [PATCH] 1. Fixed bug: gpt-3.5-turbo-1106 maximum tokens. 2. Optimize keyword extraction prompt --- test/test_keyword.py | 12 +++++++++--- utils/agent.py | 11 +++++++++-- utils/chatgpt2api.py | 7 +++++-- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/test/test_keyword.py b/test/test_keyword.py index 3bedfd54..f8ae9511 100644 --- a/test/test_keyword.py +++ b/test/test_keyword.py @@ -17,7 +17,7 @@ def getgooglesearchurl(result, numresults=1): urls.append(i["link"]) return urls -chainllm = ChatOpenAI(temperature=0.5, openai_api_base=os.environ.get('API_URL', None).split("chat")[0], model_name="gpt-4-1106-preview", openai_api_key=os.environ.get('API', None)) +chainllm = ChatOpenAI(temperature=0.5, openai_api_base=os.environ.get('API_URL', None).split("chat")[0], model_name="gpt-3.5-turbo-1106", openai_api_key=os.environ.get('API', None)) # keyword_prompt = PromptTemplate( # input_variables=["source"], # # template="*{source}*, ——我想通过网页搜索引擎,获取上述问题的可能答案。请你提取上述问题相关的关键词作为搜索用词(用空格隔开),直接给我结果(不要多余符号)。", @@ -29,7 +29,7 @@ def getgooglesearchurl(result, numresults=1): keyword_prompt = PromptTemplate( input_variables=["source"], template=( - "根据我的问题,总结最少的关键词概括,给出三行不同的关键词组合,每行的关键词用空格连接,至少有一行关键词里面有中文,至少有一行关键词里面有英文,不要出现其他符号。" + "根据我的问题,总结最少的关键词概括,给出三行不同的关键词组合,每行的关键词用空格连接,至少有一行关键词里面有中文,至少有一行关键词里面有英文。只要直接给出这三行关键词,不需要其他任何解释,不要出现其他符号。" "下面是示例:" "问题1:How much does the 'zeabur' software service cost per month? Is it free to use? Any limitations?" "三行关键词是:" @@ -41,11 +41,17 @@ def getgooglesearchurl(result, numresults=1): "pplx API demo" "pplx API" "pplx API 使用方法" + "问题3:以色列哈马斯的最新情况" + "三行关键词是:" + "以色列 哈马斯 最新情况" + "Israel Hamas situation" + "哈马斯 以色列 冲突" "这是我的问题:{source}" ), ) key_chain = LLMChain(llm=chainllm, prompt=keyword_prompt) -result = key_chain.run("今天的微博热搜有哪些?").split('\n') +result = key_chain.run("以色列哈马斯的最新情况").split('\n') +# result = key_chain.run("今天的微博热搜有哪些?").split('\n') # result = key_chain.run("鸿蒙是安卓套壳吗?") # result = key_chain.run("How much does the 'zeabur' software service cost per month? Is it free to use? Any limitations?") print(result) diff --git a/utils/agent.py b/utils/agent.py index 7823d5e1..d26056ff 100644 --- a/utils/agent.py +++ b/utils/agent.py @@ -324,7 +324,7 @@ def get_search_url(prompt, chainllm): keyword_prompt = PromptTemplate( input_variables=["source"], template=( - "根据我的问题,总结最少的关键词概括,给出三行不同的关键词组合,每行的关键词用空格连接,至少有一行关键词里面有中文,至少有一行关键词里面有英文,不要出现其他符号。" + "根据我的问题,总结最少的关键词概括,给出三行不同的关键词组合,每行的关键词用空格连接,至少有一行关键词里面有中文,至少有一行关键词里面有英文。只要直接给出这三行关键词,不需要其他任何解释,不要出现其他符号。" "下面是示例:" "问题1:How much does the 'zeabur' software service cost per month? Is it free to use? Any limitations?" "三行关键词是:" @@ -336,14 +336,21 @@ def get_search_url(prompt, chainllm): "pplx API demo" "pplx API" "pplx API 使用方法" + "问题3:以色列哈马斯的最新情况" + "三行关键词是:" + "以色列 哈马斯 最新情况" + "Israel Hamas situation" + "哈马斯 以色列 冲突" "这是我的问题:{source}" ), ) key_chain = LLMChain(llm=chainllm, prompt=keyword_prompt) keyword_google_search_thread = ThreadWithReturnValue(target=key_chain.run, args=({"source": prompt},)) keyword_google_search_thread.start() - keywords = keyword_google_search_thread.join().split('\n') + keywords = keyword_google_search_thread.join().split('\n')[-3:] print("keywords", keywords) + keywords = [item.replace("三行关键词是:", "") for item in keywords if "\\x" not in item] + print("select keywords", keywords) search_threads = [] urls_set = [] diff --git a/utils/chatgpt2api.py b/utils/chatgpt2api.py index a4236ff5..ce8c6cd6 100644 --- a/utils/chatgpt2api.py +++ b/utils/chatgpt2api.py @@ -276,7 +276,7 @@ def __init__( if "gpt-4-32k" in engine else 7000 if "gpt-4" in engine - else 4000 + else 4096 if "gpt-3.5-turbo-1106" in engine else 15000 if "gpt-3.5-turbo-16k" in engine @@ -401,7 +401,7 @@ def get_max_tokens(self, convo_id: str) -> int: """ Get max tokens """ - # print(self.max_tokens, self.get_token_count(convo_id)) + # print("self.max_tokens, self.get_token_count(convo_id)", self.max_tokens, self.get_token_count(convo_id)) return self.max_tokens - self.get_token_count(convo_id) def ask_stream( @@ -429,6 +429,8 @@ def ask_stream( if self.engine == "gpt-4-1106-preview": model_max_tokens = kwargs.get("max_tokens", self.max_tokens) + elif self.engine == "gpt-3.5-turbo-1106": + model_max_tokens = min(kwargs.get("max_tokens", self.max_tokens), self.truncate_limit - self.get_token_count(convo_id)) else: model_max_tokens = min(self.get_max_tokens(convo_id=convo_id) - 500, kwargs.get("max_tokens", self.max_tokens)) json_post = { @@ -505,6 +507,7 @@ def ask_stream( response_role = "function" if function_call_name == "get_search_results": prompt = json.loads(full_response)["prompt"] + # print(self.truncate_limit, self.get_token_count(convo_id), max_context_tokens) function_response = eval(function_call_name)(prompt, max_context_tokens) function_response = "web search results: \n" + function_response yield from self.ask_stream(function_response, response_role, convo_id=convo_id, function_name=function_call_name)