Skip to content

Commit

Permalink
Merge pull request #32 from gomate-community/pipeline
Browse files Browse the repository at this point in the history
bug@fix bm25 tokenized query
  • Loading branch information
yanqiangmiffy authored Jun 25, 2024
2 parents b750d2d + a5ecef3 commit cc2bb8a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
18 changes: 11 additions & 7 deletions examples/retrievers/bm5retriever_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
@time: 2024/6/1 15:48
"""
import os

from gomate.modules.document.common_parser import CommonParser
from gomate.modules.retrieval.bm25_retriever import BM25RetrieverConfig, BM25Retriever, tokenizer

if __name__ == '__main__':
Expand All @@ -26,15 +28,17 @@
root_dir = os.path.abspath(os.path.dirname(__file__))
print(root_dir)
new_files = [
r'H:\Projects\GoMate\data\伊朗.txt',
r'H:\Projects\GoMate\data\伊朗总统罹难事件.txt',
r'H:\Projects\GoMate\data\伊朗总统莱希及多位高级官员遇难的直升机事故.txt',
r'H:\Projects\GoMate\data\伊朗问题.txt',
r'/data/users/searchgpt/yq/GoMate_dev/data/docs/伊朗.txt',
r'/data/users/searchgpt/yq/GoMate_dev/data/docs/伊朗总统罹难事件.txt',
r'/data/users/searchgpt/yq/GoMate_dev/data/docs/伊朗总统莱希及多位高级官员遇难的直升机事故.txt',
r'/data/users/searchgpt/yq/GoMate_dev/data/docs/伊朗问题.txt',
r'/data/users/searchgpt/yq/GoMate_dev/data/docs/新冠肺炎疫情.pdf',
]
parser = CommonParser()
for filename in new_files:
with open(filename, 'r', encoding="utf-8") as file:
corpus.append(file.read())
chunks = parser.parse(filename)
corpus.extend(chunks)
bm25_retriever.build_from_texts(corpus)
query = "伊朗总统莱希"
query = "新冠肺炎疫情"
search_docs = bm25_retriever.retrieve(query)
print(search_docs)
3 changes: 2 additions & 1 deletion gomate/modules/retrieval/bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ def build_from_texts(self, corpus):
raise ValueError('Algorithm not supported')

def retrieve(self, query: str='',top_k:int=3) -> List[Dict]:
tokenized_query = " ".join(self.tokenizer(query))
# tokenized_query = " ".join(self.tokenizer(query))
tokenized_query= self.tokenizer(query)
search_docs = self.bm25.get_top_n(tokenized_query, self.corpus, n=top_k)
return search_docs

Expand Down

0 comments on commit cc2bb8a

Please sign in to comment.