-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Init Pairwise Rerankers #133
Open
xpbowler
wants to merge
7
commits into
castorini:main
Choose a base branch
from
xpbowler:pairwise
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
fef2f50
fix
xpbowler 9cb9612
inits, cleanup
xpbowler 912de1a
duot5 and pairwise implementation
IR3KT4FUNZ b1d7eff
implementation for pairwise and duot5
IR3KT4FUNZ b35f252
duot5 bug fixes, add fix for retrieving <100 candidates in non-intera…
IR3KT4FUNZ 0950167
remove temporarily unnecessary interactive argument
IR3KT4FUNZ 5604a61
fix enumeration bug
IR3KT4FUNZ File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .duot5 import DuoT5 | ||
|
||
__all__ = ["DuoT5"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import logging | ||
import math | ||
from typing import List, Tuple | ||
|
||
from transformers import T5ForConditionalGeneration, T5Tokenizer | ||
from transformers.generation import GenerationConfig | ||
|
||
from rank_llm.data import Result | ||
from rank_llm.rerank.pairwise.pairwise_rankllm import PairwiseRankLLM | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class DuoT5(PairwiseRankLLM): | ||
def __init__( | ||
self, | ||
model: str, | ||
prompt_mode: str = "duot5", | ||
context_size: int = 512, | ||
device: str = "cuda", | ||
batch_size: int = 32, | ||
): | ||
super().__init__( | ||
model=model, | ||
context_size=context_size, | ||
prompt_mode=prompt_mode, | ||
device=device, | ||
batch_size=batch_size, | ||
) | ||
|
||
self._tokenizer = T5Tokenizer.from_pretrained(model) | ||
self._llm = T5ForConditionalGeneration.from_pretrained(model).to(self._device) | ||
self._context_size = context_size | ||
|
||
def run_llm_batched( | ||
self, | ||
prompts: List[str], | ||
) -> Tuple[List[str], List[int], List[float]]: | ||
gen_cfg = GenerationConfig.from_model_config(self._llm.config) | ||
gen_cfg.max_new_tokens = self.num_output_tokens() | ||
gen_cfg.min_new_tokens = self.num_output_tokens() | ||
gen_cfg.output_scores = True | ||
gen_cfg.return_dict_in_generate = True | ||
gen_cfg.do_sample = False | ||
|
||
all_outputs = [] | ||
all_output_token_counts = [] | ||
all_scores = [] | ||
|
||
batch_prompts = prompts | ||
|
||
token_prompts = self._tokenizer( | ||
batch_prompts, padding=True, truncation=True, return_tensors="pt" | ||
).to(self._device) | ||
|
||
token_prompts = token_prompts["input_ids"] | ||
|
||
batch_outputs = self._llm.generate(token_prompts, generation_config=gen_cfg) | ||
|
||
batch_output_ids = batch_outputs.sequences | ||
batch_logits = batch_outputs.scores | ||
|
||
batch_outputs = [ | ||
self._tokenizer.decode( | ||
single_token_sequence, | ||
skip_special_tokens=True, | ||
spaces_between_special_tokens=False, | ||
) | ||
for single_token_sequence in batch_output_ids | ||
] | ||
|
||
for logit_tensor in batch_logits[0]: | ||
truth_logit = logit_tensor[1176] | ||
false_logit = logit_tensor[6136] | ||
score = math.exp(truth_logit) / ( | ||
math.exp(truth_logit) + math.exp(false_logit) | ||
) | ||
all_scores.append(score) | ||
all_output_token_counts.append(self.num_output_tokens) | ||
|
||
all_outputs.extend(batch_outputs) | ||
|
||
return all_outputs, all_output_token_counts, all_scores | ||
|
||
def run_llm(self, prompt: str) -> Tuple[str, int, float]: | ||
gen_cfg = GenerationConfig.from_model_config(self._llm.config) | ||
gen_cfg.max_new_tokens = self.num_output_tokens() | ||
gen_cfg.min_new_tokens = self.num_output_tokens() | ||
gen_cfg.output_scores = True | ||
gen_cfg.return_dict_in_generate = True | ||
gen_cfg.do_sample = False | ||
|
||
token_prompt = self._tokenizer.encode(prompt, return_tensors="pt").to( | ||
self._device | ||
) | ||
output = self._llm.generate(token_prompt, generation_config=gen_cfg) | ||
output_ids = output.sequences | ||
logits = output.scores | ||
|
||
if self._llm.config.is_encoder_decoder: | ||
output_ids = output_ids[0] | ||
output_ids = output_ids[1:] | ||
|
||
outputs = self._tokenizer.decode( | ||
output_ids, skip_special_tokens=True, spaces_between_special_tokens=False | ||
) | ||
truth_logit = logits[0][0][1176] | ||
false_logit = logits[0][0][6136] | ||
score = math.exp(truth_logit) / (math.exp(truth_logit) + math.exp(false_logit)) | ||
|
||
return outputs, output_ids.size(0), score | ||
|
||
def num_output_tokens(self) -> int: | ||
return 1 | ||
|
||
def create_prompt(self, result: Result, index1: int, index2: int) -> Tuple[str, int]: | ||
query = result.query.text | ||
query = self._replace_number(query) | ||
doc1 = self.convert_doc_to_prompt_content(result.candidates[index1].doc, max_length=self._context_size) | ||
doc2 = self.convert_doc_to_prompt_content(result.candidates[index2].doc, max_length=self._context_size) | ||
doc1 = self._tokenizer.decode(self._tokenizer.encode(doc1)[:240])[:-4] | ||
doc2 = self._tokenizer.decode(self._tokenizer.encode(doc2)[:240])[:-4] | ||
prompt = f"Query: {query} Document0: {doc1} Document1: {doc2} Relevant:" | ||
prompt = prompt.replace("<unk>","") | ||
|
||
return prompt, self.get_num_tokens(prompt) | ||
|
||
def get_num_tokens(self, prompt: str) -> int: | ||
return len(self._tokenizer.encode(prompt)) | ||
|
||
def cost_per_1k_token(self, input_token: bool) -> float: | ||
return 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
import copy | ||
import logging | ||
import math | ||
import re | ||
from abc import ABC | ||
from datetime import datetime | ||
from functools import cmp_to_key | ||
from typing import Any, Dict, List, Tuple | ||
|
||
from ftfy import fix_text | ||
from tqdm import tqdm | ||
|
||
from rank_llm.data import Candidate, Request, Result | ||
from rank_llm.rerank.rankllm import PromptMode, RankLLM | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class PairwiseRankLLM(RankLLM, ABC): | ||
""" | ||
Abstract base class that all pairwise rerankers implement. | ||
|
||
All concrete children of RankLLM must implement these functions: | ||
- rerank_batch | ||
- run_llm_batched | ||
- run_llm | ||
- create_prompt_batched | ||
- create_prompt | ||
- get_num_tokens | ||
- cost_per_1k_tokens | ||
- num_output_tokens | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: str, | ||
context_size: int, | ||
prompt_mode: PromptMode, | ||
device: str = "cuda", | ||
filename: str = "", | ||
batch_size: int = 32, | ||
) -> None: | ||
super().__init__(model, context_size, prompt_mode) | ||
self._device = device | ||
self._filename = filename | ||
self._batch_size = batch_size | ||
|
||
def rerank_batch( | ||
self, | ||
requests: List[Request], | ||
rank_start: int = 0, | ||
rank_end: int = 100, | ||
shuffle_candidates: bool = False, | ||
logging: bool = False, | ||
**kwargs: Any, | ||
) -> List[Result]: | ||
self._enumerated_indices = [] | ||
|
||
rerank_results = [ | ||
Result( | ||
query=copy.deepcopy(request.query), | ||
candidates=copy.deepcopy(request.candidates), | ||
ranking_exec_summary=[], | ||
) | ||
for request in requests | ||
] | ||
|
||
for result in rerank_results: | ||
for i in result.candidates: | ||
i.score = 0 | ||
|
||
for index in range(len(requests) * len(requests[0].candidates) * len(requests[0].candidates)): | ||
candidate_1 = math.floor( | ||
(index % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates) | ||
) | ||
candidate_2 = index % len(rerank_results[0].candidates) | ||
if candidate_1 != candidate_2: | ||
self._enumerated_indices.append(index) | ||
|
||
end = (len(rerank_results[0].candidates) - 1) * len(rerank_results[0].candidates) * len(requests) | ||
with tqdm(total=end, desc="Progress through (q, d) pairs") as progress_bar: | ||
for index in range(0, end, self._batch_size): | ||
prompts, token_counts = self.create_prompt_batched( | ||
results=rerank_results, index=index | ||
) | ||
|
||
outputs, output_tokens, scores = self.run_llm_batched(prompts=prompts) | ||
|
||
for update_index in range ( | ||
index, | ||
min( | ||
index + self._batch_size, | ||
end | ||
) | ||
): | ||
update_index_copy = self._enumerated_indices[update_index] | ||
query_number = math.floor( | ||
update_index_copy / (len(rerank_results[0].candidates) ** 2) | ||
) | ||
candidate_1 = math.floor( | ||
(update_index_copy % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates) | ||
) | ||
candidate_2 = update_index_copy % len(rerank_results[0].candidates) | ||
|
||
rerank_results[query_number].candidates[candidate_1].score += scores[update_index - index] | ||
rerank_results[query_number].candidates[candidate_2].score += 1 - scores[update_index - index] | ||
|
||
if index + self._batch_size > end: | ||
progress_bar.update(end - index) | ||
else: | ||
progress_bar.update(self._batch_size) | ||
|
||
|
||
for result in rerank_results: | ||
result.candidates.sort( | ||
key=cmp_to_key(self.candidate_comparator), reverse=True | ||
) | ||
|
||
return rerank_results | ||
|
||
def create_prompt_batched( | ||
self, results: List[Result], index | ||
) -> Tuple[List[str], List[int]]: | ||
prompts = [] | ||
token_counts = [] | ||
|
||
for current_index in range( | ||
index, | ||
min(index + self._batch_size, len(results[0].candidates) * (len(results[0].candidates) - 1) * len(results)), | ||
): | ||
current_index = self._enumerated_indices[current_index] | ||
query_number = math.floor( | ||
current_index / (len(results[0].candidates) ** 2) | ||
) | ||
candidate_1 = math.floor( | ||
(current_index % (len(results[0].candidates) ** 2)) / len(results[0].candidates) | ||
) | ||
candidate_2 = current_index % len(results[0].candidates) | ||
|
||
prompt, token_count = self.create_prompt( | ||
result=results[query_number], index1=candidate_1, index2=candidate_2 | ||
) | ||
|
||
prompts.append(prompt) | ||
token_counts.append(token_count) | ||
return prompts, token_counts | ||
|
||
def get_output_filename( | ||
self, | ||
top_k_candidates: int, | ||
dataset_name: str, | ||
shuffle_candidates: bool, | ||
**kwargs: Any, | ||
) -> str: | ||
if self._filename != "": | ||
return self._filename | ||
_modelname = self._model.split("/")[-1] | ||
if _modelname.startswith("checkpoint"): | ||
_modelname = self._model.split("/")[-2] + "_" + _modelname | ||
name = ( | ||
f"{_modelname}_{self._context_size}_{top_k_candidates}_{self._prompt_mode}" | ||
) | ||
if dataset_name: | ||
name = f"{name}_{dataset_name}" | ||
|
||
if shuffle_candidates: | ||
self._filename = f"{name}_shuffled_{datetime.isoformat(datetime.now())}" | ||
else: | ||
self._filename = f"{name}_{datetime.isoformat(datetime.now())}" | ||
|
||
return ( | ||
f"{name}_shuffled_{datetime.isoformat(datetime.now())}" | ||
if shuffle_candidates | ||
else f"{name}_{datetime.isoformat(datetime.now())}" | ||
) | ||
|
||
def candidate_comparator(self, x: Candidate, y: Candidate) -> int: | ||
if x.score < y.score: | ||
return -1 | ||
elif x.score > y.score: | ||
return 1 | ||
else: | ||
return 0 | ||
|
||
def _replace_number(self, s: str) -> str: | ||
return re.sub(r"\[(\d+)\]", r"(\1)", s) | ||
|
||
def convert_doc_to_prompt_content( | ||
self, doc: Dict[str, Any], max_length: int | ||
) -> str: | ||
if "text" in doc: | ||
content = doc["text"] | ||
elif "segment" in doc: | ||
content = doc["segment"] | ||
elif "contents" in doc: | ||
content = doc["contents"] | ||
elif "content" in doc: | ||
content = doc["content"] | ||
elif "body" in doc: | ||
content = doc["body"] | ||
else: | ||
content = doc["passage"] | ||
if "title" in doc and doc["title"]: | ||
content = "Title: " + doc["title"] + " " + "Content: " + content | ||
content = content.strip() | ||
content = fix_text(content) | ||
# For Japanese should cut by character: content = content[:int(max_length)] | ||
content = " ".join(content.split()[: int(max_length)]) | ||
return self._replace_number(content) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .pointwise_rankllm import PointwiseRankLLM | ||
from .monot5 import MonoT5 | ||
|
||
__all__ = ["PointwiseRankLLM"] | ||
__all__ = ["MonoT5"] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 240 and -4 are a bit hardcoded here?