Skip to content

Commit

Permalink
Fix the vulnerability in gpt search after token truncation and add se…
Browse files Browse the repository at this point in the history
…arch engine switching functionality
  • Loading branch information
yym68686 committed Sep 25, 2023
1 parent 17319d5 commit 3f961c6
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 77 deletions.
4 changes: 2 additions & 2 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def search_summary(result, model=config.DEFAULT_SEARCH_MODEL, temperature=config

encoding = tiktoken.encoding_for_model(model)
encode_text = encoding.encode(useful_source_text)
encode_fact_text = encoding.encode(fact_text)

max_token_len = (
30500
Expand All @@ -374,8 +375,7 @@ def search_summary(result, model=config.DEFAULT_SEARCH_MODEL, temperature=config
else 3500
)
if len(encode_text) > max_token_len:
encode_text = encode_text[:max_token_len]
# encode_text = encode_text[:3842]
encode_text = encode_text[:max_token_len-len(encode_fact_text)]
useful_source_text = encoding.decode(encode_text)
encode_text = encoding.encode(useful_source_text)
tokens_len = len(encode_text)
Expand Down
120 changes: 47 additions & 73 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ async def command_bot(update, context, language=None, prompt=translator_prompt,
message = prompt + message
if config.API and message:
await context.bot.send_chat_action(chat_id=update.message.chat_id, action=ChatAction.TYPING)
if config.SEARCH_USE_GPT and "gpt-4" not in title and language == None:
await search(update, context, has_command=False)
else:
await getChatGPT(title, robot, message, update, context)
await getChatGPT(update, context, title, robot, message, config.SEARCH_USE_GPT)
else:
message = await context.bot.send_message(
chat_id=update.message.chat_id,
Expand All @@ -62,10 +59,10 @@ async def command_bot(update, context, language=None, prompt=translator_prompt,
print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m")
await context.bot.send_chat_action(chat_id=update.message.chat_id, action=ChatAction.TYPING)
pdf_file = update.message.reply_to_message.document
# print(pdf_file)

file_id = pdf_file.file_id
new_file = await context.bot.get_file(file_id)
# print(new_file)

file_url = new_file.file_path

question = update.message.text
Expand All @@ -79,14 +76,12 @@ async def command_bot(update, context, language=None, prompt=translator_prompt,
async def reset_chat(update, context):
if config.API:
config.ChatGPTbot.reset(convo_id=str(update.message.chat_id), system_prompt=config.systemprompt)
# if config.API4:
# config.ChatGPT4bot.reset(convo_id=str(update.message.chat_id), system_prompt=config.systemprompt)
await context.bot.send_message(
chat_id=update.message.chat_id,
text="重置成功!",
)

async def getChatGPT(title, robot, message, update, context):
async def getChatGPT(update, context, title, robot, message, use_search=config.SEARCH_USE_GPT):
result = title
text = message
modifytime = 0
Expand All @@ -99,19 +94,34 @@ async def getChatGPT(title, robot, message, update, context):
)
messageid = message.message_id
try:
for data in robot.ask_stream(text, convo_id=str(update.message.chat_id), pass_history=config.PASS_HISTORY):
result = result + data
tmpresult = result
modifytime = modifytime + 1
if re.sub(r"```", '', result).count("`") % 2 != 0:
tmpresult = result + "`"
if result.count("```") % 2 != 0:
tmpresult = result + "\n```"
if modifytime % 20 == 0 and lastresult != tmpresult:
if 'claude2' in title:
tmpresult = re.sub(r",", ',', tmpresult)
await context.bot.edit_message_text(chat_id=update.message.chat_id, message_id=messageid, text=escape(tmpresult), parse_mode='MarkdownV2')
lastresult = tmpresult
if use_search:
for data in search_summary(text, model=config.DEFAULT_SEARCH_MODEL, use_goolge=config.USE_GOOGLE, use_gpt=config.SEARCH_USE_GPT):
result = result + data
tmpresult = result
modifytime = modifytime + 1
if re.sub(r"```", '', result).count("`") % 2 != 0:
tmpresult = result + "`"
if result.count("```") % 2 != 0:
tmpresult = result + "\n```"
if modifytime % 20 == 0 and lastresult != tmpresult:
if 'claude2' in title:
tmpresult = re.sub(r",", ',', tmpresult)
await context.bot.edit_message_text(chat_id=update.message.chat_id, message_id=messageid, text=escape(tmpresult), parse_mode='MarkdownV2')
lastresult = tmpresult
else:
for data in robot.ask_stream(text, convo_id=str(update.message.chat_id), pass_history=config.PASS_HISTORY):
result = result + data
tmpresult = result
modifytime = modifytime + 1
if re.sub(r"```", '', result).count("`") % 2 != 0:
tmpresult = result + "`"
if result.count("```") % 2 != 0:
tmpresult = result + "\n```"
if modifytime % 20 == 0 and lastresult != tmpresult:
if 'claude2' in title:
tmpresult = re.sub(r",", ',', tmpresult)
await context.bot.edit_message_text(chat_id=update.message.chat_id, message_id=messageid, text=escape(tmpresult), parse_mode='MarkdownV2')
lastresult = tmpresult
except Exception as e:
print('\033[31m')
print("response_msg", result)
Expand Down Expand Up @@ -175,7 +185,8 @@ async def delete_message(update, context, messageid, delay=10):

first_buttons = [
[
InlineKeyboardButton("更换模型", callback_data="更换模型"),
InlineKeyboardButton("更换问答模型", callback_data="更换问答模型"),
InlineKeyboardButton("更换搜索模型", callback_data="更换搜索模型"),
],
[
InlineKeyboardButton("历史记录已关闭", callback_data="历史记录"),
Expand Down Expand Up @@ -206,7 +217,10 @@ async def button_press(update, context):
await callback_query.answer()
data = callback_query.data
if ("gpt" or "cluade") in data:
config.GPT_ENGINE = data
if config.ENGINE_FLAG:
config.GPT_ENGINE = data
else:
config.DEFAULT_SEARCH_MODEL = data
if config.API:
config.ChatGPTbot = GPT(api_key=f"{config.API}", engine=config.GPT_ENGINE, system_prompt=config.systemprompt, temperature=config.temperature)
config.ChatGPTbot.reset(convo_id=str(update.effective_chat.id), system_prompt=config.systemprompt)
Expand All @@ -215,10 +229,7 @@ async def button_press(update, context):
f"`Hi, {update.effective_user.username}!`\n\n"
f"**Default engine:** `{config.GPT_ENGINE}`\n"
f"**Default search model:** `{config.DEFAULT_SEARCH_MODEL}`\n"

f"**temperature:** `{config.temperature}`\n"
f"**PASS_HISTORY:** `{config.PASS_HISTORY}`\n"
f"**USE_GOOGLE:** `{config.USE_GOOGLE}`\n\n"
f"**API_URL:** `{config.API_URL}`\n\n"
f"**API:** `{config.API}`\n\n"
f"**WEB_HOOK:** `{config.WEB_HOOK}`\n\n"
Expand All @@ -228,18 +239,23 @@ async def button_press(update, context):
reply_markup=InlineKeyboardMarkup(buttons),
parse_mode='MarkdownV2'
)
# messageid = message.message_id
# thread = threading.Thread(target=run_async, args=(delete_message(update, context, messageid, delay=10),))
# thread.start()
except Exception as e:
logger.info(e)
pass
elif "更换模型" in data:
elif "更换问答模型" in data:
message = await callback_query.edit_message_text(
text=escape(info_message + banner),
reply_markup=InlineKeyboardMarkup(buttons),
parse_mode='MarkdownV2'
)
config.ENGINE_FLAG = True
elif "更换搜索模型" in data:
message = await callback_query.edit_message_text(
text=escape(info_message + banner),
reply_markup=InlineKeyboardMarkup(buttons),
parse_mode='MarkdownV2'
)
config.ENGINE_FLAG = False
elif "返回" in data:
message = await callback_query.edit_message_text(
text=escape(info_message),
Expand Down Expand Up @@ -349,48 +365,6 @@ async def info(update, context):
messageid = message.message_id
await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=update.message.message_id)


async def search(update, context, has_command=True):
if has_command == False or len(context.args) > 0:
message = update.message.text if config.NICK is None else update.message.text[botNicKLength:].strip() if update.message.text[:botNicKLength].lower() == botNick else None
if has_command:
message = ' '.join(context.args)
print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m")
if message:
await context.bot.send_chat_action(chat_id=update.message.chat_id, action=ChatAction.TYPING)
text = message
result = ''
modifytime = 0
lastresult = ''
message = await context.bot.send_message(
chat_id=update.message.chat_id,
text="思考中💭",
parse_mode='MarkdownV2',
reply_to_message_id=update.message.message_id,
)
messageid = message.message_id
for data in search_summary(text, model=config.DEFAULT_SEARCH_MODEL, use_goolge=config.USE_GOOGLE, use_gpt=config.SEARCH_USE_GPT):
result = result + data
tmpresult = result
modifytime = modifytime + 1
if re.sub(r"```", '', result).count("`") % 2 != 0:
tmpresult = result + "`"
if result.count("```") % 2 != 0:
tmpresult = result + "\n```"
if modifytime % 20 == 0 and lastresult != tmpresult:
await context.bot.edit_message_text(chat_id=update.message.chat_id, message_id=messageid, text=escape(tmpresult), parse_mode='MarkdownV2')
lastresult = tmpresult
print(result)
if lastresult != result and messageid:
await context.bot.edit_message_text(chat_id=update.message.chat_id, message_id=messageid, text=escape(result), parse_mode='MarkdownV2')
else:
message = await context.bot.send_message(
chat_id=update.message.chat_id,
text="请在命令后面放入文本。",
parse_mode='MarkdownV2',
reply_to_message_id=update.message.message_id,
)

from agent import pdfQA, getmd5, persist_emdedding_pdf
async def handle_pdf(update, context):
# 获取接收到的文件
Expand Down
4 changes: 2 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
PORT = int(os.environ.get('PORT', '8080'))
NICK = os.environ.get('NICK', None)
API = os.environ.get('API', None)
# API4 = os.environ.get('API4', None)
PASS_HISTORY = (os.environ.get('PASS_HISTORY', "False") == "False") == False
USE_GOOGLE = (os.environ.get('USE_GOOGLE', "True") == "False") == False
if os.environ.get('GOOGLE_API_KEY', None) == None and os.environ.get('GOOGLE_CSE_ID', None) == None:
USE_GOOGLE = False
temperature = float(os.environ.get('temperature', '0.5'))
ENGINE_FLAG = True
GPT_ENGINE = os.environ.get('GPT_ENGINE', 'gpt-3.5-turbo')
DEFAULT_SEARCH_MODEL = os.environ.get('DEFAULT_SEARCH_MODEL', 'gpt-3.5-turbo-16k')
SEARCH_USE_GPT = (os.environ.get('SEARCH_USE_GPT', "True") == "False") == False
GPT_ENGINE = os.environ.get('GPT_ENGINE', 'gpt-3.5-turbo')
API_URL = os.environ.get('API_URL', 'https://api.openai.com/v1/chat/completions')
PDF_EMBEDDING = (os.environ.get('PDF_EMBEDDING', "True") == "False") == False

Expand Down

0 comments on commit 3f961c6

Please sign in to comment.