From fef2f5055908f2f43a1befcd58f3b3a66fd000df Mon Sep 17 00:00:00 2001 From: Ryan Nguyen Date: Sun, 25 Aug 2024 20:35:56 +0000 Subject: [PATCH 1/7] fix --- src/rank_llm/rerank/pairwise/__init__.py | 0 src/rank_llm/rerank/pairwise/duot5.py | 208 ++++++++++++++++++ .../rerank/pairwise/pairwise_rankllm.py | 20 ++ 3 files changed, 228 insertions(+) create mode 100644 src/rank_llm/rerank/pairwise/__init__.py create mode 100644 src/rank_llm/rerank/pairwise/duot5.py create mode 100644 src/rank_llm/rerank/pairwise/pairwise_rankllm.py diff --git a/src/rank_llm/rerank/pairwise/__init__.py b/src/rank_llm/rerank/pairwise/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/rank_llm/rerank/pairwise/duot5.py b/src/rank_llm/rerank/pairwise/duot5.py new file mode 100644 index 00000000..8474f9b1 --- /dev/null +++ b/src/rank_llm/rerank/pairwise/duot5.py @@ -0,0 +1,208 @@ +import logging +import math +from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import cmp_to_key +from typing import Dict, List, Optional, Tuple + +import torch +from tqdm import tqdm +from transformers import T5ForConditionalGeneration, T5Tokenizer +from transformers.generation import GenerationConfig + +from rank_llm.data import Candidate, Result +from rank_llm.rerank.pairwise.pairwise_rankllm import PairwiseRankLLM + +logger = logging.getLogger(__name__) + + +class DuoT5(PairwiseRankLLM): + def __init__( + self, + model: str, + device: str = "cuda", + window_size: int = 20, + batched: bool = False, + ): + super.__init(model, device, window_size, batched) + self._tokenizer = T5Tokenizer.from_pretrained("castorini/duot5-base-msmarco") + self._llm = T5ForConditionalGeneration.from_pretrained( + "castorini/duot5-base-msmarco" + ).to(self._device) + + def run_llm_batched( + self, + prompts: List[str | List[Dict[str, str]]], + current_window_size: Optional[int] = None, + ) -> List[Tuple[str, int]]: + if SamplingParams is None: + raise ImportError( + "Please install rank-llm with `pip install rank-llm[vllm]` to use batch inference." + ) + logger.info(f"VLLM Generating!") + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=self.num_output_tokens(current_window_size), + min_tokens=self.num_output_tokens(current_window_size), + ) + outputs = self._llm.generate(prompts, sampling_params) + + return [ + (output.outputs[0].text, len(output.outputs[0].token_ids)) + for output in outputs + ] + + def run_llm( + self, prompt: str, current_window_size: Optional[int] = None + ) -> Tuple[str, int, float]: + # CHANGE THIS CODE + if current_window_size is None: + current_window_size = self._window_size + inputs = self._tokenizer([prompt]) + inputs = {k: torch.tensor(v).to(self._device) for k, v in inputs.items()} + gen_cfg = GenerationConfig.from_model_config(self._llm.config) + gen_cfg.max_new_tokens = self.num_output_tokens() + gen_cfg.min_new_tokens = self.num_output_tokens() + gen_cfg.decoder_start_token_id = None + gen_cfg.output_scores = True + gen_cfg.return_dict_in_generate = True + # gen_cfg.temperature = 0 + gen_cfg.do_sample = False + token_prompt = self._tokenizer.encode(prompt, return_tensors="pt").to( + self._device + ) + output = self._llm.generate(token_prompt, generation_config=gen_cfg) + output_ids = output.sequences + logits = output.scores + + if self._llm.config.is_encoder_decoder: + output_ids = output_ids[0] + output_ids = output_ids[1:] + + outputs = self._tokenizer.decode( + output_ids, skip_special_tokens=True, spaces_between_special_tokens=False + ) + truth_logit = logits[0][0][1176] + false_logit = logits[0][0][6136] + score = math.exp(truth_logit) / (math.exp(truth_logit) + math.exp(false_logit)) + # print(outputs, output_ids.size(0)) + return outputs, output_ids.size(0), score + + def num_output_tokens(self, current_window_size: Optional[int] = None) -> int: + return 1 + + def _add_prefix_prompt(self, query: str, num: int) -> str: + return f"Given the query: {query}, output its relevance to the {num} documents." + + def _add_post_prompt(self, query: str, num: int) -> str: + return f"Given the query: {query}, output its relevance to the {num} documents." + + def _add_few_shot_examples(self, conv): + return 1 + # unused for now + + def create_prompt( + self, result: Result, rank_start: int, rank_end: int + ) -> Tuple[str, int]: + # query = result.query.text + # query = self._replace_number(query) + # input = f"Query: {query} Document: {result.candidates[rank_start].doc['contents']}" + # prompt = self._tokenizer.decode(self._tokenizer.encode(input)[:480])[:-4] + " Relevant: " + # prompt = prompt.replace("","") + + # CHANGE THIS CODE + query = result.query.text + query = self._replace_number(query) + doc1 = result.candidates[rank_start].doc["contents"] + doc2 = result.candidates[rank_end].doc["contents"] + doc1 = self._tokenizer.decode(self._tokenizer.encode(doc1)[:240])[:-4] + doc2 = self._tokenizer.decode(self._tokenizer.encode(doc2)[:240])[:-4] + prompt = f"Query: {query} Document0: {doc1} Document1: {doc2} Relevant:" + prompt = prompt.replace("", "") + return prompt, self.get_num_tokens(prompt) + + def create_prompt_batched( + self, + results: List[Result], + rank_start: int, + rank_end: int, + batch_size: int = 32, + ) -> List[Tuple[str, int]]: + def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + all_completed_prompts = [] + + with ThreadPoolExecutor() as executor: + for batch in tqdm(chunks(results, batch_size), desc="Processing batches"): + futures = [ + executor.submit(self.create_prompt, result, rank_start, rank_end) + for result in batch + ] + completed_prompts = [ + future.result() for future in as_completed(futures) + ] + all_completed_prompts.extend(completed_prompts) + return all_completed_prompts + + def get_num_tokens(self, prompt: str) -> int: + return len(self._tokenizer.encode(prompt)) + + def cost_per_1k_token(self, input_token: bool) -> float: + return 0 + + def candidate_comparator(self, x: Candidate, y: Candidate) -> int: + if x.score < y.score: + return -1 + elif x.score > y.score: + return 1 + else: + return 0 + + def permutation_pipeline( + self, + result: Result, + rank_start: int, + rank_end: int, + logging: bool = False, + ) -> Result: + """ + Runs the permutation pipeline on the passed in result set within the passed in rank range. + + Args: + result (Result): The result object to process. + rank_start (int): The start index for ranking. + rank_end (int): The end index for ranking. + logging (bool, optional): Flag to enable logging of operations. Defaults to False. + + Returns: + Result: The processed result object after applying permutation. + """ + # CHANGE THIS CODE + # print(len(result.candidates)) + # for i in range (len(result.candidates)): + # prompt, num_tokens = self.create_prompt(result, i, rank_end) + # output, output_num_tokens, score = self.run_llm(prompt=prompt) + # (result.candidates[i]).score = score + + # result.candidates.sort(key=cmp_to_key(self.candidate_comparator)) + n = len(result.candidates) + scores = [0 for _ in range(n)] + for i in range(n): + for j in range(n): + if j == i: + continue + else: + prompt1, num_tokens1 = self.create_prompt(result, i, j) + prompt2, num_tokens2 = self.create_prompt(result, j, i) + _, _, pi_j = self.run_llm(prompt=prompt1) + _, _, pj_i = self.run_llm(prompt=prompt2) + scores[i] = scores[i] + pi_j + 1 - pj_i + + for i in range(n): + (result.candidates[i]).score = scores[i] + + result.candidates.sort(key=cmp_to_key(self.candidate_comparator)) + + return result diff --git a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py new file mode 100644 index 00000000..90242c7e --- /dev/null +++ b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py @@ -0,0 +1,20 @@ +import logging +from abc import ABC + +from rank_llm.rerank.rankllm import RankLLM + +logger = logging.getLogger(__name__) + + +class PairwiseRankLLM(RankLLM, ABC): + def __init__( + self, + model: str, + device: str = "cuda", + window_size: int = 20, + batched: bool = False, + ) -> None: + super.__init__(model) + self._window_size = window_size + self._device = device + self._batched = batched From 9cb96123f459a7f182d13fcfda2d33f332b880aa Mon Sep 17 00:00:00 2001 From: Ryan Nguyen Date: Sun, 25 Aug 2024 21:09:38 +0000 Subject: [PATCH 2/7] inits, cleanup --- .../rerank/listwise/listwise_rankllm.py | 2 + src/rank_llm/rerank/pairwise/__init__.py | 3 + src/rank_llm/rerank/pairwise/duot5.py | 72 ++++++++-------- .../rerank/pairwise/pairwise_rankllm.py | 37 ++++++--- src/rank_llm/rerank/pointwise/__init__.py | 4 +- .../rerank/pointwise/pointwise_rankllm.py | 25 ++++-- src/rank_llm/rerank/rankllm.py | 82 +++++++++++-------- 7 files changed, 132 insertions(+), 93 deletions(-) diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index 96495df8..09f0e0b0 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -17,6 +17,8 @@ class ListwiseRankLLM(RankLLM, ABC): """ + Abstract base class that all listwise rerankers inherit. + All children of ListwiseRankLLM must implement these functions: - rerank_batched - run_llm_batched diff --git a/src/rank_llm/rerank/pairwise/__init__.py b/src/rank_llm/rerank/pairwise/__init__.py index e69de29b..99042941 100644 --- a/src/rank_llm/rerank/pairwise/__init__.py +++ b/src/rank_llm/rerank/pairwise/__init__.py @@ -0,0 +1,3 @@ +from .duot5 import DuoT5 + +__all__ = ["DuoT5"] diff --git a/src/rank_llm/rerank/pairwise/duot5.py b/src/rank_llm/rerank/pairwise/duot5.py index 8474f9b1..e77f14a9 100644 --- a/src/rank_llm/rerank/pairwise/duot5.py +++ b/src/rank_llm/rerank/pairwise/duot5.py @@ -2,15 +2,16 @@ import math from concurrent.futures import ThreadPoolExecutor, as_completed from functools import cmp_to_key -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from tqdm import tqdm from transformers import T5ForConditionalGeneration, T5Tokenizer from transformers.generation import GenerationConfig -from rank_llm.data import Candidate, Result +from rank_llm.data import Candidate, Request, Result from rank_llm.rerank.pairwise.pairwise_rankllm import PairwiseRankLLM +from rank_llm.rerank.rankllm import PromptMode logger = logging.getLogger(__name__) @@ -19,37 +20,32 @@ class DuoT5(PairwiseRankLLM): def __init__( self, model: str, - device: str = "cuda", - window_size: int = 20, - batched: bool = False, + context_size: int, + prompt_mode: PromptMode, ): - super.__init(model, device, window_size, batched) + super.__init__(model, context_size, prompt_mode) self._tokenizer = T5Tokenizer.from_pretrained("castorini/duot5-base-msmarco") self._llm = T5ForConditionalGeneration.from_pretrained( "castorini/duot5-base-msmarco" ).to(self._device) - def run_llm_batched( + # TODO + def rerank_batch( self, - prompts: List[str | List[Dict[str, str]]], - current_window_size: Optional[int] = None, - ) -> List[Tuple[str, int]]: - if SamplingParams is None: - raise ImportError( - "Please install rank-llm with `pip install rank-llm[vllm]` to use batch inference." - ) - logger.info(f"VLLM Generating!") - sampling_params = SamplingParams( - temperature=0.0, - max_tokens=self.num_output_tokens(current_window_size), - min_tokens=self.num_output_tokens(current_window_size), - ) - outputs = self._llm.generate(prompts, sampling_params) + requests: List[Request], + rank_start: int = 0, + rank_end: int = 100, + shuffle_candidates: bool = False, + logging: bool = False, + **kwargs: logging.Any, + ) -> List[Result]: + return - return [ - (output.outputs[0].text, len(output.outputs[0].token_ids)) - for output in outputs - ] + # TODO + def run_llm_batched( + self, prompts: List[str | List[torch.Dict[str, str]]], **kwargs + ) -> List[Tuple[str | int]]: + return def run_llm( self, prompt: str, current_window_size: Optional[int] = None @@ -87,19 +83,6 @@ def run_llm( # print(outputs, output_ids.size(0)) return outputs, output_ids.size(0), score - def num_output_tokens(self, current_window_size: Optional[int] = None) -> int: - return 1 - - def _add_prefix_prompt(self, query: str, num: int) -> str: - return f"Given the query: {query}, output its relevance to the {num} documents." - - def _add_post_prompt(self, query: str, num: int) -> str: - return f"Given the query: {query}, output its relevance to the {num} documents." - - def _add_few_shot_examples(self, conv): - return 1 - # unused for now - def create_prompt( self, result: Result, rank_start: int, rank_end: int ) -> Tuple[str, int]: @@ -152,6 +135,9 @@ def get_num_tokens(self, prompt: str) -> int: def cost_per_1k_token(self, input_token: bool) -> float: return 0 + def num_output_tokens(self, current_window_size: Optional[int] = None) -> int: + return 1 + def candidate_comparator(self, x: Candidate, y: Candidate) -> int: if x.score < y.score: return -1 @@ -160,6 +146,16 @@ def candidate_comparator(self, x: Candidate, y: Candidate) -> int: else: return 0 + def _add_prefix_prompt(self, query: str, num: int) -> str: + return f"Given the query: {query}, output its relevance to the {num} documents." + + def _add_post_prompt(self, query: str, num: int) -> str: + return f"Given the query: {query}, output its relevance to the {num} documents." + + def _add_few_shot_examples(self, conv): + return 1 + # unused for now + def permutation_pipeline( self, result: Result, diff --git a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py index 90242c7e..38d7da27 100644 --- a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py +++ b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py @@ -1,20 +1,35 @@ import logging from abc import ABC -from rank_llm.rerank.rankllm import RankLLM +from rank_llm.rerank.rankllm import PromptMode, RankLLM logger = logging.getLogger(__name__) class PairwiseRankLLM(RankLLM, ABC): - def __init__( + """ + Abstract base class that all pairwise rerankers implement. + + All concrete children of RankLLM must implement these functions: + - rerank_batch + - run_llm_batched + - run_llm + - create_prompt_batched + - create_prompt + - get_num_tokens + - cost_per_1k_tokens + - num_output_tokens + """ + + def __init__(self, model: str, context_size: int, prompt_mode: PromptMode) -> None: + super.__init__(model, context_size, prompt_mode) + + # TODO + def get_output_filename( self, - model: str, - device: str = "cuda", - window_size: int = 20, - batched: bool = False, - ) -> None: - super.__init__(model) - self._window_size = window_size - self._device = device - self._batched = batched + top_k_candidates: int, + dataset_name: str, + shuffle_candidates: bool, + **kwargs: logging.Any, + ) -> str: + return diff --git a/src/rank_llm/rerank/pointwise/__init__.py b/src/rank_llm/rerank/pointwise/__init__.py index 2d622682..28403f7b 100644 --- a/src/rank_llm/rerank/pointwise/__init__.py +++ b/src/rank_llm/rerank/pointwise/__init__.py @@ -1,3 +1,3 @@ -from .pointwise_rankllm import PointwiseRankLLM +from .monot5 import MonoT5 -__all__ = ["PointwiseRankLLM"] +__all__ = ["MonoT5"] diff --git a/src/rank_llm/rerank/pointwise/pointwise_rankllm.py b/src/rank_llm/rerank/pointwise/pointwise_rankllm.py index 40178d7a..674e6857 100644 --- a/src/rank_llm/rerank/pointwise/pointwise_rankllm.py +++ b/src/rank_llm/rerank/pointwise/pointwise_rankllm.py @@ -18,8 +18,15 @@ class PointwiseRankLLM(RankLLM, ABC): """ + Abstract base class that all pointwise rerankers implement. + All children of PointwiseRankLLM must implement these functions: - - currently all abstract functions of RankLLM + - run_llm_batched + - run_llm + - create_prompt + - get_num_tokens + - cost_per_1k_tokens + - num_output_tokens """ @@ -114,14 +121,6 @@ def create_prompt_batched( return prompts, token_counts - def candidate_comparator(self, x: Candidate, y: Candidate) -> int: - if x.score < y.score: - return -1 - elif x.score > y.score: - return 1 - else: - return 0 - def get_output_filename( self, top_k_candidates: int, @@ -151,6 +150,14 @@ def get_output_filename( else f"{name}_{datetime.isoformat(datetime.now())}" ) + def candidate_comparator(self, x: Candidate, y: Candidate) -> int: + if x.score < y.score: + return -1 + elif x.score > y.score: + return 1 + else: + return 0 + def _replace_number(self, s: str) -> str: return re.sub(r"\[(\d+)\]", r"(\1)", s) diff --git a/src/rank_llm/rerank/rankllm.py b/src/rank_llm/rerank/rankllm.py index 7df7b1c8..4b3e911d 100644 --- a/src/rank_llm/rerank/rankllm.py +++ b/src/rank_llm/rerank/rankllm.py @@ -21,11 +21,60 @@ def __str__(self): class RankLLM(ABC): + """ + Abstract base class that all rerankers inherit. + + All concrete children of RankLLM must implement these functions: + - rerank_batch + - run_llm_batched + - run_llm + - create_prompt_batched + - create_prompt + - get_num_tokens + - cost_per_1k_tokens + - num_output_tokens + - get_output_filename + + """ + def __init__(self, model: str, context_size: int, prompt_mode: PromptMode) -> None: self._model = model self._context_size = context_size self._prompt_mode = prompt_mode + @abstractmethod + def rerank_batch( + self, + requests: List[Request], + rank_start: int = 0, + rank_end: int = 100, + shuffle_candidates: bool = False, + logging: bool = False, + **kwargs: Any, + ) -> List[Result]: + """ + Reranks a list of requests using the RankLLM agent. + + This function applies a sliding window algorithm to rerank the results. + Each window of results is processed by the RankLLM agent to obtain a new ranking. + + Args: + requests (List[Request]): The list of requests. Each request has a query and a candidates list. + rank_start (int, optional): The starting rank for processing. Defaults to 0. + rank_end (int, optional): The end rank for processing. Defaults to 100. + window_size (int, optional): The size of each sliding window. Defaults to 20. + step (int, optional): The step size for moving the window. Defaults to 10. + shuffle_candidates (bool, optional): Whether to shuffle candidates before reranking. Defaults to False. + logging (bool, optional): Enables logging of the reranking process. Defaults to False. + vllm_batched (bool, optional): Whether to use VLLM batched processing. Defaults to False. + populate_exec_summary (bool, optional): Whether to populate the exec summary. Defaults to False. + batched (bool, optional): Whether to use batched processing. Defaults to False. + + Returns: + List[Result]: A list containing the reranked candidates. + """ + pass + @abstractmethod def run_llm_batched( self, prompts: List[Union[str, List[Dict[str, str]]]], **kwargs @@ -126,39 +175,6 @@ def num_output_tokens(self) -> int: """ pass - @abstractmethod - def rerank_batch( - self, - requests: List[Request], - rank_start: int = 0, - rank_end: int = 100, - shuffle_candidates: bool = False, - logging: bool = False, - **kwargs: Any, - ) -> List[Result]: - """ - Reranks a list of requests using the RankLLM agent. - - This function applies a sliding window algorithm to rerank the results. - Each window of results is processed by the RankLLM agent to obtain a new ranking. - - Args: - requests (List[Request]): The list of requests. Each request has a query and a candidates list. - rank_start (int, optional): The starting rank for processing. Defaults to 0. - rank_end (int, optional): The end rank for processing. Defaults to 100. - window_size (int, optional): The size of each sliding window. Defaults to 20. - step (int, optional): The step size for moving the window. Defaults to 10. - shuffle_candidates (bool, optional): Whether to shuffle candidates before reranking. Defaults to False. - logging (bool, optional): Enables logging of the reranking process. Defaults to False. - vllm_batched (bool, optional): Whether to use VLLM batched processing. Defaults to False. - populate_exec_summary (bool, optional): Whether to populate the exec summary. Defaults to False. - batched (bool, optional): Whether to use batched processing. Defaults to False. - - Returns: - List[Result]: A list containing the reranked candidates. - """ - pass - @abstractmethod def get_output_filename( self, From 912de1aa29d37ca312a52ceef9c063acba612362 Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Thu, 29 Aug 2024 13:58:33 -0400 Subject: [PATCH 3/7] duot5 and pairwise implementation --- src/rank_llm/rerank/pairwise/duot5.py | 224 ++++++------------ .../rerank/pairwise/pairwise_rankllm.py | 173 +++++++++++++- src/rank_llm/rerank/rankllm.py | 83 +++---- src/rank_llm/rerank/reranker.py | 27 +++ src/rank_llm/retrieve_and_rerank.py | 4 +- 5 files changed, 307 insertions(+), 204 deletions(-) diff --git a/src/rank_llm/rerank/pairwise/duot5.py b/src/rank_llm/rerank/pairwise/duot5.py index e77f14a9..5c901318 100644 --- a/src/rank_llm/rerank/pairwise/duot5.py +++ b/src/rank_llm/rerank/pairwise/duot5.py @@ -1,17 +1,12 @@ import logging import math -from concurrent.futures import ThreadPoolExecutor, as_completed -from functools import cmp_to_key -from typing import List, Optional, Tuple +from typing import List, Tuple -import torch -from tqdm import tqdm from transformers import T5ForConditionalGeneration, T5Tokenizer from transformers.generation import GenerationConfig -from rank_llm.data import Candidate, Request, Result +from rank_llm.data import Result from rank_llm.rerank.pairwise.pairwise_rankllm import PairwiseRankLLM -from rank_llm.rerank.rankllm import PromptMode logger = logging.getLogger(__name__) @@ -20,49 +15,81 @@ class DuoT5(PairwiseRankLLM): def __init__( self, model: str, - context_size: int, - prompt_mode: PromptMode, + prompt_mode: str = "duot5", + context_size: int = 512, + device: str = "cuda", + batch_size: int = 32, ): - super.__init__(model, context_size, prompt_mode) - self._tokenizer = T5Tokenizer.from_pretrained("castorini/duot5-base-msmarco") - self._llm = T5ForConditionalGeneration.from_pretrained( - "castorini/duot5-base-msmarco" - ).to(self._device) + super().__init__( + model=model, + context_size=context_size, + prompt_mode=prompt_mode, + device=device, + batch_size=batch_size, + ) + + self._tokenizer = T5Tokenizer.from_pretrained(model) + self._llm = T5ForConditionalGeneration.from_pretrained(model).to(self._device) + self._context_size = context_size - # TODO - def rerank_batch( - self, - requests: List[Request], - rank_start: int = 0, - rank_end: int = 100, - shuffle_candidates: bool = False, - logging: bool = False, - **kwargs: logging.Any, - ) -> List[Result]: - return - - # TODO def run_llm_batched( - self, prompts: List[str | List[torch.Dict[str, str]]], **kwargs - ) -> List[Tuple[str | int]]: - return - - def run_llm( - self, prompt: str, current_window_size: Optional[int] = None - ) -> Tuple[str, int, float]: - # CHANGE THIS CODE - if current_window_size is None: - current_window_size = self._window_size - inputs = self._tokenizer([prompt]) - inputs = {k: torch.tensor(v).to(self._device) for k, v in inputs.items()} + self, + prompts: List[str], + ) -> Tuple[List[str], List[int], List[float]]: gen_cfg = GenerationConfig.from_model_config(self._llm.config) gen_cfg.max_new_tokens = self.num_output_tokens() gen_cfg.min_new_tokens = self.num_output_tokens() - gen_cfg.decoder_start_token_id = None gen_cfg.output_scores = True gen_cfg.return_dict_in_generate = True - # gen_cfg.temperature = 0 gen_cfg.do_sample = False + + all_outputs = [] + all_output_token_counts = [] + all_scores = [] + + batch_prompts = prompts + + token_prompts = self._tokenizer( + batch_prompts, padding=True, truncation=True, return_tensors="pt" + ).to(self._device) + + token_prompts = token_prompts["input_ids"] + + batch_outputs = self._llm.generate(token_prompts, generation_config=gen_cfg) + + batch_output_ids = batch_outputs.sequences + batch_logits = batch_outputs.scores + + batch_outputs = [ + self._tokenizer.decode( + single_token_sequence, + skip_special_tokens=True, + spaces_between_special_tokens=False, + ) + for single_token_sequence in batch_output_ids + ] + + for logit_tensor in batch_logits[0]: + truth_logit = logit_tensor[1176] + false_logit = logit_tensor[6136] + score = math.exp(truth_logit) / ( + math.exp(truth_logit) + math.exp(false_logit) + ) + all_scores.append(score) + all_output_token_counts.append(self.num_output_tokens) + + all_outputs.extend(batch_outputs) + + return all_outputs, all_output_token_counts, all_scores + + def run_llm(self, prompt: str) -> Tuple[str, int, float]: + gen_cfg = GenerationConfig.from_model_config(self._llm.config) + gen_cfg.max_new_tokens = self.num_output_tokens() + gen_cfg.min_new_tokens = self.num_output_tokens() + gen_cfg.output_scores = True + gen_cfg.return_dict_in_generate = True + gen_cfg.do_sample = False + token_prompt = self._tokenizer.encode(prompt, return_tensors="pt").to( self._device ) @@ -80,125 +107,26 @@ def run_llm( truth_logit = logits[0][0][1176] false_logit = logits[0][0][6136] score = math.exp(truth_logit) / (math.exp(truth_logit) + math.exp(false_logit)) - # print(outputs, output_ids.size(0)) + return outputs, output_ids.size(0), score - def create_prompt( - self, result: Result, rank_start: int, rank_end: int - ) -> Tuple[str, int]: - # query = result.query.text - # query = self._replace_number(query) - # input = f"Query: {query} Document: {result.candidates[rank_start].doc['contents']}" - # prompt = self._tokenizer.decode(self._tokenizer.encode(input)[:480])[:-4] + " Relevant: " - # prompt = prompt.replace("","") + def num_output_tokens(self) -> int: + return 1 - # CHANGE THIS CODE + def create_prompt(self, result: Result, index1: int, index2: int) -> Tuple[str, int]: query = result.query.text query = self._replace_number(query) - doc1 = result.candidates[rank_start].doc["contents"] - doc2 = result.candidates[rank_end].doc["contents"] + doc1 = self.convert_doc_to_prompt_content(result.candidates[index1].doc, max_length=self._context_size) + doc2 = self.convert_doc_to_prompt_content(result.candidates[index2].doc, max_length=self._context_size) doc1 = self._tokenizer.decode(self._tokenizer.encode(doc1)[:240])[:-4] doc2 = self._tokenizer.decode(self._tokenizer.encode(doc2)[:240])[:-4] prompt = f"Query: {query} Document0: {doc1} Document1: {doc2} Relevant:" - prompt = prompt.replace("", "") + prompt = prompt.replace("","") + return prompt, self.get_num_tokens(prompt) - def create_prompt_batched( - self, - results: List[Result], - rank_start: int, - rank_end: int, - batch_size: int = 32, - ) -> List[Tuple[str, int]]: - def chunks(lst, n): - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i : i + n] - - all_completed_prompts = [] - - with ThreadPoolExecutor() as executor: - for batch in tqdm(chunks(results, batch_size), desc="Processing batches"): - futures = [ - executor.submit(self.create_prompt, result, rank_start, rank_end) - for result in batch - ] - completed_prompts = [ - future.result() for future in as_completed(futures) - ] - all_completed_prompts.extend(completed_prompts) - return all_completed_prompts - def get_num_tokens(self, prompt: str) -> int: return len(self._tokenizer.encode(prompt)) def cost_per_1k_token(self, input_token: bool) -> float: return 0 - - def num_output_tokens(self, current_window_size: Optional[int] = None) -> int: - return 1 - - def candidate_comparator(self, x: Candidate, y: Candidate) -> int: - if x.score < y.score: - return -1 - elif x.score > y.score: - return 1 - else: - return 0 - - def _add_prefix_prompt(self, query: str, num: int) -> str: - return f"Given the query: {query}, output its relevance to the {num} documents." - - def _add_post_prompt(self, query: str, num: int) -> str: - return f"Given the query: {query}, output its relevance to the {num} documents." - - def _add_few_shot_examples(self, conv): - return 1 - # unused for now - - def permutation_pipeline( - self, - result: Result, - rank_start: int, - rank_end: int, - logging: bool = False, - ) -> Result: - """ - Runs the permutation pipeline on the passed in result set within the passed in rank range. - - Args: - result (Result): The result object to process. - rank_start (int): The start index for ranking. - rank_end (int): The end index for ranking. - logging (bool, optional): Flag to enable logging of operations. Defaults to False. - - Returns: - Result: The processed result object after applying permutation. - """ - # CHANGE THIS CODE - # print(len(result.candidates)) - # for i in range (len(result.candidates)): - # prompt, num_tokens = self.create_prompt(result, i, rank_end) - # output, output_num_tokens, score = self.run_llm(prompt=prompt) - # (result.candidates[i]).score = score - - # result.candidates.sort(key=cmp_to_key(self.candidate_comparator)) - n = len(result.candidates) - scores = [0 for _ in range(n)] - for i in range(n): - for j in range(n): - if j == i: - continue - else: - prompt1, num_tokens1 = self.create_prompt(result, i, j) - prompt2, num_tokens2 = self.create_prompt(result, j, i) - _, _, pi_j = self.run_llm(prompt=prompt1) - _, _, pj_i = self.run_llm(prompt=prompt2) - scores[i] = scores[i] + pi_j + 1 - pj_i - - for i in range(n): - (result.candidates[i]).score = scores[i] - - result.candidates.sort(key=cmp_to_key(self.candidate_comparator)) - - return result diff --git a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py index 38d7da27..5f4913ec 100644 --- a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py +++ b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py @@ -1,11 +1,20 @@ +import copy import logging +import math +import re from abc import ABC +from datetime import datetime +from functools import cmp_to_key +from typing import Any, Dict, List, Tuple +from ftfy import fix_text +from tqdm import tqdm + +from rank_llm.data import Candidate, Request, Result from rank_llm.rerank.rankllm import PromptMode, RankLLM logger = logging.getLogger(__name__) - class PairwiseRankLLM(RankLLM, ABC): """ Abstract base class that all pairwise rerankers implement. @@ -21,15 +30,167 @@ class PairwiseRankLLM(RankLLM, ABC): - num_output_tokens """ - def __init__(self, model: str, context_size: int, prompt_mode: PromptMode) -> None: - super.__init__(model, context_size, prompt_mode) + def __init__( + self, + model: str, + context_size: int, + prompt_mode: PromptMode, + device: str = "cuda", + filename: str = "", + batch_size: int = 32, + ) -> None: + super().__init__(model, context_size, prompt_mode) + self._device = device + self._filename = filename + self._batch_size = batch_size + + def rerank_batch( + self, + requests: List[Request], + rank_start: int = 0, + rank_end: int = 100, + shuffle_candidates: bool = False, + logging: bool = False, + **kwargs: Any, + ) -> List[Result]: + rerank_results = [ + Result( + query=copy.deepcopy(request.query), + candidates=copy.deepcopy(request.candidates), + ranking_exec_summary=[], + ) + for request in requests + ] + + for result in rerank_results: + for i in result.candidates: + i.score = 0 + + + end = len(rerank_results[0].candidates) * len(rerank_results[0].candidates) * len(requests) + with tqdm(total=end, desc="Progress through (q, d) pairs") as progress_bar: + for index in range(0, end, self._batch_size): + prompts, token_counts = self.create_prompt_batched( + results=rerank_results, index=index + ) + + outputs, output_tokens, scores = self.run_llm_batched(prompts=prompts) + + for update_index in range ( + index, + min( + index + self._batch_size, + end + ) + ): + query_number = math.floor( + update_index / (len(rerank_results[0].candidates) ** 2) + ) + candidate_1 = math.floor( + (update_index % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates) + ) + candidate_2 = update_index % len(rerank_results[0].candidates) + rerank_results[query_number].candidates[candidate_1].score += scores[update_index - index] + rerank_results[query_number].candidates[candidate_2].score += 1 - scores[update_index - index] + + if index + self._batch_size > end: + progress_bar.update(end - index) + else: + progress_bar.update(self._batch_size) + + + for result in rerank_results: + result.candidates.sort( + key=cmp_to_key(self.candidate_comparator), reverse=True + ) + + return rerank_results + + def create_prompt_batched( + self, results: List[Result], index + ) -> Tuple[List[str], List[int]]: + prompts = [] + token_counts = [] + + for index in range( + index, + min(index + self._batch_size, len(results[0].candidates) * len(results)), + ): + query_number = math.floor( + index / (len(results[0].candidates) ** 2) + ) + candidate_1 = math.floor( + (index % (len(results[0].candidates) ** 2)) / len(results[0].candidates) + ) + candidate_2 = index % len(results[0].candidates) + + prompt, token_count = self.create_prompt( + result=results[query_number], index1=candidate_1, index2=candidate_2 + ) + + prompts.append(prompt) + token_counts.append(token_count) + return prompts, token_counts - # TODO def get_output_filename( self, top_k_candidates: int, dataset_name: str, shuffle_candidates: bool, - **kwargs: logging.Any, + **kwargs: Any, + ) -> str: + if self._filename != "": + return self._filename + _modelname = self._model.split("/")[-1] + if _modelname.startswith("checkpoint"): + _modelname = self._model.split("/")[-2] + "_" + _modelname + name = ( + f"{_modelname}_{self._context_size}_{top_k_candidates}_{self._prompt_mode}" + ) + if dataset_name: + name = f"{name}_{dataset_name}" + + if shuffle_candidates: + self._filename = f"{name}_shuffled_{datetime.isoformat(datetime.now())}" + else: + self._filename = f"{name}_{datetime.isoformat(datetime.now())}" + + return ( + f"{name}_shuffled_{datetime.isoformat(datetime.now())}" + if shuffle_candidates + else f"{name}_{datetime.isoformat(datetime.now())}" + ) + + def candidate_comparator(self, x: Candidate, y: Candidate) -> int: + if x.score < y.score: + return -1 + elif x.score > y.score: + return 1 + else: + return 0 + + def _replace_number(self, s: str) -> str: + return re.sub(r"\[(\d+)\]", r"(\1)", s) + + def convert_doc_to_prompt_content( + self, doc: Dict[str, Any], max_length: int ) -> str: - return + if "text" in doc: + content = doc["text"] + elif "segment" in doc: + content = doc["segment"] + elif "contents" in doc: + content = doc["contents"] + elif "content" in doc: + content = doc["content"] + elif "body" in doc: + content = doc["body"] + else: + content = doc["passage"] + if "title" in doc and doc["title"]: + content = "Title: " + doc["title"] + " " + "Content: " + content + content = content.strip() + content = fix_text(content) + # For Japanese should cut by character: content = content[:int(max_length)] + content = " ".join(content.split()[: int(max_length)]) + return self._replace_number(content) diff --git a/src/rank_llm/rerank/rankllm.py b/src/rank_llm/rerank/rankllm.py index 4b3e911d..0071e7a7 100644 --- a/src/rank_llm/rerank/rankllm.py +++ b/src/rank_llm/rerank/rankllm.py @@ -14,6 +14,7 @@ class PromptMode(Enum): RANK_GPT_APEER = "rank_GPT_APEER" LRL = "LRL" MONOT5 = "monot5" + DUOT5 = "duot5" LiT5 = "LiT5" def __str__(self): @@ -21,60 +22,11 @@ def __str__(self): class RankLLM(ABC): - """ - Abstract base class that all rerankers inherit. - - All concrete children of RankLLM must implement these functions: - - rerank_batch - - run_llm_batched - - run_llm - - create_prompt_batched - - create_prompt - - get_num_tokens - - cost_per_1k_tokens - - num_output_tokens - - get_output_filename - - """ - def __init__(self, model: str, context_size: int, prompt_mode: PromptMode) -> None: self._model = model self._context_size = context_size self._prompt_mode = prompt_mode - @abstractmethod - def rerank_batch( - self, - requests: List[Request], - rank_start: int = 0, - rank_end: int = 100, - shuffle_candidates: bool = False, - logging: bool = False, - **kwargs: Any, - ) -> List[Result]: - """ - Reranks a list of requests using the RankLLM agent. - - This function applies a sliding window algorithm to rerank the results. - Each window of results is processed by the RankLLM agent to obtain a new ranking. - - Args: - requests (List[Request]): The list of requests. Each request has a query and a candidates list. - rank_start (int, optional): The starting rank for processing. Defaults to 0. - rank_end (int, optional): The end rank for processing. Defaults to 100. - window_size (int, optional): The size of each sliding window. Defaults to 20. - step (int, optional): The step size for moving the window. Defaults to 10. - shuffle_candidates (bool, optional): Whether to shuffle candidates before reranking. Defaults to False. - logging (bool, optional): Enables logging of the reranking process. Defaults to False. - vllm_batched (bool, optional): Whether to use VLLM batched processing. Defaults to False. - populate_exec_summary (bool, optional): Whether to populate the exec summary. Defaults to False. - batched (bool, optional): Whether to use batched processing. Defaults to False. - - Returns: - List[Result]: A list containing the reranked candidates. - """ - pass - @abstractmethod def run_llm_batched( self, prompts: List[Union[str, List[Dict[str, str]]]], **kwargs @@ -175,6 +127,39 @@ def num_output_tokens(self) -> int: """ pass + @abstractmethod + def rerank_batch( + self, + requests: List[Request], + rank_start: int = 0, + rank_end: int = 100, + shuffle_candidates: bool = False, + logging: bool = False, + **kwargs: Any, + ) -> List[Result]: + """ + Reranks a list of requests using the RankLLM agent. + + This function applies a sliding window algorithm to rerank the results. + Each window of results is processed by the RankLLM agent to obtain a new ranking. + + Args: + requests (List[Request]): The list of requests. Each request has a query and a candidates list. + rank_start (int, optional): The starting rank for processing. Defaults to 0. + rank_end (int, optional): The end rank for processing. Defaults to 100. + window_size (int, optional): The size of each sliding window. Defaults to 20. + step (int, optional): The step size for moving the window. Defaults to 10. + shuffle_candidates (bool, optional): Whether to shuffle candidates before reranking. Defaults to False. + logging (bool, optional): Enables logging of the reranking process. Defaults to False. + vllm_batched (bool, optional): Whether to use VLLM batched processing. Defaults to False. + populate_exec_summary (bool, optional): Whether to populate the exec summary. Defaults to False. + batched (bool, optional): Whether to use batched processing. Defaults to False. + + Returns: + List[Result]: A list containing the reranked candidates. + """ + pass + @abstractmethod def get_output_filename( self, diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 21fa3f1d..446b033a 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -11,6 +11,7 @@ from rank_llm.rerank.listwise import RankListwiseOSLLM, SafeOpenai from rank_llm.rerank.listwise.rank_fid import RankFiDDistill, RankFiDScore from rank_llm.rerank.pointwise.monot5 import MonoT5 +from rank_llm.rerank.pairwise.duot5 import DuoT5 from rank_llm.rerank.rankllm import RankLLM @@ -282,7 +283,33 @@ def create_agent( device=device, batch_size=batch_size, ) + elif "duot5" in model_path: + # using monot5 + print(f"Loading {model_path} ...") + + model_full_paths = {"duot5": "castorini/duot5-3b-med-msmarco"} + + keys_and_defaults = [ + ("prompt_mode", PromptMode.DUOT5), + ("context_size", 512), + ("device", "cuda"), + ("batch_size", 64), + ] + [prompt_mode, context_size, device, batch_size] = extract_kwargs( + keys_and_defaults, **kwargs + ) + agent = DuoT5( + model=( + model_full_paths[model_path] + if model_path in model_full_paths + else model_path + ), + prompt_mode=prompt_mode, + context_size=context_size, + device=device, + batch_size=batch_size, + ) elif "lit5-distill" in model_path.lower(): keys_and_defaults = [ ("context_size", 150), diff --git a/src/rank_llm/retrieve_and_rerank.py b/src/rank_llm/retrieve_and_rerank.py index 992985dd..0d574fe6 100644 --- a/src/rank_llm/retrieve_and_rerank.py +++ b/src/rank_llm/retrieve_and_rerank.py @@ -53,7 +53,9 @@ def retrieve_and_rerank( dataset=dataset, **kwargs, ) - + print(top_k_retrieve) + print("Number of candidates per query: ") + print(len(requests[0].candidates)) # Reranking stages print(f"Reranking and returning {top_k_rerank} passages with {model_path}...") if reranker.get_agent() is None: From b1d7eff3df5fcc6ca0a1f6d68c339142cc7dca4e Mon Sep 17 00:00:00 2001 From: IR3KT4FUNZ Date: Mon, 2 Sep 2024 16:45:51 -0400 Subject: [PATCH 4/7] implementation for pairwise and duot5 --- src/rank_llm/rerank/pairwise/pairwise_rankllm.py | 10 ++++++++-- src/rank_llm/rerank/reranker.py | 1 + 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py index 5f4913ec..f5202d05 100644 --- a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py +++ b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py @@ -53,6 +53,11 @@ def rerank_batch( logging: bool = False, **kwargs: Any, ) -> List[Result]: + self._enumerated_indices = [] + + for index in range(len(requests) * len(requests[0].candidates) * len(requests[0].candidates)): + self._enumerated_indices.append(index) + rerank_results = [ Result( query=copy.deepcopy(request.query), @@ -66,8 +71,7 @@ def rerank_batch( for i in result.candidates: i.score = 0 - - end = len(rerank_results[0].candidates) * len(rerank_results[0].candidates) * len(requests) + end = len(rerank_results[0].candidates - 1) * len(rerank_results[0].candidates) * len(requests) with tqdm(total=end, desc="Progress through (q, d) pairs") as progress_bar: for index in range(0, end, self._batch_size): prompts, token_counts = self.create_prompt_batched( @@ -83,6 +87,7 @@ def rerank_batch( end ) ): + update_index = self._enumerated_indices[update_index] query_number = math.floor( update_index / (len(rerank_results[0].candidates) ** 2) ) @@ -116,6 +121,7 @@ def create_prompt_batched( index, min(index + self._batch_size, len(results[0].candidates) * len(results)), ): + index = self._enumerated_indices[index] query_number = math.floor( index / (len(results[0].candidates) ** 2) ) diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 446b033a..a7863660 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -294,6 +294,7 @@ def create_agent( ("context_size", 512), ("device", "cuda"), ("batch_size", 64), + ("interactive", True) ] [prompt_mode, context_size, device, batch_size] = extract_kwargs( keys_and_defaults, **kwargs From b35f2521767005276f8e15dcefa462178cd8fe0a Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Mon, 2 Sep 2024 18:43:14 -0400 Subject: [PATCH 5/7] duot5 bug fixes, add fix for retrieving <100 candidates in non-interactive cases --- .../rerank/pairwise/pairwise_rankllm.py | 23 ++++++++++--------- src/rank_llm/rerank/reranker.py | 1 - src/rank_llm/retrieve_and_rerank.py | 7 +++--- src/rank_llm/scripts/run_rank_llm.py | 8 +++++++ 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py index f5202d05..a6e1d8ba 100644 --- a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py +++ b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py @@ -71,7 +71,7 @@ def rerank_batch( for i in result.candidates: i.score = 0 - end = len(rerank_results[0].candidates - 1) * len(rerank_results[0].candidates) * len(requests) + end = (len(rerank_results[0].candidates) - 1) * len(rerank_results[0].candidates) * len(requests) with tqdm(total=end, desc="Progress through (q, d) pairs") as progress_bar: for index in range(0, end, self._batch_size): prompts, token_counts = self.create_prompt_batched( @@ -87,14 +87,15 @@ def rerank_batch( end ) ): - update_index = self._enumerated_indices[update_index] + update_index_copy = self._enumerated_indices[update_index] query_number = math.floor( - update_index / (len(rerank_results[0].candidates) ** 2) + update_index_copy / (len(rerank_results[0].candidates) ** 2) ) candidate_1 = math.floor( - (update_index % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates) + (update_index_copy % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates) ) - candidate_2 = update_index % len(rerank_results[0].candidates) + candidate_2 = update_index_copy % len(rerank_results[0].candidates) + rerank_results[query_number].candidates[candidate_1].score += scores[update_index - index] rerank_results[query_number].candidates[candidate_2].score += 1 - scores[update_index - index] @@ -117,18 +118,18 @@ def create_prompt_batched( prompts = [] token_counts = [] - for index in range( + for current_index in range( index, - min(index + self._batch_size, len(results[0].candidates) * len(results)), + min(index + self._batch_size, len(results[0].candidates) * (len(results[0].candidates) - 1) * len(results)), ): - index = self._enumerated_indices[index] + current_index = self._enumerated_indices[current_index] query_number = math.floor( - index / (len(results[0].candidates) ** 2) + current_index / (len(results[0].candidates) ** 2) ) candidate_1 = math.floor( - (index % (len(results[0].candidates) ** 2)) / len(results[0].candidates) + (current_index % (len(results[0].candidates) ** 2)) / len(results[0].candidates) ) - candidate_2 = index % len(results[0].candidates) + candidate_2 = current_index % len(results[0].candidates) prompt, token_count = self.create_prompt( result=results[query_number], index1=candidate_1, index2=candidate_2 diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index a7863660..446b033a 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -294,7 +294,6 @@ def create_agent( ("context_size", 512), ("device", "cuda"), ("batch_size", 64), - ("interactive", True) ] [prompt_mode, context_size, device, batch_size] = extract_kwargs( keys_and_defaults, **kwargs diff --git a/src/rank_llm/retrieve_and_rerank.py b/src/rank_llm/retrieve_and_rerank.py index 0d574fe6..cb904953 100644 --- a/src/rank_llm/retrieve_and_rerank.py +++ b/src/rank_llm/retrieve_and_rerank.py @@ -53,9 +53,10 @@ def retrieve_and_rerank( dataset=dataset, **kwargs, ) - print(top_k_retrieve) - print("Number of candidates per query: ") - print(len(requests[0].candidates)) + + for request in requests: + request.candidates = request.candidates[:top_k_retrieve] + # Reranking stages print(f"Reranking and returning {top_k_rerank} passages with {model_path}...") if reranker.get_agent() is None: diff --git a/src/rank_llm/scripts/run_rank_llm.py b/src/rank_llm/scripts/run_rank_llm.py index 3639525b..1ce04192 100644 --- a/src/rank_llm/scripts/run_rank_llm.py +++ b/src/rank_llm/scripts/run_rank_llm.py @@ -38,6 +38,7 @@ def main(args): window_size = args.window_size system_message = args.system_message vllm_batched = args.vllm_batched + interactive = args.interactive _ = retrieve_and_rerank( model_path=model_path, @@ -62,6 +63,7 @@ def main(args): step_size=step_size, system_message=system_message, vllm_batched=vllm_batched, + interactive=interactive, ) @@ -175,5 +177,11 @@ def main(args): action="store_true", help="whether to run the model in batches", ) + parser.add_argument( + "--interactive", + type=bool, + default=False, + help="whether retrieval is done from the server or a prebuilt index" + ) args = parser.parse_args() main(args) From 0950167458e9220397c34eda2f330b0a75b464b1 Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Mon, 2 Sep 2024 18:50:26 -0400 Subject: [PATCH 6/7] remove temporarily unnecessary interactive argument --- src/rank_llm/scripts/run_rank_llm.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/rank_llm/scripts/run_rank_llm.py b/src/rank_llm/scripts/run_rank_llm.py index 1ce04192..3639525b 100644 --- a/src/rank_llm/scripts/run_rank_llm.py +++ b/src/rank_llm/scripts/run_rank_llm.py @@ -38,7 +38,6 @@ def main(args): window_size = args.window_size system_message = args.system_message vllm_batched = args.vllm_batched - interactive = args.interactive _ = retrieve_and_rerank( model_path=model_path, @@ -63,7 +62,6 @@ def main(args): step_size=step_size, system_message=system_message, vllm_batched=vllm_batched, - interactive=interactive, ) @@ -177,11 +175,5 @@ def main(args): action="store_true", help="whether to run the model in batches", ) - parser.add_argument( - "--interactive", - type=bool, - default=False, - help="whether retrieval is done from the server or a prebuilt index" - ) args = parser.parse_args() main(args) From 5604a610a5ce301bf948af5dd0356eaca94518b8 Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Sat, 7 Sep 2024 19:40:25 -0400 Subject: [PATCH 7/7] fix enumeration bug --- src/rank_llm/rerank/pairwise/pairwise_rankllm.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py index a6e1d8ba..8f47f8ba 100644 --- a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py +++ b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py @@ -55,9 +55,6 @@ def rerank_batch( ) -> List[Result]: self._enumerated_indices = [] - for index in range(len(requests) * len(requests[0].candidates) * len(requests[0].candidates)): - self._enumerated_indices.append(index) - rerank_results = [ Result( query=copy.deepcopy(request.query), @@ -71,6 +68,14 @@ def rerank_batch( for i in result.candidates: i.score = 0 + for index in range(len(requests) * len(requests[0].candidates) * len(requests[0].candidates)): + candidate_1 = math.floor( + (index % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates) + ) + candidate_2 = index % len(rerank_results[0].candidates) + if candidate_1 != candidate_2: + self._enumerated_indices.append(index) + end = (len(rerank_results[0].candidates) - 1) * len(rerank_results[0].candidates) * len(requests) with tqdm(total=end, desc="Progress through (q, d) pairs") as progress_bar: for index in range(0, end, self._batch_size): @@ -95,10 +100,10 @@ def rerank_batch( (update_index_copy % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates) ) candidate_2 = update_index_copy % len(rerank_results[0].candidates) - + rerank_results[query_number].candidates[candidate_1].score += scores[update_index - index] rerank_results[query_number].candidates[candidate_2].score += 1 - scores[update_index - index] - + if index + self._batch_size > end: progress_bar.update(end - index) else: