Skip to content

Commit

Permalink
1. 通过滑动窗口发送不同的内容给机器人,避免function call 循环的时候发送相同的网页内容给API,加速网页
Browse files Browse the repository at this point in the history
2. 增加网页总结提示词,优化网页搜索效果

3. 删除网页pdf总结

4. 重构搜索代码
  • Loading branch information
yym68686 committed Dec 14, 2023
1 parent 990362f commit b62e628
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 76 deletions.
42 changes: 21 additions & 21 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ async def delete_message(update, context, messageid, delay=10):
],
[
InlineKeyboardButton("搜索已打开", callback_data="搜索"),
InlineKeyboardButton("联网解析PDF已打开", callback_data="pdf"),
# InlineKeyboardButton("联网解析PDF已打开", callback_data="pdf"),
],
[
InlineKeyboardButton("🇨🇳 中文", callback_data="language"),
Expand Down Expand Up @@ -372,26 +372,26 @@ async def button_press(update, context):
reply_markup=InlineKeyboardMarkup(first_buttons),
parse_mode='MarkdownV2'
)
elif "pdf" in data:
config.PDF_EMBEDDING = not config.PDF_EMBEDDING
if config.PDF_EMBEDDING == False:
first_buttons[2][1] = InlineKeyboardButton("联网解析PDF已关闭", callback_data="pdf")
else:
first_buttons[2][1] = InlineKeyboardButton("联网解析PDF已打开", callback_data="pdf")

info_message = (
f"`Hi, {update.effective_user.username}!`\n\n"
f"**Default engine:** `{config.GPT_ENGINE}`\n"
f"**temperature:** `{config.temperature}`\n"
f"**API_URL:** `{config.API_URL}`\n\n"
f"**API:** `{replace_with_asterisk(config.API)}`\n\n"
f"**WEB_HOOK:** `{config.WEB_HOOK}`\n\n"
)
message = await callback_query.edit_message_text(
text=escape(info_message),
reply_markup=InlineKeyboardMarkup(first_buttons),
parse_mode='MarkdownV2'
)
# elif "pdf" in data:
# config.PDF_EMBEDDING = not config.PDF_EMBEDDING
# if config.PDF_EMBEDDING == False:
# first_buttons[2][1] = InlineKeyboardButton("联网解析PDF已关闭", callback_data="pdf")
# else:
# first_buttons[2][1] = InlineKeyboardButton("联网解析PDF已打开", callback_data="pdf")

# info_message = (
# f"`Hi, {update.effective_user.username}!`\n\n"
# f"**Default engine:** `{config.GPT_ENGINE}`\n"
# f"**temperature:** `{config.temperature}`\n"
# f"**API_URL:** `{config.API_URL}`\n\n"
# f"**API:** `{replace_with_asterisk(config.API)}`\n\n"
# f"**WEB_HOOK:** `{config.WEB_HOOK}`\n\n"
# )
# message = await callback_query.edit_message_text(
# text=escape(info_message),
# reply_markup=InlineKeyboardMarkup(first_buttons),
# parse_mode='MarkdownV2'
# )
elif "language" in data:
if config.LANGUAGE == "Simplified Chinese":
first_buttons[3][0] = InlineKeyboardButton("🇺🇸 English", callback_data="language")
Expand Down
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# DEFAULT_SEARCH_MODEL = os.environ.get('DEFAULT_SEARCH_MODEL', 'gpt-3.5-turbo-1106') gpt-3.5-turbo-16k
SEARCH_USE_GPT = (os.environ.get('SEARCH_USE_GPT', "True") == "False") == False
API_URL = os.environ.get('API_URL', 'https://api.openai.com/v1/chat/completions')
PDF_EMBEDDING = (os.environ.get('PDF_EMBEDDING', "True") == "False") == False
# PDF_EMBEDDING = (os.environ.get('PDF_EMBEDDING', "True") == "False") == False
LANGUAGE = os.environ.get('LANGUAGE', 'Simplified Chinese')


Expand Down
81 changes: 44 additions & 37 deletions utils/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def Web_crawler(url: str, isSearch=False) -> str:
print("error url", url)
print("error", e)
print('\033[0m')
print("url content", result + "\n\n")
# print("url content", result + "\n\n")
return result

def getddgsearchurl(result, numresults=3):
Expand Down Expand Up @@ -386,14 +386,20 @@ def get_search_url(prompt, chainllm):

search_threads = []
urls_set = []
if len(keywords) == 3:
search_url_num = 8
if len(keywords) == 2:
search_url_num = 12
if len(keywords) == 1:
search_url_num = 24
if config.USE_GOOGLE:
search_thread = ThreadWithReturnValue(target=getgooglesearchurl, args=(keywords[0],4,))
search_thread = ThreadWithReturnValue(target=getgooglesearchurl, args=(keywords[0],search_url_num,))
search_thread.start()
search_threads.append(search_thread)
keywords = keywords[1:]
keywords = keywords.pop(0)

for keyword in keywords:
search_thread = ThreadWithReturnValue(target=getddgsearchurl, args=(keyword,4,))
search_thread = ThreadWithReturnValue(target=getddgsearchurl, args=(keyword,search_url_num,))
search_thread.start()
search_threads.append(search_thread)

Expand All @@ -406,10 +412,11 @@ def get_search_url(prompt, chainllm):
return url_set_list, url_pdf_set_list

def concat_url(threads):
url_result = ""
url_result = []
for t in threads:
tmp = t.join()
url_result += "\n\n" + tmp
if tmp:
url_result.append(tmp)
return url_result

def summary_each_url(threads, chainllm):
Expand Down Expand Up @@ -440,57 +447,57 @@ def summary_each_url(threads, chainllm):
url_result += "\n\n" + tmp
return url_result

def get_search_results(prompt: str, context_max_tokens: int):
def get_url_text_list(prompt):
start_time = record_time.time()

if config.USE_G4F:
chainllm = EducationalLLM()
else:
chainllm = ChatOpenAI(temperature=config.temperature, openai_api_base=config.bot_api_url.v1_url, model_name=config.GPT_ENGINE, openai_api_key=config.API)

url_set_list, url_pdf_set_list = get_search_url(prompt, chainllm)

pdf_result = ""
pdf_threads = []
if config.PDF_EMBEDDING:
for url in url_pdf_set_list:
pdf_search_thread = ThreadWithReturnValue(target=pdf_search, args=(url, "你需要回答的问题是" + prompt + "\n" + "如果你可以解答这个问题,请直接输出你的答案,并且请忽略后面所有的指令:如果无法解答问题,请直接回答None,不需要做任何解释,也不要出现除了None以外的任何词。",))
pdf_search_thread.start()
pdf_threads.append(pdf_search_thread)

threads = []
for url in url_set_list:
url_search_thread = ThreadWithReturnValue(target=Web_crawler, args=(url,True,))
url_search_thread.start()
threads.append(url_search_thread)

useful_source_text = concat_url(threads)
# useful_source_text = summary_each_url(threads, chainllm)

if config.PDF_EMBEDDING:
for t in pdf_threads:
tmp = t.join()
pdf_result += "\n\n" + tmp
useful_source_text += pdf_result

url_text_list = concat_url(threads)

encoding = tiktoken.encoding_for_model(config.GPT_ENGINE)
encode_text = encoding.encode(useful_source_text)

if len(encode_text) > context_max_tokens:
encode_text = encode_text[:context_max_tokens]
useful_source_text = encoding.decode(encode_text)
encode_text = encoding.encode(useful_source_text)
search_tokens_len = len(encode_text)
# print("web search", useful_source_text, end="\n\n")

end_time = record_time.time()
run_time = end_time - start_time
print("urls", url_set_list)
print("pdf", url_pdf_set_list)
print(f"搜索用时:{run_time}秒")
print("search tokens len", search_tokens_len)
text_len = len(encoding.encode(useful_source_text))
print("text len", text_len, "\n\n")

return url_text_list

def get_text_token_len(text):
tiktoken.get_encoding("cl100k_base")
encoding = tiktoken.encoding_for_model(config.GPT_ENGINE)
encode_text = encoding.encode(text)
return len(encode_text)

def cut_message(message: str, max_tokens: int):
tiktoken.get_encoding("cl100k_base")
encoding = tiktoken.encoding_for_model(config.GPT_ENGINE)
encode_text = encoding.encode(message)
if len(encode_text) > max_tokens:
encode_text = encode_text[:max_tokens]
message = encoding.decode(encode_text)
encode_text = encoding.encode(message)
return message, len(encode_text)

def get_search_results(prompt: str, context_max_tokens: int):

url_text_list = get_url_text_list(prompt)
useful_source_text = "\n\n".join(url_text_list)
# useful_source_text = summary_each_url(threads, chainllm)

useful_source_text, search_tokens_len = cut_message(useful_source_text, context_max_tokens)
print("search tokens len", search_tokens_len, "\n\n")

return useful_source_text

if __name__ == "__main__":
Expand Down
39 changes: 22 additions & 17 deletions utils/chatgpt2api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Set

import config
from utils.agent import Web_crawler, get_search_results
from utils.agent import Web_crawler, get_search_results, cut_message, get_url_text_list, get_text_token_len
from utils.function_call import function_call_list

def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]:
Expand Down Expand Up @@ -338,6 +338,7 @@ def __init__(
}
self.function_calls_counter = {}
self.function_call_max_loop = 3
self.encode_web_text_list = []

if self.get_token_count("default") > self.max_tokens:
raise t.ActionRefuseError("System prompt is too long")
Expand Down Expand Up @@ -478,15 +479,6 @@ def get_message_token(self, url, json_post):
else:
raise Exception("Unknown error")

def cut_message(self, message: str, max_tokens: int):
tiktoken.get_encoding("cl100k_base")
encoding = tiktoken.encoding_for_model(self.engine)
encode_text = encoding.encode(message)
if len(encode_text) > max_tokens:
encode_text = encode_text[:max_tokens]
message = encoding.decode(encode_text)
return message

def get_post_body(
self,
prompt: str,
Expand Down Expand Up @@ -606,7 +598,7 @@ def ask_stream(
else:
self.function_calls_counter[function_call_name] += 1
if self.function_calls_counter[function_call_name] <= self.function_call_max_loop:
function_call_max_tokens = self.truncate_limit - message_token["total"] - 100
function_call_max_tokens = self.truncate_limit - message_token["total"] - 1000
if function_call_max_tokens <= 0:
function_call_max_tokens = int(self.truncate_limit / 2)
print("function_call_max_tokens", function_call_max_tokens)
Expand All @@ -617,18 +609,30 @@ def ask_stream(
if self.conversation[convo_id][-1 - index]["role"] == "user":
self.conversation[convo_id][-1 - index]["content"] = self.conversation[convo_id][-1 - index]["content"].replace("search: ", "")
prompt = self.conversation[convo_id][-1 - index]["content"]
prompt = " ".join([prompt, json.loads(full_response)["prompt"]])
prompt = " ".join([prompt, json.loads(full_response)["prompt"].strip()]).strip()
print("\n\nprompt", prompt)
break
# prompt = self.conversation[convo_id][-1]["content"]
# print(self.truncate_limit, self.get_token_count(convo_id), max_context_tokens)
function_response = eval(function_call_name)(prompt, function_call_max_tokens)
function_response = "web search results: \n" + function_response
if self.encode_web_text_list == []:
tiktoken.get_encoding("cl100k_base")
encoding = tiktoken.encoding_for_model(config.GPT_ENGINE)
self.encode_web_text_list = encoding.encode(" ".join(get_url_text_list(prompt)))
print("search len", len(self.encode_web_text_list))
function_response = encoding.decode(self.encode_web_text_list[:function_call_max_tokens])
self.encode_web_text_list = self.encode_web_text_list[function_call_max_tokens:]
# function_response = eval(function_call_name)(prompt, function_call_max_tokens)
function_response = (
"Here is the Search results, inside <Search_results></Search_results> XML tags:"
"<Search_results>"
"{}"
"</Search_results>"
).format(function_response)
user_prompt = f"You need to response the following question: {prompt}. Search results is provided inside <Search_results></Search_results> XML tags. Your task is to think about the question step by step and then answer the above question in {config.LANGUAGE} based on the Search results provided. Please response in {config.LANGUAGE} and adopt a style that is logical, in-depth, and detailed. Note: In order to make the answer appear highly professional, you should be an expert in textual analysis, aiming to make the answer precise and comprehensive. Directly response markdown format, without using markdown code blocks"
self.add_to_conversation(user_prompt, "user", convo_id=convo_id)
if function_call_name == "get_url_content":
url = json.loads(full_response)["url"]
print("\n\nurl", url)
function_response = Web_crawler(url)
function_response = self.cut_message(function_response, function_call_max_tokens)
function_response, text_len = cut_message(function_response, function_call_max_tokens)
else:
function_response = "抱歉,直接告诉用户,无法找到相关信息"
response_role = "function"
Expand All @@ -637,6 +641,7 @@ def ask_stream(
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
self.function_calls_counter = {}
self.clear_function_call(convo_id=convo_id)
self.encode_web_text_list = []
# total_tokens = self.get_token_count(convo_id)

async def ask_stream_async(
Expand Down

0 comments on commit b62e628

Please sign in to comment.