Skip to content

Commit

Permalink
add bge large reranker
Browse files Browse the repository at this point in the history
  • Loading branch information
Wenshansilvia committed Feb 10, 2024
1 parent 6bf7054 commit f1e7b18
Show file tree
Hide file tree
Showing 11 changed files with 132 additions and 1 deletion.
Binary file not shown.
26 changes: 26 additions & 0 deletions gomate/applications/RerankerApp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from gomate.modules import bge_large_reranker


class RerankerApp():
"""重排模块,评估文档的相关性并重新排序。把最有可能提供准确、相关回答的文档排在前面。
实现包括
1. bge-reranker-large。智源开源的Rerank模型。
2. ...
"""

def __init__(self, component_name=None):
"""Init required reranker according to component name."""
self.reranker_list = ['bge_large']
assert component_name in self.reranker_list
if component_name == 'bge_large':
self.reranker = bge_large_reranker()

def run(self, query, contexts):
"""Run the required reranker"""
if query is None:
raise ValueError('missing query')
if contexts is None:
raise ValueError('missing contexts')
return self.reranker.run(query, contexts)
File renamed without changes.
3 changes: 2 additions & 1 deletion gomate/applications/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .RewriteApp import RewriterApp
from .RewriterApp import RewriterApp
from .RerankerApp import RerankerApp
1 change: 1 addition & 0 deletions gomate/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .reranker import bge_large_reranker
from .rewriter import HyDE_rewriter
1 change: 1 addition & 0 deletions gomate/modules/reranker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .bge_large_reranker import bge_large_reranker
14 changes: 14 additions & 0 deletions gomate/modules/reranker/base_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from abc import ABC, abstractmethod


class base_reranker(ABC):
"""Define base reranker."""

@abstractmethod
def __init__(self, component_name=None):
"""Init required reranker according to component name."""
...

def run(self, query, contexts):
"""Run the required reranker"""
...
66 changes: 66 additions & 0 deletions gomate/modules/reranker/bge_large_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from tqdm import tqdm
from typing import List
import numpy as np


class bge_large_reranker():
"""This is bge-reranker-large."""

def __init__(self,
model_name_or_path: str = 'BAAI/bge-reranker-large',
use_fp16: bool = False):
"""Init the hyde reranker model"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
if torch.cuda.is_available():
self.device = torch.device('cuda')
elif torch.backends.mps.is_available():
self.device = torch.device('mps')
else:
self.device = torch.device('cpu')
use_fp16 = False
if use_fp16:
self.model.half()
self.model = self.model.to(self.device)
self.model.eval()

self.num_gpus = torch.cuda.device_count()
if self.num_gpus > 1:
print(f"----------using {self.num_gpus}*GPUs----------")
self.model = torch.nn.DataParallel(self.model)

@torch.no_grad()
def run(self, query, contexts, batch_size: int = 256,
max_length: int = 512) -> List[float]:
"""Get reranked contexts in runtime"""

if self.num_gpus > 0:
batch_size = batch_size * self.num_gpus

assert isinstance(query, str)
assert isinstance(contexts, list)
sentence_pairs = [[query, cxt] for cxt in contexts]

all_scores = []
for start_index in tqdm(range(0, len(sentence_pairs), batch_size), desc="Compute Reranking Scores",
disable=len(sentence_pairs) < 128):
sentences_batch = sentence_pairs[start_index:start_index + batch_size]
inputs = self.tokenizer(
sentences_batch,
padding=True,
truncation=True,
return_tensors='pt',
max_length=max_length,
).to(self.device)

scores = self.model(**inputs, return_dict=True).logits.view(-1, ).float()
all_scores.extend(scores.cpu().numpy().tolist())

def sigmoid(x):
return 1 / (1 + np.exp(-x))

probabilities = sigmoid(np.array(all_scores))
print(probabilities)
return probabilities
2 changes: 2 additions & 0 deletions gomate/modules/rewriter/base_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from abc import ABC, abstractmethod


class base_rewriter(ABC):
"""Define base rewriter."""

@abstractmethod
def __init__(self, component_name=None):
"""Init required rewriter according to component name."""
Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ pydocstyle == 2.1
openai == 1.10.0
datasets == 2.16.1
langchain == 0.1.4
transformers == 4.37.2
torch == 2.2.0
pandas == 2.0.0
nltk == 3.8.1
16 changes: 16 additions & 0 deletions tests/units/test_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-

import pytest
from gomate.applications import RerankerApp
# import os

def test_reranker():
component_name = 'bge_large'
model = RerankerApp(component_name = component_name)
query = "恐龙是怎么被命名的?"
contexts = ["[12]恐龙是 介于冷血和温血之间的动物2014年6月,有关恐龙究竟是像鸟类和哺乳动物一样的温血动物,还是类似爬行动物、鱼类和两栖动物的冷血动物的问题终于有了答案——恐龙其实是介于冷血和温血之间的动物。 [12]“我们的结果显示恐龙所具有的生长速率和新陈代谢速率,既不是冷血生物体也不是温血生物体所具有的特征。它们既不像哺乳动物或者鸟类,也不像爬行动物或者鱼类,而是介于现代冷血动物和温血动物之间。简言之,它们的生理机能在现代社会并不常见。”美国亚利桑那大学进化生物学家和生态学家布莱恩·恩奎斯特说。墨西哥生物学家表示,正是这种中等程度的新陈代谢使得恐龙可以长得比任何哺乳动物都要大。温血动物需要大量进食,因此它们频繁猎捕和咀嚼植物。“很难想象霸王龙大小的狮子能够吃饱以 存活下来。","[12]哺乳动物起源于爬行动物,它们的前身是“似哺乳类的爬行动物”,即兽孔目,早期则是“似爬行类的哺乳动物”,即哺乳型动物。 [12]中生代的爬行动物,大部分在中生代的末期灭绝了;一部分适应了变化的环境被保留下来,即现存的爬行动物(如龟鳖类、蛇类、鳄类等);还有一部分沿着不同的进化方向,进化成了现今的鸟类和哺乳类。 [12]恐龙是 介于冷血和温血之间的动物2014年6月,有关恐龙究竟是像鸟类和哺乳动物一样的温血动物,还是类似爬行动物、鱼类和两栖动物的冷血动物的问题终于有了答案——恐龙其实是介于冷血和温血之间的动物。"]
probabilities = model.run(query, contexts)
assert probabilities is not None

if __name__ == '__main__':
test_reranker()

0 comments on commit f1e7b18

Please sign in to comment.