Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Init Pairwise Rerankers #133

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/rank_llm/rerank/listwise/listwise_rankllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

class ListwiseRankLLM(RankLLM, ABC):
"""
Abstract base class that all listwise rerankers inherit.

All children of ListwiseRankLLM must implement these functions:
- rerank_batched
- run_llm_batched
Expand Down
3 changes: 3 additions & 0 deletions src/rank_llm/rerank/pairwise/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .duot5 import DuoT5

__all__ = ["DuoT5"]
132 changes: 132 additions & 0 deletions src/rank_llm/rerank/pairwise/duot5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import logging
import math
from typing import List, Tuple

from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers.generation import GenerationConfig

from rank_llm.data import Result
from rank_llm.rerank.pairwise.pairwise_rankllm import PairwiseRankLLM

logger = logging.getLogger(__name__)


class DuoT5(PairwiseRankLLM):
def __init__(
self,
model: str,
prompt_mode: str = "duot5",
context_size: int = 512,
device: str = "cuda",
batch_size: int = 32,
):
super().__init__(
model=model,
context_size=context_size,
prompt_mode=prompt_mode,
device=device,
batch_size=batch_size,
)

self._tokenizer = T5Tokenizer.from_pretrained(model)
self._llm = T5ForConditionalGeneration.from_pretrained(model).to(self._device)
self._context_size = context_size

def run_llm_batched(
self,
prompts: List[str],
) -> Tuple[List[str], List[int], List[float]]:
gen_cfg = GenerationConfig.from_model_config(self._llm.config)
gen_cfg.max_new_tokens = self.num_output_tokens()
gen_cfg.min_new_tokens = self.num_output_tokens()
gen_cfg.output_scores = True
gen_cfg.return_dict_in_generate = True
gen_cfg.do_sample = False

all_outputs = []
all_output_token_counts = []
all_scores = []

batch_prompts = prompts

token_prompts = self._tokenizer(
batch_prompts, padding=True, truncation=True, return_tensors="pt"
).to(self._device)

token_prompts = token_prompts["input_ids"]

batch_outputs = self._llm.generate(token_prompts, generation_config=gen_cfg)

batch_output_ids = batch_outputs.sequences
batch_logits = batch_outputs.scores

batch_outputs = [
self._tokenizer.decode(
single_token_sequence,
skip_special_tokens=True,
spaces_between_special_tokens=False,
)
for single_token_sequence in batch_output_ids
]

for logit_tensor in batch_logits[0]:
truth_logit = logit_tensor[1176]
false_logit = logit_tensor[6136]
score = math.exp(truth_logit) / (
math.exp(truth_logit) + math.exp(false_logit)
)
all_scores.append(score)
all_output_token_counts.append(self.num_output_tokens)

all_outputs.extend(batch_outputs)

return all_outputs, all_output_token_counts, all_scores

def run_llm(self, prompt: str) -> Tuple[str, int, float]:
gen_cfg = GenerationConfig.from_model_config(self._llm.config)
gen_cfg.max_new_tokens = self.num_output_tokens()
gen_cfg.min_new_tokens = self.num_output_tokens()
gen_cfg.output_scores = True
gen_cfg.return_dict_in_generate = True
gen_cfg.do_sample = False

token_prompt = self._tokenizer.encode(prompt, return_tensors="pt").to(
self._device
)
output = self._llm.generate(token_prompt, generation_config=gen_cfg)
output_ids = output.sequences
logits = output.scores

if self._llm.config.is_encoder_decoder:
output_ids = output_ids[0]
output_ids = output_ids[1:]

outputs = self._tokenizer.decode(
output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
)
truth_logit = logits[0][0][1176]
false_logit = logits[0][0][6136]
score = math.exp(truth_logit) / (math.exp(truth_logit) + math.exp(false_logit))

return outputs, output_ids.size(0), score

def num_output_tokens(self) -> int:
return 1

def create_prompt(self, result: Result, index1: int, index2: int) -> Tuple[str, int]:
query = result.query.text
query = self._replace_number(query)
doc1 = self.convert_doc_to_prompt_content(result.candidates[index1].doc, max_length=self._context_size)
doc2 = self.convert_doc_to_prompt_content(result.candidates[index2].doc, max_length=self._context_size)
doc1 = self._tokenizer.decode(self._tokenizer.encode(doc1)[:240])[:-4]
doc2 = self._tokenizer.decode(self._tokenizer.encode(doc2)[:240])[:-4]
Comment on lines +121 to +122
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 240 and -4 are a bit hardcoded here?

prompt = f"Query: {query} Document0: {doc1} Document1: {doc2} Relevant:"
prompt = prompt.replace("<unk>","")

return prompt, self.get_num_tokens(prompt)

def get_num_tokens(self, prompt: str) -> int:
return len(self._tokenizer.encode(prompt))

def cost_per_1k_token(self, input_token: bool) -> float:
return 0
208 changes: 208 additions & 0 deletions src/rank_llm/rerank/pairwise/pairwise_rankllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import copy
import logging
import math
import re
from abc import ABC
from datetime import datetime
from functools import cmp_to_key
from typing import Any, Dict, List, Tuple

from ftfy import fix_text
from tqdm import tqdm

from rank_llm.data import Candidate, Request, Result
from rank_llm.rerank.rankllm import PromptMode, RankLLM

logger = logging.getLogger(__name__)

class PairwiseRankLLM(RankLLM, ABC):
"""
Abstract base class that all pairwise rerankers implement.

All concrete children of RankLLM must implement these functions:
- rerank_batch
- run_llm_batched
- run_llm
- create_prompt_batched
- create_prompt
- get_num_tokens
- cost_per_1k_tokens
- num_output_tokens
"""

def __init__(
self,
model: str,
context_size: int,
prompt_mode: PromptMode,
device: str = "cuda",
filename: str = "",
batch_size: int = 32,
) -> None:
super().__init__(model, context_size, prompt_mode)
self._device = device
self._filename = filename
self._batch_size = batch_size

def rerank_batch(
self,
requests: List[Request],
rank_start: int = 0,
rank_end: int = 100,
shuffle_candidates: bool = False,
logging: bool = False,
**kwargs: Any,
) -> List[Result]:
self._enumerated_indices = []

rerank_results = [
Result(
query=copy.deepcopy(request.query),
candidates=copy.deepcopy(request.candidates),
ranking_exec_summary=[],
)
for request in requests
]

for result in rerank_results:
for i in result.candidates:
i.score = 0

for index in range(len(requests) * len(requests[0].candidates) * len(requests[0].candidates)):
candidate_1 = math.floor(
(index % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates)
)
candidate_2 = index % len(rerank_results[0].candidates)
if candidate_1 != candidate_2:
self._enumerated_indices.append(index)

end = (len(rerank_results[0].candidates) - 1) * len(rerank_results[0].candidates) * len(requests)
with tqdm(total=end, desc="Progress through (q, d) pairs") as progress_bar:
for index in range(0, end, self._batch_size):
prompts, token_counts = self.create_prompt_batched(
results=rerank_results, index=index
)

outputs, output_tokens, scores = self.run_llm_batched(prompts=prompts)

for update_index in range (
index,
min(
index + self._batch_size,
end
)
):
update_index_copy = self._enumerated_indices[update_index]
query_number = math.floor(
update_index_copy / (len(rerank_results[0].candidates) ** 2)
)
candidate_1 = math.floor(
(update_index_copy % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates)
)
candidate_2 = update_index_copy % len(rerank_results[0].candidates)

rerank_results[query_number].candidates[candidate_1].score += scores[update_index - index]
rerank_results[query_number].candidates[candidate_2].score += 1 - scores[update_index - index]

if index + self._batch_size > end:
progress_bar.update(end - index)
else:
progress_bar.update(self._batch_size)


for result in rerank_results:
result.candidates.sort(
key=cmp_to_key(self.candidate_comparator), reverse=True
)

return rerank_results

def create_prompt_batched(
self, results: List[Result], index
) -> Tuple[List[str], List[int]]:
prompts = []
token_counts = []

for current_index in range(
index,
min(index + self._batch_size, len(results[0].candidates) * (len(results[0].candidates) - 1) * len(results)),
):
current_index = self._enumerated_indices[current_index]
query_number = math.floor(
current_index / (len(results[0].candidates) ** 2)
)
candidate_1 = math.floor(
(current_index % (len(results[0].candidates) ** 2)) / len(results[0].candidates)
)
candidate_2 = current_index % len(results[0].candidates)

prompt, token_count = self.create_prompt(
result=results[query_number], index1=candidate_1, index2=candidate_2
)

prompts.append(prompt)
token_counts.append(token_count)
return prompts, token_counts

def get_output_filename(
self,
top_k_candidates: int,
dataset_name: str,
shuffle_candidates: bool,
**kwargs: Any,
) -> str:
if self._filename != "":
return self._filename
_modelname = self._model.split("/")[-1]
if _modelname.startswith("checkpoint"):
_modelname = self._model.split("/")[-2] + "_" + _modelname
name = (
f"{_modelname}_{self._context_size}_{top_k_candidates}_{self._prompt_mode}"
)
if dataset_name:
name = f"{name}_{dataset_name}"

if shuffle_candidates:
self._filename = f"{name}_shuffled_{datetime.isoformat(datetime.now())}"
else:
self._filename = f"{name}_{datetime.isoformat(datetime.now())}"

return (
f"{name}_shuffled_{datetime.isoformat(datetime.now())}"
if shuffle_candidates
else f"{name}_{datetime.isoformat(datetime.now())}"
)

def candidate_comparator(self, x: Candidate, y: Candidate) -> int:
if x.score < y.score:
return -1
elif x.score > y.score:
return 1
else:
return 0

def _replace_number(self, s: str) -> str:
return re.sub(r"\[(\d+)\]", r"(\1)", s)

def convert_doc_to_prompt_content(
self, doc: Dict[str, Any], max_length: int
) -> str:
if "text" in doc:
content = doc["text"]
elif "segment" in doc:
content = doc["segment"]
elif "contents" in doc:
content = doc["contents"]
elif "content" in doc:
content = doc["content"]
elif "body" in doc:
content = doc["body"]
else:
content = doc["passage"]
if "title" in doc and doc["title"]:
content = "Title: " + doc["title"] + " " + "Content: " + content
content = content.strip()
content = fix_text(content)
# For Japanese should cut by character: content = content[:int(max_length)]
content = " ".join(content.split()[: int(max_length)])
return self._replace_number(content)
4 changes: 2 additions & 2 deletions src/rank_llm/rerank/pointwise/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .pointwise_rankllm import PointwiseRankLLM
from .monot5 import MonoT5

__all__ = ["PointwiseRankLLM"]
__all__ = ["MonoT5"]
Loading