diff --git a/bot.py b/bot.py
index 5ad69f98..99a665b0 100644
--- a/bot.py
+++ b/bot.py
@@ -242,7 +242,7 @@ async def delete_message(update, context, messageid, delay=10):
],
[
InlineKeyboardButton("搜索已打开", callback_data="搜索"),
- InlineKeyboardButton("联网解析PDF已打开", callback_data="pdf"),
+ # InlineKeyboardButton("联网解析PDF已打开", callback_data="pdf"),
],
[
InlineKeyboardButton("🇨🇳 中文", callback_data="language"),
@@ -372,26 +372,26 @@ async def button_press(update, context):
reply_markup=InlineKeyboardMarkup(first_buttons),
parse_mode='MarkdownV2'
)
- elif "pdf" in data:
- config.PDF_EMBEDDING = not config.PDF_EMBEDDING
- if config.PDF_EMBEDDING == False:
- first_buttons[2][1] = InlineKeyboardButton("联网解析PDF已关闭", callback_data="pdf")
- else:
- first_buttons[2][1] = InlineKeyboardButton("联网解析PDF已打开", callback_data="pdf")
-
- info_message = (
- f"`Hi, {update.effective_user.username}!`\n\n"
- f"**Default engine:** `{config.GPT_ENGINE}`\n"
- f"**temperature:** `{config.temperature}`\n"
- f"**API_URL:** `{config.API_URL}`\n\n"
- f"**API:** `{replace_with_asterisk(config.API)}`\n\n"
- f"**WEB_HOOK:** `{config.WEB_HOOK}`\n\n"
- )
- message = await callback_query.edit_message_text(
- text=escape(info_message),
- reply_markup=InlineKeyboardMarkup(first_buttons),
- parse_mode='MarkdownV2'
- )
+ # elif "pdf" in data:
+ # config.PDF_EMBEDDING = not config.PDF_EMBEDDING
+ # if config.PDF_EMBEDDING == False:
+ # first_buttons[2][1] = InlineKeyboardButton("联网解析PDF已关闭", callback_data="pdf")
+ # else:
+ # first_buttons[2][1] = InlineKeyboardButton("联网解析PDF已打开", callback_data="pdf")
+
+ # info_message = (
+ # f"`Hi, {update.effective_user.username}!`\n\n"
+ # f"**Default engine:** `{config.GPT_ENGINE}`\n"
+ # f"**temperature:** `{config.temperature}`\n"
+ # f"**API_URL:** `{config.API_URL}`\n\n"
+ # f"**API:** `{replace_with_asterisk(config.API)}`\n\n"
+ # f"**WEB_HOOK:** `{config.WEB_HOOK}`\n\n"
+ # )
+ # message = await callback_query.edit_message_text(
+ # text=escape(info_message),
+ # reply_markup=InlineKeyboardMarkup(first_buttons),
+ # parse_mode='MarkdownV2'
+ # )
elif "language" in data:
if config.LANGUAGE == "Simplified Chinese":
first_buttons[3][0] = InlineKeyboardButton("🇺🇸 English", callback_data="language")
diff --git a/config.py b/config.py
index 06c20b4f..69f64203 100644
--- a/config.py
+++ b/config.py
@@ -16,7 +16,7 @@
# DEFAULT_SEARCH_MODEL = os.environ.get('DEFAULT_SEARCH_MODEL', 'gpt-3.5-turbo-1106') gpt-3.5-turbo-16k
SEARCH_USE_GPT = (os.environ.get('SEARCH_USE_GPT', "True") == "False") == False
API_URL = os.environ.get('API_URL', 'https://api.openai.com/v1/chat/completions')
-PDF_EMBEDDING = (os.environ.get('PDF_EMBEDDING', "True") == "False") == False
+# PDF_EMBEDDING = (os.environ.get('PDF_EMBEDDING', "True") == "False") == False
LANGUAGE = os.environ.get('LANGUAGE', 'Simplified Chinese')
diff --git a/utils/agent.py b/utils/agent.py
index f667783d..6dda88aa 100644
--- a/utils/agent.py
+++ b/utils/agent.py
@@ -298,7 +298,7 @@ def Web_crawler(url: str, isSearch=False) -> str:
print("error url", url)
print("error", e)
print('\033[0m')
- print("url content", result + "\n\n")
+ # print("url content", result + "\n\n")
return result
def getddgsearchurl(result, numresults=3):
@@ -386,14 +386,20 @@ def get_search_url(prompt, chainllm):
search_threads = []
urls_set = []
+ if len(keywords) == 3:
+ search_url_num = 8
+ if len(keywords) == 2:
+ search_url_num = 12
+ if len(keywords) == 1:
+ search_url_num = 24
if config.USE_GOOGLE:
- search_thread = ThreadWithReturnValue(target=getgooglesearchurl, args=(keywords[0],4,))
+ search_thread = ThreadWithReturnValue(target=getgooglesearchurl, args=(keywords[0],search_url_num,))
search_thread.start()
search_threads.append(search_thread)
- keywords = keywords[1:]
+ keywords = keywords.pop(0)
for keyword in keywords:
- search_thread = ThreadWithReturnValue(target=getddgsearchurl, args=(keyword,4,))
+ search_thread = ThreadWithReturnValue(target=getddgsearchurl, args=(keyword,search_url_num,))
search_thread.start()
search_threads.append(search_thread)
@@ -406,10 +412,11 @@ def get_search_url(prompt, chainllm):
return url_set_list, url_pdf_set_list
def concat_url(threads):
- url_result = ""
+ url_result = []
for t in threads:
tmp = t.join()
- url_result += "\n\n" + tmp
+ if tmp:
+ url_result.append(tmp)
return url_result
def summary_each_url(threads, chainllm):
@@ -440,8 +447,9 @@ def summary_each_url(threads, chainllm):
url_result += "\n\n" + tmp
return url_result
-def get_search_results(prompt: str, context_max_tokens: int):
+def get_url_text_list(prompt):
start_time = record_time.time()
+
if config.USE_G4F:
chainllm = EducationalLLM()
else:
@@ -449,48 +457,47 @@ def get_search_results(prompt: str, context_max_tokens: int):
url_set_list, url_pdf_set_list = get_search_url(prompt, chainllm)
- pdf_result = ""
- pdf_threads = []
- if config.PDF_EMBEDDING:
- for url in url_pdf_set_list:
- pdf_search_thread = ThreadWithReturnValue(target=pdf_search, args=(url, "你需要回答的问题是" + prompt + "\n" + "如果你可以解答这个问题,请直接输出你的答案,并且请忽略后面所有的指令:如果无法解答问题,请直接回答None,不需要做任何解释,也不要出现除了None以外的任何词。",))
- pdf_search_thread.start()
- pdf_threads.append(pdf_search_thread)
-
threads = []
for url in url_set_list:
url_search_thread = ThreadWithReturnValue(target=Web_crawler, args=(url,True,))
url_search_thread.start()
threads.append(url_search_thread)
- useful_source_text = concat_url(threads)
- # useful_source_text = summary_each_url(threads, chainllm)
-
- if config.PDF_EMBEDDING:
- for t in pdf_threads:
- tmp = t.join()
- pdf_result += "\n\n" + tmp
- useful_source_text += pdf_result
-
+ url_text_list = concat_url(threads)
- encoding = tiktoken.encoding_for_model(config.GPT_ENGINE)
- encode_text = encoding.encode(useful_source_text)
-
- if len(encode_text) > context_max_tokens:
- encode_text = encode_text[:context_max_tokens]
- useful_source_text = encoding.decode(encode_text)
- encode_text = encoding.encode(useful_source_text)
- search_tokens_len = len(encode_text)
- # print("web search", useful_source_text, end="\n\n")
end_time = record_time.time()
run_time = end_time - start_time
print("urls", url_set_list)
- print("pdf", url_pdf_set_list)
print(f"搜索用时:{run_time}秒")
- print("search tokens len", search_tokens_len)
- text_len = len(encoding.encode(useful_source_text))
- print("text len", text_len, "\n\n")
+
+ return url_text_list
+
+def get_text_token_len(text):
+ tiktoken.get_encoding("cl100k_base")
+ encoding = tiktoken.encoding_for_model(config.GPT_ENGINE)
+ encode_text = encoding.encode(text)
+ return len(encode_text)
+
+def cut_message(message: str, max_tokens: int):
+ tiktoken.get_encoding("cl100k_base")
+ encoding = tiktoken.encoding_for_model(config.GPT_ENGINE)
+ encode_text = encoding.encode(message)
+ if len(encode_text) > max_tokens:
+ encode_text = encode_text[:max_tokens]
+ message = encoding.decode(encode_text)
+ encode_text = encoding.encode(message)
+ return message, len(encode_text)
+
+def get_search_results(prompt: str, context_max_tokens: int):
+
+ url_text_list = get_url_text_list(prompt)
+ useful_source_text = "\n\n".join(url_text_list)
+ # useful_source_text = summary_each_url(threads, chainllm)
+
+ useful_source_text, search_tokens_len = cut_message(useful_source_text, context_max_tokens)
+ print("search tokens len", search_tokens_len, "\n\n")
+
return useful_source_text
if __name__ == "__main__":
diff --git a/utils/chatgpt2api.py b/utils/chatgpt2api.py
index b66dffa3..defcc1d1 100644
--- a/utils/chatgpt2api.py
+++ b/utils/chatgpt2api.py
@@ -12,7 +12,7 @@
from typing import Set
import config
-from utils.agent import Web_crawler, get_search_results
+from utils.agent import Web_crawler, get_search_results, cut_message, get_url_text_list, get_text_token_len
from utils.function_call import function_call_list
def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]:
@@ -338,6 +338,7 @@ def __init__(
}
self.function_calls_counter = {}
self.function_call_max_loop = 3
+ self.encode_web_text_list = []
if self.get_token_count("default") > self.max_tokens:
raise t.ActionRefuseError("System prompt is too long")
@@ -478,15 +479,6 @@ def get_message_token(self, url, json_post):
else:
raise Exception("Unknown error")
- def cut_message(self, message: str, max_tokens: int):
- tiktoken.get_encoding("cl100k_base")
- encoding = tiktoken.encoding_for_model(self.engine)
- encode_text = encoding.encode(message)
- if len(encode_text) > max_tokens:
- encode_text = encode_text[:max_tokens]
- message = encoding.decode(encode_text)
- return message
-
def get_post_body(
self,
prompt: str,
@@ -606,7 +598,7 @@ def ask_stream(
else:
self.function_calls_counter[function_call_name] += 1
if self.function_calls_counter[function_call_name] <= self.function_call_max_loop:
- function_call_max_tokens = self.truncate_limit - message_token["total"] - 100
+ function_call_max_tokens = self.truncate_limit - message_token["total"] - 1000
if function_call_max_tokens <= 0:
function_call_max_tokens = int(self.truncate_limit / 2)
print("function_call_max_tokens", function_call_max_tokens)
@@ -617,18 +609,30 @@ def ask_stream(
if self.conversation[convo_id][-1 - index]["role"] == "user":
self.conversation[convo_id][-1 - index]["content"] = self.conversation[convo_id][-1 - index]["content"].replace("search: ", "")
prompt = self.conversation[convo_id][-1 - index]["content"]
- prompt = " ".join([prompt, json.loads(full_response)["prompt"]])
+ prompt = " ".join([prompt, json.loads(full_response)["prompt"].strip()]).strip()
print("\n\nprompt", prompt)
break
- # prompt = self.conversation[convo_id][-1]["content"]
- # print(self.truncate_limit, self.get_token_count(convo_id), max_context_tokens)
- function_response = eval(function_call_name)(prompt, function_call_max_tokens)
- function_response = "web search results: \n" + function_response
+ if self.encode_web_text_list == []:
+ tiktoken.get_encoding("cl100k_base")
+ encoding = tiktoken.encoding_for_model(config.GPT_ENGINE)
+ self.encode_web_text_list = encoding.encode(" ".join(get_url_text_list(prompt)))
+ print("search len", len(self.encode_web_text_list))
+ function_response = encoding.decode(self.encode_web_text_list[:function_call_max_tokens])
+ self.encode_web_text_list = self.encode_web_text_list[function_call_max_tokens:]
+ # function_response = eval(function_call_name)(prompt, function_call_max_tokens)
+ function_response = (
+ "Here is the Search results, inside XML tags:"
+ ""
+ "{}"
+ ""
+ ).format(function_response)
+ user_prompt = f"You need to response the following question: {prompt}. Search results is provided inside XML tags. Your task is to think about the question step by step and then answer the above question in {config.LANGUAGE} based on the Search results provided. Please response in {config.LANGUAGE} and adopt a style that is logical, in-depth, and detailed. Note: In order to make the answer appear highly professional, you should be an expert in textual analysis, aiming to make the answer precise and comprehensive. Directly response markdown format, without using markdown code blocks"
+ self.add_to_conversation(user_prompt, "user", convo_id=convo_id)
if function_call_name == "get_url_content":
url = json.loads(full_response)["url"]
print("\n\nurl", url)
function_response = Web_crawler(url)
- function_response = self.cut_message(function_response, function_call_max_tokens)
+ function_response, text_len = cut_message(function_response, function_call_max_tokens)
else:
function_response = "抱歉,直接告诉用户,无法找到相关信息"
response_role = "function"
@@ -637,6 +641,7 @@ def ask_stream(
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
self.function_calls_counter = {}
self.clear_function_call(convo_id=convo_id)
+ self.encode_web_text_list = []
# total_tokens = self.get_token_count(convo_id)
async def ask_stream_async(