From a5ecef353fdfdabf6a16aeaafc0c284d1a8a1e6f Mon Sep 17 00:00:00 2001 From: yanqiangmiffy <1185918903@qq.com> Date: Tue, 25 Jun 2024 21:44:52 +0800 Subject: [PATCH] bug@fix bm25 tokenized query --- examples/retrievers/bm5retriever_example.py | 18 +++++++++++------- gomate/modules/retrieval/bm25_retriever.py | 3 ++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/retrievers/bm5retriever_example.py b/examples/retrievers/bm5retriever_example.py index d729aff..d3f09d0 100644 --- a/examples/retrievers/bm5retriever_example.py +++ b/examples/retrievers/bm5retriever_example.py @@ -8,6 +8,8 @@ @time: 2024/6/1 15:48 """ import os + +from gomate.modules.document.common_parser import CommonParser from gomate.modules.retrieval.bm25_retriever import BM25RetrieverConfig, BM25Retriever, tokenizer if __name__ == '__main__': @@ -26,15 +28,17 @@ root_dir = os.path.abspath(os.path.dirname(__file__)) print(root_dir) new_files = [ - r'H:\Projects\GoMate\data\伊朗.txt', - r'H:\Projects\GoMate\data\伊朗总统罹难事件.txt', - r'H:\Projects\GoMate\data\伊朗总统莱希及多位高级官员遇难的直升机事故.txt', - r'H:\Projects\GoMate\data\伊朗问题.txt', + r'/data/users/searchgpt/yq/GoMate_dev/data/docs/伊朗.txt', + r'/data/users/searchgpt/yq/GoMate_dev/data/docs/伊朗总统罹难事件.txt', + r'/data/users/searchgpt/yq/GoMate_dev/data/docs/伊朗总统莱希及多位高级官员遇难的直升机事故.txt', + r'/data/users/searchgpt/yq/GoMate_dev/data/docs/伊朗问题.txt', + r'/data/users/searchgpt/yq/GoMate_dev/data/docs/新冠肺炎疫情.pdf', ] + parser = CommonParser() for filename in new_files: - with open(filename, 'r', encoding="utf-8") as file: - corpus.append(file.read()) + chunks = parser.parse(filename) + corpus.extend(chunks) bm25_retriever.build_from_texts(corpus) - query = "伊朗总统莱希" + query = "新冠肺炎疫情" search_docs = bm25_retriever.retrieve(query) print(search_docs) diff --git a/gomate/modules/retrieval/bm25_retriever.py b/gomate/modules/retrieval/bm25_retriever.py index 6567146..4ffdee2 100644 --- a/gomate/modules/retrieval/bm25_retriever.py +++ b/gomate/modules/retrieval/bm25_retriever.py @@ -264,7 +264,8 @@ def build_from_texts(self, corpus): raise ValueError('Algorithm not supported') def retrieve(self, query: str='',top_k:int=3) -> List[Dict]: - tokenized_query = " ".join(self.tokenizer(query)) + # tokenized_query = " ".join(self.tokenizer(query)) + tokenized_query= self.tokenizer(query) search_docs = self.bm25.get_top_n(tokenized_query, self.corpus, n=top_k) return search_docs