Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yanqiangmiffy committed Sep 9, 2024
1 parent be2e483 commit 9382511
Show file tree
Hide file tree
Showing 10 changed files with 438 additions and 144 deletions.
156 changes: 156 additions & 0 deletions docs/retrieval.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
## 检索器

### BM25Retriever

> 基于[`bm25s`](https://github.com/xhluca/bm25s)实现
参数说明:

- `method`:bm25算法:'robertson', 'lucene', 'atire', 'bm25l', 'bm25+'
- `index_path`:向量维度

```python
from gomate.modules.document.common_parser import CommonParser
from gomate.modules.document.utils import PROJECT_BASE
from gomate.modules.retrieval.bm25s_retriever import BM25RetrieverConfig, BM25Retriever

if __name__ == '__main__':

corpus = []

new_files = [
f'{PROJECT_BASE}/data/docs/伊朗.txt',
f'{PROJECT_BASE}/data/docs/伊朗总统罹难事件.txt',
f'{PROJECT_BASE}/data/docs/伊朗总统莱希及多位高级官员遇难的直升机事故.txt',
f'{PROJECT_BASE}/data/docs/伊朗问题.txt',
f'{PROJECT_BASE}/data/docs/汽车操作手册.pdf',
# r'H:\2024-Xfyun-RAG\data\corpus.txt'
]
parser = CommonParser()
for filename in new_files:
chunks = parser.parse(filename)
corpus.extend(chunks)

bm25_config = BM25RetrieverConfig(method='lucene', index_path='indexs/description_bm25.index', k1=1.6, b=0.7)
bm25_config.validate()
print(bm25_config.log_config())

bm25_retriever = BM25Retriever(bm25_config)
bm25_retriever.build_from_texts(corpus)
# bm25_retriever.load_index()
query = "伊朗总统莱希"
search_docs = bm25_retriever.retrieve(query)
print(search_docs)
```

### DenseRetriever

参数说明:

- `model_name_or_path`:embedding模型hf名称或者路径
- `dim`:向量维度
- `index_dir`:构建索引路径

```python
import pandas as pd
from tqdm import tqdm

from gomate.modules.retrieval.dense_retriever import DenseRetriever, DenseRetrieverConfig

if __name__ == '__main__':
retriever_config = DenseRetrieverConfig(
model_name_or_path="/data/users/searchgpt/pretrained_models/bge-large-zh-v1.5",
dim=1024,
index_dir='/data/users/searchgpt/yq/GoMate/examples/retrievers/dense_cache'
)
config_info = retriever_config.log_config()
print(config_info)
retriever = DenseRetriever(config=retriever_config)
data = pd.read_json('/data/users/searchgpt/yq/GoMate/data/docs/zh_refine.json', lines=True)[:5]
print(data)
print(data.columns)

corpus = []
for documents in tqdm(data['positive'], total=len(data)):
for document in documents:
# retriever.add_text(document)
corpus.append(document)
for documents in tqdm(data['negative'], total=len(data)):
for document in documents:
# retriever.add_text(document)
corpus.append(document)
print("len(corpus)", len(corpus))
retriever.build_from_texts(corpus)
result = retriever.retrieve("RCEP具体包括哪些国家")
print(result)
retriever.save_index()
```

### HybridRetriever

> 混合检索,将Bm25检索以及Dense检索的结果进行合并
```python
from gomate.modules.document.common_parser import CommonParser
from gomate.modules.retrieval.bm25s_retriever import BM25RetrieverConfig
from gomate.modules.retrieval.dense_retriever import DenseRetrieverConfig
from gomate.modules.retrieval.hybrid_retriever import HybridRetriever, HybridRetrieverConfig

if __name__ == '__main__':
# BM25 and Dense Retriever configurations
bm25_config = BM25RetrieverConfig(
method='lucene',
index_path='indexs/description_bm25.index',
k1=1.6,
b=0.7
)
bm25_config.validate()
print(bm25_config.log_config())

dense_config = DenseRetrieverConfig(
model_name_or_path="/data/users/searchgpt/pretrained_models/bge-large-zh-v1.5",
dim=1024,
index_path='/data/users/searchgpt/yq/GoMate/examples/retrievers/dense_cache'
)
config_info = dense_config.log_config()
print(config_info)

# Hybrid Retriever configuration
hybrid_config = HybridRetrieverConfig(
bm25_config=bm25_config,
dense_config=dense_config,
bm25_weight=0.5,
dense_weight=0.5
)
hybrid_retriever = HybridRetriever(config=hybrid_config)

# Corpus
corpus = []

# Files to be parsed
new_files = [
r'/data/users/searchgpt/yq/GoMate_dev/data/docs/伊朗.txt',
r'/data/users/searchgpt/yq/GoMate_dev/data/docs/伊朗总统罹难事件.txt',
r'/data/users/searchgpt/yq/GoMate_dev/data/docs/伊朗总统莱希及多位高级官员遇难的直升机事故.txt',
r'/data/users/searchgpt/yq/GoMate_dev/data/docs/伊朗问题.txt',
r'/data/users/searchgpt/yq/GoMate_dev/data/docs/新冠肺炎疫情.pdf',
]

# Parsing documents
parser = CommonParser()
for filename in new_files:
chunks = parser.parse(filename)
corpus.extend(chunks)

# Build hybrid retriever from texts
hybrid_retriever.build_from_texts(corpus)

# Query
query = "新冠肺炎疫情"
results = hybrid_retriever.retrieve(query, top_k=3)

# Output results
for result in results:
print(f"Text: {result['text']}, Score: {result['score']}")

```
31 changes: 17 additions & 14 deletions examples/retrievers/bm25sretrever_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,35 @@
@license: Apache Licence
@time: 2024/08/27 14:16
"""
import os

from gomate.modules.document.common_parser import CommonParser
from gomate.modules.retrieval.bm25s_retriever import BM25Retriever
from gomate.modules.document.utils import PROJECT_BASE
if __name__ == '__main__':
from gomate.modules.retrieval.bm25s_retriever import BM25RetrieverConfig, BM25Retriever

if __name__ == '__main__':

corpus = []

new_files = [
# f'{PROJECT_BASE}/data/docs/伊朗.txt',
# f'{PROJECT_BASE}/data/docs/伊朗总统罹难事件.txt',
# f'{PROJECT_BASE}/data/docs/伊朗总统莱希及多位高级官员遇难的直升机事故.txt',
# f'{PROJECT_BASE}/data/docs/伊朗问题.txt',
# f'{PROJECT_BASE}/data/docs/汽车操作手册.pdf',
r'H:\2024-Xfyun-RAG\data\corpus.txt'
f'{PROJECT_BASE}/data/docs/伊朗.txt',
f'{PROJECT_BASE}/data/docs/伊朗总统罹难事件.txt',
f'{PROJECT_BASE}/data/docs/伊朗总统莱希及多位高级官员遇难的直升机事故.txt',
f'{PROJECT_BASE}/data/docs/伊朗问题.txt',
f'{PROJECT_BASE}/data/docs/汽车操作手册.pdf',
# r'H:\2024-Xfyun-RAG\data\corpus.txt'
]
parser = CommonParser()
for filename in new_files:
chunks = parser.parse(filename)
corpus.extend(chunks)
bm25_retriever = BM25Retriever(method="lucene",
index_path="indexs/description_bm25.index",
rebuild_index=True,
corpus=corpus)

bm25_config = BM25RetrieverConfig(method='lucene', index_path='indexs/description_bm25.index', k1=1.6, b=0.7)
bm25_config.validate()
print(bm25_config.log_config())

bm25_retriever = BM25Retriever(bm25_config)
bm25_retriever.build_from_texts(corpus)
# bm25_retriever.load_index()
query = "伊朗总统莱希"
search_docs = bm25_retriever.retrieve(query)
print(search_docs)
print(search_docs)
13 changes: 8 additions & 5 deletions examples/retrievers/denseretriever_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,26 @@
retriever_config = DenseRetrieverConfig(
model_name_or_path="/data/users/searchgpt/pretrained_models/bge-large-zh-v1.5",
dim=1024,
index_dir='/data/users/searchgpt/yq/GoMate/examples/retrievers/dense_cache'
index_path='/data/users/searchgpt/yq/GoMate/examples/retrievers/dense_cache'
)
config_info = retriever_config.log_config()
print(config_info)

retriever = DenseRetriever(config=retriever_config)

data = pd.read_json('/data/users/searchgpt/yq/GoMate/data/docs/zh_refine.json', lines=True)[:5]
print(data)
print(data.columns)

corpus = []
for documents in tqdm(data['positive'], total=len(data)):
for document in documents:
retriever.add_text(document)
# retriever.add_text(document)
corpus.append(document)
for documents in tqdm(data['negative'], total=len(data)):
for document in documents:
retriever.add_text(document)
# retriever.add_text(document)
corpus.append(document)
print("len(corpus)",len(corpus))
retriever.build_from_texts(corpus)
result = retriever.retrieve("RCEP具体包括哪些国家")
print(result)
retriever.save_index()
26 changes: 20 additions & 6 deletions examples/retrievers/hybridretriever_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,34 @@
@license: Apache Licence
@time: 2024/08/27 14:16
"""
import os

from gomate.modules.document.common_parser import CommonParser
from gomate.modules.retrieval.bm25s_retriever import BM25RetrieverConfig, BM25Retriever, tokenizer
from gomate.modules.retrieval.dense_retriever import DenseRetriever, DenseRetrieverConfig
from gomate.modules.retrieval.bm25s_retriever import BM25RetrieverConfig
from gomate.modules.retrieval.dense_retriever import DenseRetrieverConfig
from gomate.modules.retrieval.hybrid_retriever import HybridRetriever, HybridRetrieverConfig

if __name__ == '__main__':
# BM25 and Dense Retriever configurations
bm25_config = BM25RetrieverConfig(tokenizer=tokenizer, k1=1.5, b=0.75)
dense_config = DenseRetrieverConfig(model_name_or_path='sentence-transformers/all-mpnet-base-v2')
bm25_config = BM25RetrieverConfig(
method='lucene',
index_path='indexs/description_bm25.index',
k1=1.6,
b=0.7
)
bm25_config.validate()
print(bm25_config.log_config())

dense_config = DenseRetrieverConfig(
model_name_or_path="/data/users/searchgpt/pretrained_models/bge-large-zh-v1.5",
dim=1024,
index_path='/data/users/searchgpt/yq/GoMate/examples/retrievers/dense_cache'
)
config_info = dense_config.log_config()
print(config_info)

# Hybrid Retriever configuration
hybrid_config = HybridRetrieverConfig(bm25_config=bm25_config, dense_config=dense_config, bm25_weight=0.5, dense_weight=0.5)
hybrid_config = HybridRetrieverConfig(bm25_config=bm25_config, dense_config=dense_config, bm25_weight=0.5,
dense_weight=0.5)
hybrid_retriever = HybridRetriever(config=hybrid_config)

# Corpus
Expand Down
2 changes: 1 addition & 1 deletion gomate/modules/document/txt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def get_encoding(file):
return tmp['encoding']

class TextParser(object):
def parse(self, fnm, encoding=None, from_page=0, to_page=100000, **kwargs):
def parse(self, fnm, encoding="utf-8", from_page=0, to_page=100000, **kwargs):
# 如果 fnm 不是字符串(假设是字节流等),则使用 find_codec 找到编码
if not isinstance(fnm, str):
encoding = get_encoding(fnm) if encoding is None else encoding
Expand Down
12 changes: 9 additions & 3 deletions gomate/modules/document/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,25 @@

import tiktoken

import pathlib

# 获取当前文件所在的路径
current_path = pathlib.Path(__file__).resolve()
print(current_path)

# 找到根目录,这里假设项目的根目录为 'GoMate'
# 找到根目录,这里假设项目的根目录为 'GoMate' 或 'GoMate_dev'
project_root = current_path
while project_root.name != 'GoMate':
while project_root.name != 'GoMate' and project_root.name != 'GoMate_dev':
project_root = project_root.parent
# 如果到达根目录还没找到项目根目录,则可能路径有问题,防止死循环
if project_root == project_root.parent:
raise Exception("项目根目录未找到")

# 在 Windows 中输出带反斜杠的路径
project_root_str = str(project_root)

print(f"项目根目录为: {project_root_str}")


PROJECT_BASE = project_root_str
all_codecs = [
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
Expand Down
56 changes: 54 additions & 2 deletions gomate/modules/retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,61 @@
@description: coding..
"""
from abc import ABC, abstractmethod
from typing import List
import json
import os


class BaseRetriever(ABC):
"""通用的检索器接口"""

def save_index(self):
raise NotImplementedError
def load_index(self):
raise NotImplementedError
def build_from_texts(self, corpus):
"""构建索引"""
raise NotImplementedError

@abstractmethod
def retrieve(self, query: str,top_k:int) -> str:
pass
def retrieve(self, query, top_k):
"""检索并返回前K个结果"""
raise NotImplementedError



class BaseConfig:
"""
Base configuration class that provides common methods for managing configurations.
This class can be inherited by specific configuration classes (e.g., BM25RetrieverConfig, DenseRetrieverConfig)
to implement shared methods like saving to a file, loading from a file, and logging the configuration.
"""

def log_config(self):
"""Return a formatted string that summarizes the configuration."""
config_summary = f"{self.__class__.__name__} Configuration:\n"
for key, value in self.__dict__.items():
config_summary += f"{key}: {value}\n"
return config_summary

def save_to_file(self, file_path):
"""Save the configuration to a JSON file."""
with open(file_path, 'w') as f:
json.dump(self.__dict__, f, indent=4)
print(f"Configuration saved to {file_path}")

@classmethod
def load_from_file(cls, file_path):
"""Load configuration from a JSON file."""
if not os.path.exists(file_path):
raise FileNotFoundError(f"Configuration file {file_path} does not exist.")

with open(file_path, 'r') as f:
config_dict = json.load(f)

return cls(**config_dict)

def validate(self):
"""Validate configuration parameters. Override in subclasses if needed."""
raise NotImplementedError("This method should be implemented in the subclass.")
Loading

0 comments on commit 9382511

Please sign in to comment.