Skip to content

Commit

Permalink
Merge pull request #35 from gomate-community/pipeline
Browse files Browse the repository at this point in the history
features@rewriter:hyde
  • Loading branch information
yanqiangmiffy authored Jun 26, 2024
2 parents 5468e4f + 1c11e63 commit 8dd06ca
Show file tree
Hide file tree
Showing 7 changed files with 587 additions and 18 deletions.
4 changes: 2 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

# 修改成自己的配置!!!
app_config = ApplicationConfig()
app_config.docs_path = "./docs/"
app_config.docs_path = "/data/users/searchgpt/yq/GoMate_dev/data/docs/"
app_config.llm_model_path = "/data/users/searchgpt/pretrained_models/glm-4-9b-chat"

retriever_config = DenseRetrieverConfig(
Expand Down Expand Up @@ -155,7 +155,7 @@ def predict(input,
value='知识库问答',
interactive=False)

kg_name = gr.Radio(["伊朗新闻"],
kg_name = gr.Radio(["文档知识库"],
label="知识库",
value=None,
info="使用知识库问答,请加载知识库",
Expand Down
258 changes: 258 additions & 0 deletions docs/rewriter.md

Large diffs are not rendered by default.

102 changes: 102 additions & 0 deletions examples/rewriters/hyde_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os

import pandas as pd
from tqdm import tqdm

from gomate.modules.document.common_parser import CommonParser
from gomate.modules.generator.llm import GLMChat
from gomate.modules.retrieval.dense_retriever import DenseRetriever, DenseRetrieverConfig
from gomate.modules.rewriter.hyde_rewriter import HydeRewriter
from gomate.modules.rewriter.promptor import Promptor

if __name__ == '__main__':
promptor = Promptor(task="WEB_SEARCH", language="zh")

retriever_config = DenseRetrieverConfig(
model_name_or_path="/data/users/searchgpt/pretrained_models/bge-large-zh-v1.5",
dim=1024,
index_dir='/data/users/searchgpt/yq/GoMate/examples/retrievers/dense_cache'
)
config_info = retriever_config.log_config()
retriever = DenseRetriever(config=retriever_config)
parser = CommonParser()

chunks = []
docs_path = '/data/users/searchgpt/yq/GoMate_dev/data/docs'
for filename in os.listdir(docs_path):
file_path = os.path.join(docs_path, filename)
try:
chunks.extend(parser.parse(file_path))
except:
pass
retriever.build_from_texts(chunks)

data = pd.read_json('/data/users/searchgpt/yq/GoMate/data/docs/zh_refine.json', lines=True)[:5]
for documents in tqdm(data['positive'], total=len(data)):
for document in documents:
retriever.add_text(document)
for documents in tqdm(data['negative'], total=len(data)):
for document in documents:
retriever.add_text(document)

print("init_vector_store done! ")
generator = GLMChat("/data/users/searchgpt/pretrained_models/glm-4-9b-chat")

hyde = HydeRewriter(promptor, generator, retriever)
hypothesis_document = hyde.rewrite("RCEP具体包括哪些国家")
print("==================hypothesis_document=================\n")
print(hypothesis_document)
hyde_result = hyde.retrieve("RCEP具体包括哪些国家")
print("==================hyde_result=================\n")
print(hyde_result['retrieve_result'])
dense_result = retriever.retrieve("RCEP具体包括哪些国家")
print("==================dense_result=================\n")
print(dense_result)
hyde_answer, _ = generator.chat(prompt="RCEP具体包括哪些国家",
content='\n'.join([doc['text'] for doc in hyde_result['retrieve_result']]))
print("==================hyde_answer=================\n")
print(hyde_answer)
dense_answer, _ = generator.chat(prompt="RCEP具体包括哪些国家",
content='\n'.join([doc['text'] for doc in dense_result]))
print("==================dense_answer=================\n")
print(dense_answer)

print("****" * 20)

hypothesis_document = hyde.rewrite("数据集类型有哪些?")
print("==================hypothesis_document=================\n")
print(hypothesis_document)
hyde_result = hyde.retrieve("数据集类型有哪些?")
print("==================hyde_result=================\n")
print(hyde_result['retrieve_result'])
dense_result = retriever.retrieve("数据集类型有哪些?")
print("==================dense_result=================\n")
print(dense_result)
hyde_answer, _ = generator.chat(prompt="数据集类型有哪些?",
content='\n'.join([doc['text'] for doc in hyde_result['retrieve_result']]))
print("==================hyde_answer=================\n")
print(hyde_answer)
dense_answer, _ = generator.chat(prompt="数据集类型有哪些?",
content='\n'.join([doc['text'] for doc in dense_result]))
print("==================dense_answer=================\n")
print(dense_answer)

print("****" * 20)

hypothesis_document = hyde.rewrite("Sklearn可以使用的数据集有哪些?")
print("==================hypothesis_document=================\n")
print(hypothesis_document)
hyde_result = hyde.retrieve("Sklearn可以使用的数据集有哪些?")
print("==================hyde_result=================\n")
print(hyde_result['retrieve_result'])
dense_result = retriever.retrieve("Sklearn可以使用的数据集有哪些?")
print("==================dense_result=================\n")
print(dense_result)
hyde_answer, _ = generator.chat(prompt="Sklearn可以使用的数据集有哪些?",
content='\n'.join([doc['text'] for doc in hyde_result['retrieve_result']]))
print("==================hyde_answer=================\n")
print(hyde_answer)
dense_answer, _ = generator.chat(prompt="Sklearn可以使用的数据集有哪些?",
content='\n'.join([doc['text'] for doc in dense_result]))
print("==================dense_answer=================\n")
print(dense_answer)
13 changes: 9 additions & 4 deletions gomate/applications/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
"""
import os

from gomate.modules.citation.match_citation import MatchCitation
from gomate.modules.document.common_parser import CommonParser
from gomate.modules.generator.llm import GLMChat
from gomate.modules.reranker.bge_reranker import BgeReranker
from gomate.modules.retrieval.dense_retriever import DenseRetriever
from gomate.modules.citation.match_citation import MatchCitation


class ApplicationConfig():
def __init__(self):
Expand All @@ -28,7 +29,8 @@ def __init__(self, config):
self.retriever = DenseRetriever(self.config.retriever_config)
self.reranker = BgeReranker(self.config.rerank_config)
self.llm = GLMChat(self.config.llm_model_path)
self.mc=MatchCitation()
self.mc = MatchCitation()

def init_vector_store(self):
"""
Expand All @@ -37,7 +39,10 @@ def init_vector_store(self):
chunks = []
for filename in os.listdir(self.config.docs_path):
file_path = os.path.join(self.config.docs_path, filename)
chunks.extend(self.parser.parse(file_path))
try:
chunks.extend(self.parser.parse(file_path))
except:
pass
self.retriever.build_from_texts(chunks)
print("init_vector_store done! ")
self.retriever.save_index(self.config.retriever_config.index_dir)
Expand All @@ -53,7 +58,7 @@ def add_document(self, file_path):

def chat(self, question: str = '', top_k: int = 5):
contents = self.retriever.retrieve(query=question, top_k=top_k)
contents=self.reranker.rerank(query=question,documents=[content['text'] for content in contents])
contents = self.reranker.rerank(query=question, documents=[content['text'] for content in contents])
content = '\n'.join([content['text'] for content in contents])
print(contents)
response, history = self.llm.chat(question, [], content)
Expand Down
22 changes: 13 additions & 9 deletions gomate/modules/generator/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
@description: coding..
"""
import os
from typing import Dict, List, Optional, Tuple, Union, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Dict, List, Any

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from openai import OpenAI
from transformers import AutoTokenizer, AutoModelForCausalLM

PROMPT_TEMPLATE = dict(
RAG_PROMPT_TEMPALTE="""使用以上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
Expand Down Expand Up @@ -92,18 +91,22 @@ def chat(self, prompt: str, history: List = [], content: str = '') -> str:
return response

def load_model(self):

self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16,
trust_remote_code=True).cuda()


class GLMChat(BaseModel):
def __init__(self, path: str = '') -> None:
super().__init__(path)
self.load_model()

def chat(self, prompt: str, history: List = [], content: str = '') -> tuple[Any, Any]:
prompt = PROMPT_TEMPLATE['GLM_PROMPT_TEMPALTE'].format(question=prompt, context=content)
response, history = self.model.chat(self.tokenizer, prompt, history)
def chat(self, prompt: str, history: List = [], content: str = '', llm_only: bool = False) -> tuple[Any, Any]:
if llm_only:
prompt = prompt
else:
prompt = PROMPT_TEMPLATE['GLM_PROMPT_TEMPALTE'].format(question=prompt, context=content)
response, history = self.model.chat(self.tokenizer, prompt, history,max_length= 32000, num_beams=1, do_sample=True, top_p=0.8, temperature=0.2,)
return response, history

def load_model(self):
Expand All @@ -112,6 +115,7 @@ def load_model(self):
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16,
trust_remote_code=True).cuda()


class DashscopeChat(BaseModel):
def __init__(self, path: str = '', model: str = "qwen-turbo") -> None:
super().__init__(path)
Expand Down Expand Up @@ -148,4 +152,4 @@ def chat(self, prompt: str, history: List[Dict], content: str) -> str:
max_tokens=150,
temperature=0.1
)
return response.choices[0].message
return response.choices[0].message
122 changes: 119 additions & 3 deletions gomate/modules/rewriter/hyde_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,127 @@
@contact:1185918903@qq.com
@license: Apache Licence
@time: 2024/5/31 1:23
@reference:https://github.com/texttron/hyde/blob/main/src/hyde/generator.py
"""
import pandas as pd
from tqdm import tqdm
import os
from gomate.modules.generator.llm import GLMChat
from gomate.modules.retrieval.dense_retriever import DenseRetriever, DenseRetrieverConfig
from gomate.modules.rewriter.base import BaseRewriter
from gomate.modules.rewriter.promptor import Promptor
from gomate.modules.document.common_parser import CommonParser

class HydeRewriter(BaseRewriter):
def __init__(self, promptor, generator, retriever):
self.promptor = promptor
self.generator = generator
self.retriever = retriever

def prompt(self, query):
return self.promptor.build_prompt(query)

class HydeRewriter(BaseRewriter):
def __init__(self):
pass
def rewrite(self, query):
prompt = self.promptor.build_prompt(query)
hypothesis_document, _ = self.generator.chat(prompt, llm_only=True)
return hypothesis_document

def retrieve(self, query, top_k=5):
hypothesis_document = self.rewrite(query)
hits = self.retriever.retrieve(hypothesis_document, top_k=top_k)
return {'hypothesis_document': hypothesis_document, 'retrieve_result': hits}


if __name__ == '__main__':
promptor = Promptor(task="WEB_SEARCH", language="zh")

retriever_config = DenseRetrieverConfig(
model_name_or_path="/data/users/searchgpt/pretrained_models/bge-large-zh-v1.5",
dim=1024,
index_dir='/data/users/searchgpt/yq/GoMate/examples/retrievers/dense_cache'
)
config_info = retriever_config.log_config()
retriever = DenseRetriever(config=retriever_config)
parser = CommonParser()


chunks = []
docs_path='/data/users/searchgpt/yq/GoMate_dev/data/docs'
for filename in os.listdir(docs_path):
file_path = os.path.join(docs_path, filename)
try:
chunks.extend(parser.parse(file_path))
except:
pass
retriever.build_from_texts(chunks)

data = pd.read_json('/data/users/searchgpt/yq/GoMate/data/docs/zh_refine.json', lines=True)[:5]
for documents in tqdm(data['positive'], total=len(data)):
for document in documents:
retriever.add_text(document)
for documents in tqdm(data['negative'], total=len(data)):
for document in documents:
retriever.add_text(document)

print("init_vector_store done! ")
generator = GLMChat("/data/users/searchgpt/pretrained_models/glm-4-9b-chat")

hyde = HydeRewriter(promptor, generator, retriever)
hypothesis_document = hyde.rewrite("RCEP具体包括哪些国家")
print("==================hypothesis_document=================\n")
print(hypothesis_document)
hyde_result = hyde.retrieve("RCEP具体包括哪些国家")
print("==================hyde_result=================\n")
print(hyde_result['retrieve_result'])
dense_result = retriever.retrieve("RCEP具体包括哪些国家")
print("==================dense_result=================\n")
print(dense_result)
hyde_answer, _ = generator.chat(prompt="RCEP具体包括哪些国家",
content='\n'.join([doc['text'] for doc in hyde_result['retrieve_result']]))
print("==================hyde_answer=================\n")
print(hyde_answer)
dense_answer, _ = generator.chat(prompt="RCEP具体包括哪些国家",
content='\n'.join([doc['text'] for doc in dense_result]))
print("==================dense_answer=================\n")
print(dense_answer)

print("****"*20)

hypothesis_document = hyde.rewrite("数据集类型有哪些?")
print("==================hypothesis_document=================\n")
print(hypothesis_document)
hyde_result = hyde.retrieve("数据集类型有哪些?")
print("==================hyde_result=================\n")
print(hyde_result['retrieve_result'])
dense_result = retriever.retrieve("数据集类型有哪些?")
print("==================dense_result=================\n")
print(dense_result)
hyde_answer, _ = generator.chat(prompt="数据集类型有哪些?",
content='\n'.join([doc['text'] for doc in hyde_result['retrieve_result']]))
print("==================hyde_answer=================\n")
print(hyde_answer)
dense_answer, _ = generator.chat(prompt="数据集类型有哪些?",
content='\n'.join([doc['text'] for doc in dense_result]))
print("==================dense_answer=================\n")
print(dense_answer)

print("****"*20)


hypothesis_document = hyde.rewrite("Sklearn可以使用的数据集有哪些?")
print("==================hypothesis_document=================\n")
print(hypothesis_document)
hyde_result = hyde.retrieve("Sklearn可以使用的数据集有哪些?")
print("==================hyde_result=================\n")
print(hyde_result['retrieve_result'])
dense_result = retriever.retrieve("Sklearn可以使用的数据集有哪些?")
print("==================dense_result=================\n")
print(dense_result)
hyde_answer, _ = generator.chat(prompt="Sklearn可以使用的数据集有哪些?",
content='\n'.join([doc['text'] for doc in hyde_result['retrieve_result']]))
print("==================hyde_answer=================\n")
print(hyde_answer)
dense_answer, _ = generator.chat(prompt="Sklearn可以使用的数据集有哪些?",
content='\n'.join([doc['text'] for doc in dense_result]))
print("==================dense_answer=================\n")
print(dense_answer)
Loading

0 comments on commit 8dd06ca

Please sign in to comment.