-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6bf7054
commit f1e7b18
Showing
11 changed files
with
132 additions
and
1 deletion.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from gomate.modules import bge_large_reranker | ||
|
||
|
||
class RerankerApp(): | ||
"""重排模块,评估文档的相关性并重新排序。把最有可能提供准确、相关回答的文档排在前面。 | ||
实现包括 | ||
1. bge-reranker-large。智源开源的Rerank模型。 | ||
2. ... | ||
""" | ||
|
||
def __init__(self, component_name=None): | ||
"""Init required reranker according to component name.""" | ||
self.reranker_list = ['bge_large'] | ||
assert component_name in self.reranker_list | ||
if component_name == 'bge_large': | ||
self.reranker = bge_large_reranker() | ||
|
||
def run(self, query, contexts): | ||
"""Run the required reranker""" | ||
if query is None: | ||
raise ValueError('missing query') | ||
if contexts is None: | ||
raise ValueError('missing contexts') | ||
return self.reranker.run(query, contexts) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .RewriteApp import RewriterApp | ||
from .RewriterApp import RewriterApp | ||
from .RerankerApp import RerankerApp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .reranker import bge_large_reranker | ||
from .rewriter import HyDE_rewriter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .bge_large_reranker import bge_large_reranker |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class base_reranker(ABC): | ||
"""Define base reranker.""" | ||
|
||
@abstractmethod | ||
def __init__(self, component_name=None): | ||
"""Init required reranker according to component name.""" | ||
... | ||
|
||
def run(self, query, contexts): | ||
"""Run the required reranker""" | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import torch | ||
from transformers import AutoModelForSequenceClassification, AutoTokenizer | ||
from tqdm import tqdm | ||
from typing import List | ||
import numpy as np | ||
|
||
|
||
class bge_large_reranker(): | ||
"""This is bge-reranker-large.""" | ||
|
||
def __init__(self, | ||
model_name_or_path: str = 'BAAI/bge-reranker-large', | ||
use_fp16: bool = False): | ||
"""Init the hyde reranker model""" | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | ||
self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path) | ||
if torch.cuda.is_available(): | ||
self.device = torch.device('cuda') | ||
elif torch.backends.mps.is_available(): | ||
self.device = torch.device('mps') | ||
else: | ||
self.device = torch.device('cpu') | ||
use_fp16 = False | ||
if use_fp16: | ||
self.model.half() | ||
self.model = self.model.to(self.device) | ||
self.model.eval() | ||
|
||
self.num_gpus = torch.cuda.device_count() | ||
if self.num_gpus > 1: | ||
print(f"----------using {self.num_gpus}*GPUs----------") | ||
self.model = torch.nn.DataParallel(self.model) | ||
|
||
@torch.no_grad() | ||
def run(self, query, contexts, batch_size: int = 256, | ||
max_length: int = 512) -> List[float]: | ||
"""Get reranked contexts in runtime""" | ||
|
||
if self.num_gpus > 0: | ||
batch_size = batch_size * self.num_gpus | ||
|
||
assert isinstance(query, str) | ||
assert isinstance(contexts, list) | ||
sentence_pairs = [[query, cxt] for cxt in contexts] | ||
|
||
all_scores = [] | ||
for start_index in tqdm(range(0, len(sentence_pairs), batch_size), desc="Compute Reranking Scores", | ||
disable=len(sentence_pairs) < 128): | ||
sentences_batch = sentence_pairs[start_index:start_index + batch_size] | ||
inputs = self.tokenizer( | ||
sentences_batch, | ||
padding=True, | ||
truncation=True, | ||
return_tensors='pt', | ||
max_length=max_length, | ||
).to(self.device) | ||
|
||
scores = self.model(**inputs, return_dict=True).logits.view(-1, ).float() | ||
all_scores.extend(scores.cpu().numpy().tolist()) | ||
|
||
def sigmoid(x): | ||
return 1 / (1 + np.exp(-x)) | ||
|
||
probabilities = sigmoid(np.array(all_scores)) | ||
print(probabilities) | ||
return probabilities |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import pytest | ||
from gomate.applications import RerankerApp | ||
# import os | ||
|
||
def test_reranker(): | ||
component_name = 'bge_large' | ||
model = RerankerApp(component_name = component_name) | ||
query = "恐龙是怎么被命名的?" | ||
contexts = ["[12]恐龙是 介于冷血和温血之间的动物2014年6月,有关恐龙究竟是像鸟类和哺乳动物一样的温血动物,还是类似爬行动物、鱼类和两栖动物的冷血动物的问题终于有了答案——恐龙其实是介于冷血和温血之间的动物。 [12]“我们的结果显示恐龙所具有的生长速率和新陈代谢速率,既不是冷血生物体也不是温血生物体所具有的特征。它们既不像哺乳动物或者鸟类,也不像爬行动物或者鱼类,而是介于现代冷血动物和温血动物之间。简言之,它们的生理机能在现代社会并不常见。”美国亚利桑那大学进化生物学家和生态学家布莱恩·恩奎斯特说。墨西哥生物学家表示,正是这种中等程度的新陈代谢使得恐龙可以长得比任何哺乳动物都要大。温血动物需要大量进食,因此它们频繁猎捕和咀嚼植物。“很难想象霸王龙大小的狮子能够吃饱以 存活下来。","[12]哺乳动物起源于爬行动物,它们的前身是“似哺乳类的爬行动物”,即兽孔目,早期则是“似爬行类的哺乳动物”,即哺乳型动物。 [12]中生代的爬行动物,大部分在中生代的末期灭绝了;一部分适应了变化的环境被保留下来,即现存的爬行动物(如龟鳖类、蛇类、鳄类等);还有一部分沿着不同的进化方向,进化成了现今的鸟类和哺乳类。 [12]恐龙是 介于冷血和温血之间的动物2014年6月,有关恐龙究竟是像鸟类和哺乳动物一样的温血动物,还是类似爬行动物、鱼类和两栖动物的冷血动物的问题终于有了答案——恐龙其实是介于冷血和温血之间的动物。"] | ||
probabilities = model.run(query, contexts) | ||
assert probabilities is not None | ||
|
||
if __name__ == '__main__': | ||
test_reranker() |