Skip to content

Commit

Permalink
1. Fixed bug: gpt-3.5-turbo-1106 maximum tokens.
Browse files Browse the repository at this point in the history
2. Optimize keyword extraction prompt
  • Loading branch information
yym68686 committed Dec 2, 2023
1 parent 63ea99e commit 4f74fa7
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
12 changes: 9 additions & 3 deletions test/test_keyword.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def getgooglesearchurl(result, numresults=1):
urls.append(i["link"])
return urls

chainllm = ChatOpenAI(temperature=0.5, openai_api_base=os.environ.get('API_URL', None).split("chat")[0], model_name="gpt-4-1106-preview", openai_api_key=os.environ.get('API', None))
chainllm = ChatOpenAI(temperature=0.5, openai_api_base=os.environ.get('API_URL', None).split("chat")[0], model_name="gpt-3.5-turbo-1106", openai_api_key=os.environ.get('API', None))
# keyword_prompt = PromptTemplate(
# input_variables=["source"],
# # template="*{source}*, ——我想通过网页搜索引擎,获取上述问题的可能答案。请你提取上述问题相关的关键词作为搜索用词(用空格隔开),直接给我结果(不要多余符号)。",
Expand All @@ -29,7 +29,7 @@ def getgooglesearchurl(result, numresults=1):
keyword_prompt = PromptTemplate(
input_variables=["source"],
template=(
"根据我的问题,总结最少的关键词概括,给出三行不同的关键词组合,每行的关键词用空格连接,至少有一行关键词里面有中文,至少有一行关键词里面有英文,不要出现其他符号。"
"根据我的问题,总结最少的关键词概括,给出三行不同的关键词组合,每行的关键词用空格连接,至少有一行关键词里面有中文,至少有一行关键词里面有英文。只要直接给出这三行关键词,不需要其他任何解释,不要出现其他符号。"
"下面是示例:"
"问题1:How much does the 'zeabur' software service cost per month? Is it free to use? Any limitations?"
"三行关键词是:"
Expand All @@ -41,11 +41,17 @@ def getgooglesearchurl(result, numresults=1):
"pplx API demo"
"pplx API"
"pplx API 使用方法"
"问题3:以色列哈马斯的最新情况"
"三行关键词是:"
"以色列 哈马斯 最新情况"
"Israel Hamas situation"
"哈马斯 以色列 冲突"
"这是我的问题:{source}"
),
)
key_chain = LLMChain(llm=chainllm, prompt=keyword_prompt)
result = key_chain.run("今天的微博热搜有哪些?").split('\n')
result = key_chain.run("以色列哈马斯的最新情况").split('\n')
# result = key_chain.run("今天的微博热搜有哪些?").split('\n')
# result = key_chain.run("鸿蒙是安卓套壳吗?")
# result = key_chain.run("How much does the 'zeabur' software service cost per month? Is it free to use? Any limitations?")
print(result)
Expand Down
11 changes: 9 additions & 2 deletions utils/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def get_search_url(prompt, chainllm):
keyword_prompt = PromptTemplate(
input_variables=["source"],
template=(
"根据我的问题,总结最少的关键词概括,给出三行不同的关键词组合,每行的关键词用空格连接,至少有一行关键词里面有中文,至少有一行关键词里面有英文,不要出现其他符号。"
"根据我的问题,总结最少的关键词概括,给出三行不同的关键词组合,每行的关键词用空格连接,至少有一行关键词里面有中文,至少有一行关键词里面有英文。只要直接给出这三行关键词,不需要其他任何解释,不要出现其他符号。"
"下面是示例:"
"问题1:How much does the 'zeabur' software service cost per month? Is it free to use? Any limitations?"
"三行关键词是:"
Expand All @@ -336,14 +336,21 @@ def get_search_url(prompt, chainllm):
"pplx API demo"
"pplx API"
"pplx API 使用方法"
"问题3:以色列哈马斯的最新情况"
"三行关键词是:"
"以色列 哈马斯 最新情况"
"Israel Hamas situation"
"哈马斯 以色列 冲突"
"这是我的问题:{source}"
),
)
key_chain = LLMChain(llm=chainllm, prompt=keyword_prompt)
keyword_google_search_thread = ThreadWithReturnValue(target=key_chain.run, args=({"source": prompt},))
keyword_google_search_thread.start()
keywords = keyword_google_search_thread.join().split('\n')
keywords = keyword_google_search_thread.join().split('\n')[-3:]
print("keywords", keywords)
keywords = [item.replace("三行关键词是:", "") for item in keywords if "\\x" not in item]
print("select keywords", keywords)

search_threads = []
urls_set = []
Expand Down
7 changes: 5 additions & 2 deletions utils/chatgpt2api.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __init__(
if "gpt-4-32k" in engine
else 7000
if "gpt-4" in engine
else 4000
else 4096
if "gpt-3.5-turbo-1106" in engine
else 15000
if "gpt-3.5-turbo-16k" in engine
Expand Down Expand Up @@ -401,7 +401,7 @@ def get_max_tokens(self, convo_id: str) -> int:
"""
Get max tokens
"""
# print(self.max_tokens, self.get_token_count(convo_id))
# print("self.max_tokens, self.get_token_count(convo_id)", self.max_tokens, self.get_token_count(convo_id))
return self.max_tokens - self.get_token_count(convo_id)

def ask_stream(
Expand Down Expand Up @@ -429,6 +429,8 @@ def ask_stream(

if self.engine == "gpt-4-1106-preview":
model_max_tokens = kwargs.get("max_tokens", self.max_tokens)
elif self.engine == "gpt-3.5-turbo-1106":
model_max_tokens = min(kwargs.get("max_tokens", self.max_tokens), self.truncate_limit - self.get_token_count(convo_id))
else:
model_max_tokens = min(self.get_max_tokens(convo_id=convo_id) - 500, kwargs.get("max_tokens", self.max_tokens))
json_post = {
Expand Down Expand Up @@ -505,6 +507,7 @@ def ask_stream(
response_role = "function"
if function_call_name == "get_search_results":
prompt = json.loads(full_response)["prompt"]
# print(self.truncate_limit, self.get_token_count(convo_id), max_context_tokens)
function_response = eval(function_call_name)(prompt, max_context_tokens)
function_response = "web search results: \n" + function_response
yield from self.ask_stream(function_response, response_role, convo_id=convo_id, function_name=function_call_name)
Expand Down

0 comments on commit 4f74fa7

Please sign in to comment.