-
Notifications
You must be signed in to change notification settings - Fork 52
/
rag.py
71 lines (61 loc) · 2.44 KB
/
rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#!/usr/bin/env python
# -*- coding:utf-8 _*-
"""
@author:quincy qiang
@license: Apache Licence
@file: RagApplication.py
@time: 2024/05/20
@contact: yanqiangmiffy@gamil.com
"""
import os
from trustrag.modules.citation.match_citation import MatchCitation
from trustrag.modules.document.common_parser import CommonParser
from trustrag.modules.generator.llm import GLMChat
from trustrag.modules.reranker.bge_reranker import BgeReranker
from trustrag.modules.retrieval.dense_retriever import DenseRetriever
class ApplicationConfig():
def __init__(self):
self.retriever_config = None
self.rerank_config = None
class RagApplication():
def __init__(self, config):
self.config = config
self.parser = CommonParser()
self.retriever = DenseRetriever(self.config.retriever_config)
self.reranker = BgeReranker(self.config.rerank_config)
self.llm = GLMChat(self.config.llm_model_path)
self.mc = MatchCitation()
def init_vector_store(self):
"""
"""
print("init_vector_store ... ")
chunks = []
for filename in os.listdir(self.config.docs_path):
file_path = os.path.join(self.config.docs_path, filename)
try:
chunks.extend(self.parser.parse(file_path))
except:
pass
self.retriever.build_from_texts(chunks)
print("init_vector_store done! ")
self.retriever.save_index(self.config.retriever_config.index_path)
def load_vector_store(self):
self.retriever.load_index(self.config.retriever_config.index_path)
def add_document(self, file_path):
chunks = self.parser.parse(file_path)
for chunk in chunks:
self.retriever.add_text(chunk)
print("add_document done!")
def chat(self, question: str = '', top_k: int = 5):
contents = self.retriever.retrieve(query=question, top_k=top_k)
contents = self.reranker.rerank(query=question, documents=[content['text'] for content in contents])
content = '\n'.join([content['text'] for content in contents])
print(contents)
result, history = self.llm.chat(question, [], content)
# result = self.mc.ground_response(
# response=response,
# evidences=[content['text'] for content in contents],
# selected_idx=[idx for idx in range(len(contents))],
# markdown=True
# )
return result, history, contents