From 60d5e7f03ffe5c46c908626e3c7522f2689689e5 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Tue, 13 Aug 2024 00:33:57 -0400 Subject: [PATCH 01/30] First step update - during create_prompt, use selected_index to substitude rank_start and rank_end --- .../rerank/listwise/listwise_rankllm.py | 8 +- src/rank_llm/rerank/listwise/rank_fid.py | 20 +-- src/rank_llm/rerank/listwise/rank_gpt.py | 31 +++-- .../rerank/listwise/rank_listwise_os_llm.py | 22 +-- .../rerank/listwise/reorder/__init__.py | 0 .../listwise/reorder/reorder_executor.py | 126 ++++++++++++++++++ src/rank_llm/rerank/rankllm.py | 10 +- 7 files changed, 173 insertions(+), 44 deletions(-) create mode 100644 src/rank_llm/rerank/listwise/reorder/__init__.py create mode 100644 src/rank_llm/rerank/listwise/reorder/reorder_executor.py diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index 838cc97a..93f03916 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -95,7 +95,7 @@ def permutation_pipeline_batched( prompts = [] logger.info("Loading prompts.") prompts = self.create_prompt_batched( - results, rank_start, rank_end, batch_size=32 + results, list(range(rank_start, rank_end)), batch_size=32 ) if logging: for prompt in prompts: @@ -142,7 +142,9 @@ def permutation_pipeline( Returns: Result: The processed result object after applying permutation. """ - prompt, in_token_count = self.create_prompt(result, rank_start, rank_end) + prompt, in_token_count = self.create_prompt( + result, list(range(rank_start, rank_end)) + ) if logging: logger.info(f"Prompt: {prompt}\n") permutation, out_token_count = self.run_llm( @@ -338,7 +340,7 @@ def get_ranking_cost( start_pos = rank_end - window_size while start_pos >= rank_start: start_pos = max(start_pos, rank_start) - prompt, _ = self.create_prompt(result, start_pos, end_pos) + prompt, _ = self.create_prompt(result, list(range(start_pos, end_pos))) input_token_count += self.get_num_tokens(prompt) end_pos = end_pos - step start_pos = start_pos - step diff --git a/src/rank_llm/rerank/listwise/rank_fid.py b/src/rank_llm/rerank/listwise/rank_fid.py index f6f2f02c..3f4dd285 100644 --- a/src/rank_llm/rerank/listwise/rank_fid.py +++ b/src/rank_llm/rerank/listwise/rank_fid.py @@ -188,9 +188,9 @@ def run_llm_batched( return self._run_llm_by_length_unified(prompt_infos) def create_prompt_batched( - self, results: List[Result], rank_start: int, rank_end: int, batch_size: int + self, results: List[Result], selected_index: List[int], batch_size: int ) -> List[Tuple[List[Dict[str, str]], int]]: - return [self.create_prompt(result, rank_start, rank_end) for result in results] + return [self.create_prompt(result, selected_index) for result in results] def run_llm(self, prompts: List[Dict[str, str]], **kwargs) -> Tuple[str, int]: """ @@ -202,7 +202,7 @@ def run_llm(self, prompts: List[Dict[str, str]], **kwargs) -> Tuple[str, int]: )[0] def create_prompt( - self, result: Result, rank_start: int, rank_end: int + self, result: Result, selected_index: List[int] ) -> Tuple[List[Dict[str, str]], int]: """ Create a prompt based on the result and given ranking range. @@ -213,13 +213,13 @@ def create_prompt( { "text": self._gen_passage( result.query.text, - i + 1 - rank_start, + loc + 1, self.convert_doc_to_prompt_content( - result.candidates[i].doc, self.max_tokens() + result.candidates[idx].doc, self.max_tokens() ), ) } - for i in range(rank_start, rank_end) + for loc, idx in enumerate(selected_index) ] return prompts, sum(self.get_num_tokens(prompt["text"]) for prompt in prompts) @@ -477,9 +477,9 @@ def run_llm_batched( return self._run_llm_by_length_unified(processed_prompts) def create_prompt_batched( - self, results: List[Result], rank_start: int, rank_end: int, batch_size: int + self, results: List[Result], selected_index: List[int], batch_size: int ) -> List[Tuple[List[Dict[str, str]], int]]: - return [self.create_prompt(result, rank_start, rank_end) for result in results] + return [self.create_prompt(result, selected_index) for result in results] def run_llm(self, prompts: List[Dict[str, str]], **kwargs) -> Tuple[str, int]: # get arbitrary query (they should be the same) @@ -488,7 +488,7 @@ def run_llm(self, prompts: List[Dict[str, str]], **kwargs) -> Tuple[str, int]: )[0] def create_prompt( - self, result: Result, rank_start: int, rank_end: int + self, result: Result, selected_index: List[int] ) -> Tuple[List[Dict[str, str]], int]: """ Create a prompt based on the result and given ranking range. @@ -498,7 +498,7 @@ def create_prompt( sum_token = 0 - for i in range(rank_start, rank_end): + for i in selected_index: results.append( { "query": f"question: {query}", diff --git a/src/rank_llm/rerank/listwise/rank_gpt.py b/src/rank_llm/rerank/listwise/rank_gpt.py index db29759e..2a47b660 100644 --- a/src/rank_llm/rerank/listwise/rank_gpt.py +++ b/src/rank_llm/rerank/listwise/rank_gpt.py @@ -229,7 +229,7 @@ def num_output_tokens(self, current_window_size: Optional[int] = None) -> int: _output_token_estimate = ( len( encoder.encode( - " > ".join([f"[{i+1}]" for i in range(current_window_size)]) + " > ".join([f"[{i + 1}]" for i in range(current_window_size)]) ) ) - 1 @@ -248,20 +248,20 @@ def run_llm_batched(self): pass def create_prompt( - self, result: Result, rank_start: int, rank_end: int + self, result: Result, selected_index: List[int] ) -> Tuple[List[Dict[str, str]], int]: if self._prompt_mode in [PromptMode.RANK_GPT, PromptMode.RANK_GPT_APEER]: - return self.create_rank_gpt_prompt(result, rank_start, rank_end) + return self.create_rank_gpt_prompt(result, selected_index) else: - return self.create_LRL_prompt(result, rank_start, rank_end) + return self.create_LRL_prompt(result, selected_index) def create_rank_gpt_prompt( - self, result: Result, rank_start: int, rank_end: int + self, result: Result, selected_index: List[int] ) -> Tuple[List[Dict[str, str]], int]: query = result.query.text - num = len(result.candidates[rank_start:rank_end]) + num = len(selected_index) - max_length = 300 * (self._window_size / (rank_end - rank_start)) + max_length = 300 * (self._window_size / (len(selected_index))) while True: messages = ( self._get_prefix_for_rank_gpt_apeer_prompt(query, num) @@ -269,7 +269,8 @@ def create_rank_gpt_prompt( else self._get_prefix_for_rank_gpt_prompt(query, num) ) rank = 0 - for cand in result.candidates[rank_start:rank_end]: + for idx in selected_index: + cand = result.candidates[idx] rank += 1 content = self.convert_doc_to_prompt_content(cand.doc, max_length) if self._prompt_mode == PromptMode.RANK_GPT_APEER: @@ -304,21 +305,23 @@ def create_rank_gpt_prompt( max_length -= max( 1, (num_tokens - self.max_tokens() + self.num_output_tokens()) - // ((rank_end - rank_start) * 4), + // (len(selected_index) * 4), ) return messages, self.get_num_tokens(messages) def create_LRL_prompt( - self, result: Result, rank_start: int, rank_end: int + self, result: Result, selected_index: List[int] ) -> Tuple[List[Dict[str, str]], int]: query = result.query.text - num = len(result.candidates[rank_start:rank_end]) - max_length = 300 * (20 / (rank_end - rank_start)) + num = len(selected_index) + max_length = 300 * (20 / len(selected_index)) psg_ids = [] while True: message = "Sort the list PASSAGES by how good each text answers the QUESTION (in descending order of relevancy).\n" rank = 0 - for cand in result.candidates[rank_start:rank_end]: + for idx in selected_index: + cand = result.candidates[idx] + rank += 1 psg_id = f"PASSAGE{rank}" content = self.convert_doc_to_prompt_content(cand.doc, max_length) @@ -335,7 +338,7 @@ def create_LRL_prompt( max_length -= max( 1, (num_tokens - self.max_tokens() + self.num_output_tokens()) - // ((rank_end - rank_start) * 4), + // (len(selected_index) * 4), ) return messages, self.get_num_tokens(messages) diff --git a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py index e3f9db7a..03b60986 100644 --- a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py @@ -208,7 +208,7 @@ def num_output_tokens(self, current_window_size: Optional[int] = None) -> int: _output_token_estimate = ( len( self._tokenizer.encode( - " > ".join([f"[{i+1}]" for i in range(current_window_size)]) + " > ".join([f"[{i + 1}]" for i in range(current_window_size)]) ) ) - 1 @@ -238,12 +238,12 @@ def _add_few_shot_examples(self, conv): return conv def create_prompt( - self, result: Result, rank_start: int, rank_end: int + self, result: Result, selected_index: List[int] ) -> Tuple[str, int]: query = result.query.text query = self._replace_number(query) - num = len(result.candidates[rank_start:rank_end]) - max_length = 300 * (20 / (rank_end - rank_start)) + num = len(selected_index) + max_length = 300 * (20 / (len(selected_index))) while True: conv = get_conversation_template(self._model) if self._system_message: @@ -252,7 +252,8 @@ def create_prompt( prefix = self._add_prefix_prompt(query, num) rank = 0 input_context = f"{prefix}\n" - for cand in result.candidates[rank_start:rank_end]: + for idx in selected_index: + cand = result.candidates[idx] rank += 1 # For Japanese should cut by character: content = content[:int(max_length)] content = self.convert_doc_to_prompt_content(cand.doc, max_length) @@ -265,7 +266,7 @@ def create_prompt( prompt = fix_text(prompt) num_tokens = self.get_num_tokens(prompt) if num_tokens <= self.max_tokens() - self.num_output_tokens( - rank_end - rank_start + len(selected_index) ): break else: @@ -274,17 +275,16 @@ def create_prompt( ( num_tokens - self.max_tokens() - + self.num_output_tokens(rank_end - rank_start) + + self.num_output_tokens(len(selected_index)) ) - // ((rank_end - rank_start) * 4), + // (len(selected_index) * 4), ) return prompt, self.get_num_tokens(prompt) def create_prompt_batched( self, results: List[Result], - rank_start: int, - rank_end: int, + selected_index: List[int], batch_size: int = 32, ) -> List[Tuple[str, int]]: def chunks(lst, n): @@ -298,7 +298,7 @@ def chunks(lst, n): for batch in tqdm(chunks(results, batch_size), desc="Processing batches"): completed_prompts = list( executor.map( - lambda result: self.create_prompt(result, rank_start, rank_end), + lambda result: self.create_prompt(result, selected_index), batch, ) ) diff --git a/src/rank_llm/rerank/listwise/reorder/__init__.py b/src/rank_llm/rerank/listwise/reorder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/rank_llm/rerank/listwise/reorder/reorder_executor.py b/src/rank_llm/rerank/listwise/reorder/reorder_executor.py new file mode 100644 index 00000000..0dea09f5 --- /dev/null +++ b/src/rank_llm/rerank/listwise/reorder/reorder_executor.py @@ -0,0 +1,126 @@ +import copy +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Dict, List, Tuple, TypeVar, Union + +from rank_llm.data import Request, Result + +T = TypeVar("T") + + +@dataclass +class ModelFunction: + # [(Request, SelectIndex)] -> [Prompt] + create_prompt: Callable[ + [List[Tuple[Request, List[int]]]], List[Union[str, Dict[str, str]]] + ] + + # [Prompt] -> [Permutation] + execute: Callable[[List[Union[str, Dict[str, str]]]], List[List[int]]] + + +class ReorderExecutor(ABC): + @abstractmethod + def reorder( + self, + requests: List[Request], + rank_start: int, + rank_end: int, + model: ModelFunction, + **kwargs, + ) -> Result: + pass + + @staticmethod + def _shuffle_and_rescore( + results: List[Result], select_indexes: List[int] + ) -> List[Result]: + # do nothing for now + return results + + @staticmethod + def _reorder_by_rank(items: List[T], idxes: List[int], rank: List[int]) -> List[T]: + """ + Provide items, indexes, ranks, returns an ordered items, specifically ordered on idxes locations by rank + """ + assert len(idxes) == len(rank) + + n = len(idxes) + + subset_item = [items[idxes[rank[i]]] for i in range(n)] + + for i in range(len(idxes)): + items[idxes[i]] = subset_item[i] + + return items + + +class SlidingWindowReorderExecutor(ReorderExecutor): + def __init__( + self, window_size: int, step_size: int, shuffle_candidates: bool = False + ): + self._window_size = window_size + self._step_size = step_size + + self._shuffle_candidates = shuffle_candidates + + def reorder( + self, + requests: List[Request], + rank_start: int, + rank_end: int, + model: ModelFunction, + **kwargs, + ) -> List[Result]: + rerank_results = [ + Result( + query=copy.deepcopy(request.query), + candidates=copy.deepcopy(request.candidates), + ranking_exec_summary=[], + ) + for request in requests + ] + + if self._shuffle_candidates: + self._shuffle_and_rescore(rerank_results, [*range(rank_start, rank_end)]) + + # order of requests + request_ranks = [[*range(len(request.candidates))] for request in requests] + + end_pos = rank_end + start_pos = rank_end - self._window_size + + # end_pos > rank_start ensures that the list is non-empty while allowing last window to be smaller than window_size + # start_pos + step != rank_start prevents processing of redundant windows (e.g. 0-20, followed by 0-10) + while end_pos > rank_start and start_pos + self._step_size != rank_start: + # if logging: + # logger.info(f"start_pos: {start_pos}, end_pos: {end_pos}") + start_pos = max(start_pos, rank_start) + + index_working_on = [*range(start_pos, end_pos)] + prompts = model.create_prompt( + [ + (request, [request_rank[i] for i in index_working_on]) + for request, request_rank in zip(requests, request_ranks) + ] + ) + orders = model.execute(prompts) + + for request_rank, order in zip(request_ranks, orders): + self._reorder_by_rank(request_rank, index_working_on, order) + + end_pos = end_pos - self._step_size + start_pos = start_pos - self._step_size + + return [ + Result( + query=copy.deepcopy(request.query), + candidates=self._reorder_by_rank( + copy.deepcopy(request.candidates), + [*range(len(request.candidates))], + rank, + ), + ranking_exec_summary=[], + ) + for request, rank in zip(requests, request_ranks) + ] diff --git a/src/rank_llm/rerank/rankllm.py b/src/rank_llm/rerank/rankllm.py index 7f5462c9..a6a0f3bf 100644 --- a/src/rank_llm/rerank/rankllm.py +++ b/src/rank_llm/rerank/rankllm.py @@ -57,15 +57,14 @@ def run_llm( @abstractmethod def create_prompt_batched( - self, results: List[Result], rank_start: int, rank_end: int, batch_size: int + self, results: List[Result], selected_index: List[int], batch_size: int ) -> List[Tuple[Union[str, List[Dict[str, str]]], int]]: """ Abstract method to create a batch of prompts based on the results and given ranking range. Args: results (List[Result]): The list of result objects containing data for prompt generation. - rank_start (int): The starting rank for prompt generation. - rank_end (int): The ending rank for prompt generation. + selected_index: select index for prompt creation Returns: Tuple[List[Union[str, List[Dict[str, str]]], List[int]]: A tuple object containing the list of generated prompts and the list of number of tokens in the generated prompts. @@ -74,15 +73,14 @@ def create_prompt_batched( @abstractmethod def create_prompt( - self, result: Result, rank_start: int, rank_end: int + self, result: Result, selected_index: List[int] ) -> Tuple[Union[str, List[Dict[str, str]]], int]: """ Abstract method to create a prompt based on the result and given ranking range. Args: result (Result): The result object containing data for prompt generation. - rank_start (int): The starting rank for prompt generation. - rank_end (int): The ending rank for prompt generation. + selected_index: select index for prompt creation Returns: Tuple[Union[str, List[Dict[str, str]]], int]: A tuple object containing the generated prompt and the number of tokens in the generated prompt. From 74118c4c7771d21a9275ee95e2247a702d6d9b41 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Tue, 13 Aug 2024 00:41:25 -0400 Subject: [PATCH 02/30] Allow selected index to be variable on each step --- src/rank_llm/rerank/listwise/listwise_rankllm.py | 4 +++- src/rank_llm/rerank/listwise/rank_fid.py | 11 +++++++---- src/rank_llm/rerank/listwise/rank_gpt.py | 8 ++++++-- src/rank_llm/rerank/listwise/rank_listwise_os_llm.py | 6 +++--- .../rerank/listwise/reorder/reorder_executor.py | 10 +++++----- src/rank_llm/rerank/rankllm.py | 4 ++-- 6 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index 93f03916..8b4858a5 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -95,7 +95,9 @@ def permutation_pipeline_batched( prompts = [] logger.info("Loading prompts.") prompts = self.create_prompt_batched( - results, list(range(rank_start, rank_end)), batch_size=32 + results, + [list(range(rank_start, rank_end)) for _ in range(len(results))], + batch_size=32, ) if logging: for prompt in prompts: diff --git a/src/rank_llm/rerank/listwise/rank_fid.py b/src/rank_llm/rerank/listwise/rank_fid.py index 3f4dd285..e074e8dd 100644 --- a/src/rank_llm/rerank/listwise/rank_fid.py +++ b/src/rank_llm/rerank/listwise/rank_fid.py @@ -188,9 +188,9 @@ def run_llm_batched( return self._run_llm_by_length_unified(prompt_infos) def create_prompt_batched( - self, results: List[Result], selected_index: List[int], batch_size: int + self, results: List[Result], selected_indexes: List[int], batch_size: int ) -> List[Tuple[List[Dict[str, str]], int]]: - return [self.create_prompt(result, selected_index) for result in results] + return [self.create_prompt(result, selected_indexes) for result in results] def run_llm(self, prompts: List[Dict[str, str]], **kwargs) -> Tuple[str, int]: """ @@ -477,9 +477,12 @@ def run_llm_batched( return self._run_llm_by_length_unified(processed_prompts) def create_prompt_batched( - self, results: List[Result], selected_index: List[int], batch_size: int + self, results: List[Result], selected_indexes: List[List[int]], batch_size: int ) -> List[Tuple[List[Dict[str, str]], int]]: - return [self.create_prompt(result, selected_index) for result in results] + return [ + self.create_prompt(result, selected_index) + for result, selected_index in zip(results, selected_indexes) + ] def run_llm(self, prompts: List[Dict[str, str]], **kwargs) -> Tuple[str, int]: # get arbitrary query (they should be the same) diff --git a/src/rank_llm/rerank/listwise/rank_gpt.py b/src/rank_llm/rerank/listwise/rank_gpt.py index 2a47b660..825fe436 100644 --- a/src/rank_llm/rerank/listwise/rank_gpt.py +++ b/src/rank_llm/rerank/listwise/rank_gpt.py @@ -241,10 +241,14 @@ def num_output_tokens(self, current_window_size: Optional[int] = None) -> int: self._output_token_estimate = _output_token_estimate return _output_token_estimate - def create_prompt_batched(self): + def create_prompt_batched( + self, results: List[Result], selected_indexes: List[List[int]], batch_size: int + ) -> List[Tuple[Union[str, List[Dict[str, str]]], int]]: pass - def run_llm_batched(self): + def run_llm_batched( + self, prompts: List[Union[str, List[Dict[str, str]]]], **kwargs + ) -> List[Tuple[str, int]]: pass def create_prompt( diff --git a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py index 03b60986..d8fa4f7e 100644 --- a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py @@ -284,7 +284,7 @@ def create_prompt( def create_prompt_batched( self, results: List[Result], - selected_index: List[int], + selected_indexes: List[List[int]], batch_size: int = 32, ) -> List[Tuple[str, int]]: def chunks(lst, n): @@ -298,8 +298,8 @@ def chunks(lst, n): for batch in tqdm(chunks(results, batch_size), desc="Processing batches"): completed_prompts = list( executor.map( - lambda result: self.create_prompt(result, selected_index), - batch, + lambda req: self.create_prompt(req[0], req[1]), + zip(batch, selected_indexes), ) ) all_completed_prompts.extend(completed_prompts) diff --git a/src/rank_llm/rerank/listwise/reorder/reorder_executor.py b/src/rank_llm/rerank/listwise/reorder/reorder_executor.py index 0dea09f5..1a5c6800 100644 --- a/src/rank_llm/rerank/listwise/reorder/reorder_executor.py +++ b/src/rank_llm/rerank/listwise/reorder/reorder_executor.py @@ -3,16 +3,16 @@ from dataclasses import dataclass from typing import Callable, Dict, List, Tuple, TypeVar, Union -from rank_llm.data import Request, Result +from rank_llm.data import Result T = TypeVar("T") @dataclass class ModelFunction: - # [(Request, SelectIndex)] -> [Prompt] + # [(Result, SelectIndex)] -> [Prompt] create_prompt: Callable[ - [List[Tuple[Request, List[int]]]], List[Union[str, Dict[str, str]]] + [List[Tuple[Result, List[int]]]], List[Union[str, Dict[str, str]]] ] # [Prompt] -> [Permutation] @@ -23,7 +23,7 @@ class ReorderExecutor(ABC): @abstractmethod def reorder( self, - requests: List[Request], + requests: List[Result], rank_start: int, rank_end: int, model: ModelFunction, @@ -66,7 +66,7 @@ def __init__( def reorder( self, - requests: List[Request], + requests: List[Result], rank_start: int, rank_end: int, model: ModelFunction, diff --git a/src/rank_llm/rerank/rankllm.py b/src/rank_llm/rerank/rankllm.py index a6a0f3bf..f6a05f1c 100644 --- a/src/rank_llm/rerank/rankllm.py +++ b/src/rank_llm/rerank/rankllm.py @@ -57,14 +57,14 @@ def run_llm( @abstractmethod def create_prompt_batched( - self, results: List[Result], selected_index: List[int], batch_size: int + self, results: List[Result], selected_indexes: List[List[int]], batch_size: int ) -> List[Tuple[Union[str, List[Dict[str, str]]], int]]: """ Abstract method to create a batch of prompts based on the results and given ranking range. Args: results (List[Result]): The list of result objects containing data for prompt generation. - selected_index: select index for prompt creation + selected_indexes: select index for prompt creation Returns: Tuple[List[Union[str, List[Dict[str, str]]], List[int]]: A tuple object containing the list of generated prompts and the list of number of tokens in the generated prompts. From 36bbcdc218f558c677e2c05da8f59876f03dac3e Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Tue, 13 Aug 2024 00:48:52 -0400 Subject: [PATCH 03/30] Added basic _get_model_function - it is wrong now, we want to return a permutation but now it is returning a string. Will need the cleanup pipeline in next step --- .../rerank/listwise/listwise_rankllm.py | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index 8b4858a5..a9aca9cb 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -4,13 +4,14 @@ import re from abc import ABC from datetime import datetime -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Union from ftfy import fix_text from tqdm import tqdm from rank_llm.data import RankingExecInfo, Request, Result from rank_llm.rerank import PromptMode, RankLLM +from rank_llm.rerank.listwise.reorder.reorder_executor import ModelFunction logger = logging.getLogger(__name__) @@ -441,3 +442,33 @@ def convert_doc_to_prompt_content( # For Japanese should cut by character: content = content[:int(max_length)] content = " ".join(content.split()[: int(max_length)]) return self._replace_number(content) + + def _get_model_function(self, batched: bool = False, **kwargs) -> ModelFunction: + # [(Request, SelectIndex)] -> [Prompt] + if batched: + + def create_prompt(batch: List[Tuple[Result, List[int]]]): + return [ + prompt + for prompt, _ in self.create_prompt_batched( + [result for result, selected_index in batch], + [selected_index for result, selected_index in batch], + 32, + ) + ] + + def execute(batch: List[Union[str, Dict[str, str]]]): + return [s for s, _ in self.run_llm_batched(batch, **kwargs)] + + else: + + def create_prompt(batch: List[Tuple[Result, List[int]]]): + return [ + self.create_prompt(result, selected_index)[0] + for result, selected_index in batch + ] + + def execute(batch: List[Union[str, Dict[str, str]]]): + return [self.run_llm(x, **kwargs)[0] for x in batch] + + return ModelFunction(create_prompt=create_prompt, execute=execute) From 6d2094a8c79de10a008520a9785564f42157b7d9 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Wed, 14 Aug 2024 00:36:28 -0400 Subject: [PATCH 04/30] Renamed indexes to indices, [indices] is renamed to indices_batch --- .../rerank/listwise/listwise_rankllm.py | 42 +++++++++++++++---- src/rank_llm/rerank/listwise/rank_fid.py | 23 ++++++---- src/rank_llm/rerank/listwise/rank_gpt.py | 31 +++++++------- .../rerank/listwise/rank_listwise_os_llm.py | 18 ++++---- .../listwise/reorder/reorder_executor.py | 16 +++---- src/rank_llm/rerank/rankllm.py | 11 +++-- 6 files changed, 90 insertions(+), 51 deletions(-) diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index a9aca9cb..c5fcf487 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -443,6 +443,16 @@ def convert_doc_to_prompt_content( content = " ".join(content.split()[: int(max_length)]) return self._replace_number(content) + def _permutation_to_rank(self, perm_string: str, selected_indices: List[int]): + perm = [int(x) - 1 for x in self._clean_response(perm_string).split(" ")] + perm = [ + int(x) + for x in self._remove_duplicate(perm) + if 0 <= x < len(selected_indices) + ] + perm = perm + [i for i in range(len(selected_indices)) if i not in perm] + return perm + def _get_model_function(self, batched: bool = False, **kwargs) -> ModelFunction: # [(Request, SelectIndex)] -> [Prompt] if batched: @@ -451,24 +461,40 @@ def create_prompt(batch: List[Tuple[Result, List[int]]]): return [ prompt for prompt, _ in self.create_prompt_batched( - [result for result, selected_index in batch], - [selected_index for result, selected_index in batch], + [result for result, selected_location in batch], + [selected_indices for result, selected_indices in batch], 32, ) ] - def execute(batch: List[Union[str, Dict[str, str]]]): - return [s for s, _ in self.run_llm_batched(batch, **kwargs)] + def execute( + batch: List[Union[str, Dict[str, str]]], + selected_indices_batch: List[List[int]], + ): + return [ + self._permutation_to_rank(s, selected_indices) + for (s, _), selected_indices in zip( + self.run_llm_batched(batch, **kwargs), selected_indices_batch + ) + ] else: def create_prompt(batch: List[Tuple[Result, List[int]]]): return [ - self.create_prompt(result, selected_index)[0] - for result, selected_index in batch + self.create_prompt(result, selected_indices)[0] + for result, selected_indices in batch ] - def execute(batch: List[Union[str, Dict[str, str]]]): - return [self.run_llm(x, **kwargs)[0] for x in batch] + def execute( + batch: List[Union[str, Dict[str, str]]], + selected_indices_batch: List[List[int]], + ): + return [ + self._permutation_to_rank( + self.run_llm(x, **kwargs)[0], selected_indices + ) + for x, selected_indices in zip(batch, selected_indices_batch) + ] return ModelFunction(create_prompt=create_prompt, execute=execute) diff --git a/src/rank_llm/rerank/listwise/rank_fid.py b/src/rank_llm/rerank/listwise/rank_fid.py index e074e8dd..5b695d9d 100644 --- a/src/rank_llm/rerank/listwise/rank_fid.py +++ b/src/rank_llm/rerank/listwise/rank_fid.py @@ -188,9 +188,11 @@ def run_llm_batched( return self._run_llm_by_length_unified(prompt_infos) def create_prompt_batched( - self, results: List[Result], selected_indexes: List[int], batch_size: int + self, results: List[Result], selected_indices_batch: List[int], batch_size: int ) -> List[Tuple[List[Dict[str, str]], int]]: - return [self.create_prompt(result, selected_indexes) for result in results] + return [ + self.create_prompt(result, selected_indices_batch) for result in results + ] def run_llm(self, prompts: List[Dict[str, str]], **kwargs) -> Tuple[str, int]: """ @@ -202,7 +204,7 @@ def run_llm(self, prompts: List[Dict[str, str]], **kwargs) -> Tuple[str, int]: )[0] def create_prompt( - self, result: Result, selected_index: List[int] + self, result: Result, selected_indices: List[int] ) -> Tuple[List[Dict[str, str]], int]: """ Create a prompt based on the result and given ranking range. @@ -219,7 +221,7 @@ def create_prompt( ), ) } - for loc, idx in enumerate(selected_index) + for loc, idx in enumerate(selected_indices) ] return prompts, sum(self.get_num_tokens(prompt["text"]) for prompt in prompts) @@ -477,11 +479,14 @@ def run_llm_batched( return self._run_llm_by_length_unified(processed_prompts) def create_prompt_batched( - self, results: List[Result], selected_indexes: List[List[int]], batch_size: int + self, + results: List[Result], + selected_indices_batch: List[List[int]], + batch_size: int, ) -> List[Tuple[List[Dict[str, str]], int]]: return [ - self.create_prompt(result, selected_index) - for result, selected_index in zip(results, selected_indexes) + self.create_prompt(result, selected_indices) + for result, selected_indices in zip(results, selected_indices_batch) ] def run_llm(self, prompts: List[Dict[str, str]], **kwargs) -> Tuple[str, int]: @@ -491,7 +496,7 @@ def run_llm(self, prompts: List[Dict[str, str]], **kwargs) -> Tuple[str, int]: )[0] def create_prompt( - self, result: Result, selected_index: List[int] + self, result: Result, selected_indices: List[int] ) -> Tuple[List[Dict[str, str]], int]: """ Create a prompt based on the result and given ranking range. @@ -501,7 +506,7 @@ def create_prompt( sum_token = 0 - for i in selected_index: + for i in selected_indices: results.append( { "query": f"question: {query}", diff --git a/src/rank_llm/rerank/listwise/rank_gpt.py b/src/rank_llm/rerank/listwise/rank_gpt.py index 825fe436..5c8cd3e4 100644 --- a/src/rank_llm/rerank/listwise/rank_gpt.py +++ b/src/rank_llm/rerank/listwise/rank_gpt.py @@ -242,7 +242,10 @@ def num_output_tokens(self, current_window_size: Optional[int] = None) -> int: return _output_token_estimate def create_prompt_batched( - self, results: List[Result], selected_indexes: List[List[int]], batch_size: int + self, + results: List[Result], + selected_indices_batch: List[List[int]], + batch_size: int, ) -> List[Tuple[Union[str, List[Dict[str, str]]], int]]: pass @@ -252,20 +255,20 @@ def run_llm_batched( pass def create_prompt( - self, result: Result, selected_index: List[int] + self, result: Result, selected_indices: List[int] ) -> Tuple[List[Dict[str, str]], int]: if self._prompt_mode in [PromptMode.RANK_GPT, PromptMode.RANK_GPT_APEER]: - return self.create_rank_gpt_prompt(result, selected_index) + return self.create_rank_gpt_prompt(result, selected_indices) else: - return self.create_LRL_prompt(result, selected_index) + return self.create_LRL_prompt(result, selected_indices) def create_rank_gpt_prompt( - self, result: Result, selected_index: List[int] + self, result: Result, selected_indices: List[int] ) -> Tuple[List[Dict[str, str]], int]: query = result.query.text - num = len(selected_index) + num = len(selected_indices) - max_length = 300 * (self._window_size / (len(selected_index))) + max_length = 300 * (self._window_size / (len(selected_indices))) while True: messages = ( self._get_prefix_for_rank_gpt_apeer_prompt(query, num) @@ -273,7 +276,7 @@ def create_rank_gpt_prompt( else self._get_prefix_for_rank_gpt_prompt(query, num) ) rank = 0 - for idx in selected_index: + for idx in selected_indices: cand = result.candidates[idx] rank += 1 content = self.convert_doc_to_prompt_content(cand.doc, max_length) @@ -309,21 +312,21 @@ def create_rank_gpt_prompt( max_length -= max( 1, (num_tokens - self.max_tokens() + self.num_output_tokens()) - // (len(selected_index) * 4), + // (len(selected_indices) * 4), ) return messages, self.get_num_tokens(messages) def create_LRL_prompt( - self, result: Result, selected_index: List[int] + self, result: Result, selected_indices: List[int] ) -> Tuple[List[Dict[str, str]], int]: query = result.query.text - num = len(selected_index) - max_length = 300 * (20 / len(selected_index)) + num = len(selected_indices) + max_length = 300 * (20 / len(selected_indices)) psg_ids = [] while True: message = "Sort the list PASSAGES by how good each text answers the QUESTION (in descending order of relevancy).\n" rank = 0 - for idx in selected_index: + for idx in selected_indices: cand = result.candidates[idx] rank += 1 @@ -342,7 +345,7 @@ def create_LRL_prompt( max_length -= max( 1, (num_tokens - self.max_tokens() + self.num_output_tokens()) - // (len(selected_index) * 4), + // (len(selected_indices) * 4), ) return messages, self.get_num_tokens(messages) diff --git a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py index d8fa4f7e..eb6323b9 100644 --- a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py @@ -238,12 +238,12 @@ def _add_few_shot_examples(self, conv): return conv def create_prompt( - self, result: Result, selected_index: List[int] + self, result: Result, selected_indices: List[int] ) -> Tuple[str, int]: query = result.query.text query = self._replace_number(query) - num = len(selected_index) - max_length = 300 * (20 / (len(selected_index))) + num = len(selected_indices) + max_length = 300 * (20 / (len(selected_indices))) while True: conv = get_conversation_template(self._model) if self._system_message: @@ -252,7 +252,7 @@ def create_prompt( prefix = self._add_prefix_prompt(query, num) rank = 0 input_context = f"{prefix}\n" - for idx in selected_index: + for idx in selected_indices: cand = result.candidates[idx] rank += 1 # For Japanese should cut by character: content = content[:int(max_length)] @@ -266,7 +266,7 @@ def create_prompt( prompt = fix_text(prompt) num_tokens = self.get_num_tokens(prompt) if num_tokens <= self.max_tokens() - self.num_output_tokens( - len(selected_index) + len(selected_indices) ): break else: @@ -275,16 +275,16 @@ def create_prompt( ( num_tokens - self.max_tokens() - + self.num_output_tokens(len(selected_index)) + + self.num_output_tokens(len(selected_indices)) ) - // (len(selected_index) * 4), + // (len(selected_indices) * 4), ) return prompt, self.get_num_tokens(prompt) def create_prompt_batched( self, results: List[Result], - selected_indexes: List[List[int]], + selected_indices_batch: List[List[int]], batch_size: int = 32, ) -> List[Tuple[str, int]]: def chunks(lst, n): @@ -299,7 +299,7 @@ def chunks(lst, n): completed_prompts = list( executor.map( lambda req: self.create_prompt(req[0], req[1]), - zip(batch, selected_indexes), + zip(batch, selected_indices_batch), ) ) all_completed_prompts.extend(completed_prompts) diff --git a/src/rank_llm/rerank/listwise/reorder/reorder_executor.py b/src/rank_llm/rerank/listwise/reorder/reorder_executor.py index 1a5c6800..1b5f4784 100644 --- a/src/rank_llm/rerank/listwise/reorder/reorder_executor.py +++ b/src/rank_llm/rerank/listwise/reorder/reorder_executor.py @@ -10,13 +10,15 @@ @dataclass class ModelFunction: - # [(Result, SelectIndex)] -> [Prompt] + # [(Result, SelectIndices)] -> [Prompt] create_prompt: Callable[ [List[Tuple[Result, List[int]]]], List[Union[str, Dict[str, str]]] ] - # [Prompt] -> [Permutation] - execute: Callable[[List[Union[str, Dict[str, str]]]], List[List[int]]] + # [Prompt], [SelectedIndices] -> [Permutation] + execute: Callable[ + [List[Union[str, Dict[str, str]]], List[List[int]]], List[List[int]] + ] class ReorderExecutor(ABC): @@ -97,17 +99,17 @@ def reorder( # logger.info(f"start_pos: {start_pos}, end_pos: {end_pos}") start_pos = max(start_pos, rank_start) - index_working_on = [*range(start_pos, end_pos)] + indices_working_on = [*range(start_pos, end_pos)] prompts = model.create_prompt( [ - (request, [request_rank[i] for i in index_working_on]) + (request, [request_rank[i] for i in indices_working_on]) for request, request_rank in zip(requests, request_ranks) ] ) - orders = model.execute(prompts) + orders = model.execute(prompts, [indices_working_on] * len(requests)) for request_rank, order in zip(request_ranks, orders): - self._reorder_by_rank(request_rank, index_working_on, order) + self._reorder_by_rank(request_rank, indices_working_on, order) end_pos = end_pos - self._step_size start_pos = start_pos - self._step_size diff --git a/src/rank_llm/rerank/rankllm.py b/src/rank_llm/rerank/rankllm.py index f6a05f1c..ee851bcc 100644 --- a/src/rank_llm/rerank/rankllm.py +++ b/src/rank_llm/rerank/rankllm.py @@ -57,14 +57,17 @@ def run_llm( @abstractmethod def create_prompt_batched( - self, results: List[Result], selected_indexes: List[List[int]], batch_size: int + self, + results: List[Result], + selected_indices_batch: List[List[int]], + batch_size: int, ) -> List[Tuple[Union[str, List[Dict[str, str]]], int]]: """ Abstract method to create a batch of prompts based on the results and given ranking range. Args: results (List[Result]): The list of result objects containing data for prompt generation. - selected_indexes: select index for prompt creation + selected_indices_batch: select index for prompt creation Returns: Tuple[List[Union[str, List[Dict[str, str]]], List[int]]: A tuple object containing the list of generated prompts and the list of number of tokens in the generated prompts. @@ -73,14 +76,14 @@ def create_prompt_batched( @abstractmethod def create_prompt( - self, result: Result, selected_index: List[int] + self, result: Result, selected_indices: List[int] ) -> Tuple[Union[str, List[Dict[str, str]]], int]: """ Abstract method to create a prompt based on the result and given ranking range. Args: result (Result): The result object containing data for prompt generation. - selected_index: select index for prompt creation + selected_indices: select index for prompt creation Returns: Tuple[Union[str, List[Dict[str, str]]], int]: A tuple object containing the generated prompt and the number of tokens in the generated prompt. From cc78522021b44cb12c3fbf0ec34ff1ec3899ba1c Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Wed, 14 Aug 2024 00:40:00 -0400 Subject: [PATCH 05/30] Renamed ReorderExecutor to Reorder Policy --- src/rank_llm/rerank/listwise/listwise_rankllm.py | 2 +- .../reorder/{reorder_executor.py => reorder_policy.py} | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) rename src/rank_llm/rerank/listwise/reorder/{reorder_executor.py => reorder_policy.py} (97%) diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index c5fcf487..2d71bd8c 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -11,7 +11,7 @@ from rank_llm.data import RankingExecInfo, Request, Result from rank_llm.rerank import PromptMode, RankLLM -from rank_llm.rerank.listwise.reorder.reorder_executor import ModelFunction +from rank_llm.rerank.listwise.reorder.reorder_policy import ModelFunction logger = logging.getLogger(__name__) diff --git a/src/rank_llm/rerank/listwise/reorder/reorder_executor.py b/src/rank_llm/rerank/listwise/reorder/reorder_policy.py similarity index 97% rename from src/rank_llm/rerank/listwise/reorder/reorder_executor.py rename to src/rank_llm/rerank/listwise/reorder/reorder_policy.py index 1b5f4784..c54d550e 100644 --- a/src/rank_llm/rerank/listwise/reorder/reorder_executor.py +++ b/src/rank_llm/rerank/listwise/reorder/reorder_policy.py @@ -21,7 +21,7 @@ class ModelFunction: ] -class ReorderExecutor(ABC): +class ReorderPolicy(ABC): @abstractmethod def reorder( self, @@ -57,7 +57,7 @@ def _reorder_by_rank(items: List[T], idxes: List[int], rank: List[int]) -> List[ return items -class SlidingWindowReorderExecutor(ReorderExecutor): +class SlidingWindowReorderPolicy(ReorderPolicy): def __init__( self, window_size: int, step_size: int, shuffle_candidates: bool = False ): From 4936c8cb3b469cc460dc6b84759e4ef9199dac34 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Thu, 15 Aug 2024 00:53:25 -0400 Subject: [PATCH 06/30] Moved LiT5Distill to policy --- .../rerank/listwise/listwise_rankllm.py | 83 +++++++++++++++++-- src/rank_llm/rerank/listwise/rank_fid.py | 82 +++++------------- .../rerank/listwise/reorder/reorder_policy.py | 37 ++++++++- src/rank_llm/rerank/reranker.py | 17 ++-- 4 files changed, 137 insertions(+), 82 deletions(-) diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index 2d71bd8c..d69d0965 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -11,10 +11,16 @@ from rank_llm.data import RankingExecInfo, Request, Result from rank_llm.rerank import PromptMode, RankLLM -from rank_llm.rerank.listwise.reorder.reorder_policy import ModelFunction +from rank_llm.rerank.listwise.reorder.reorder_policy import ( + ModelFunction, + ReorderPolicy, + SlidingWindowReorderPolicy, +) logger = logging.getLogger(__name__) +SUPPORT_REORDER_POLICIES = [SlidingWindowReorderPolicy] + class ListwiseRankLLM(RankLLM, ABC): """ @@ -31,15 +37,68 @@ class ListwiseRankLLM(RankLLM, ABC): def __init__( self, + reorder_policy: ReorderPolicy, model: str, context_size: int, prompt_mode: PromptMode, num_few_shot_examples: int, - window_size: int, ) -> None: super().__init__(model, context_size, prompt_mode) self._num_few_shot_examples = num_few_shot_examples - self._window_size = window_size + + self.reorder_policy = reorder_policy + + def rerank_batch( + self, + requests: List[Request], + rank_start: int = 0, + rank_end: int = 100, + shuffle_candidates: bool = False, + logging: bool = False, + batched: bool = False, + **kwargs: Any, + ) -> List[Result]: + populate_exec_summary: bool = kwargs.get("populate_exec_summary", False) + + batch_size = kwargs.get("batch_size", 1) + + if not batched: + batch_size = 1 + + reorder_policy = self.reorder_policy + model_functions = self._get_model_function(batched) + + # reranking using vllm + if len(set([len(req.candidates) for req in requests])) != 1: + raise ValueError("Batched requests must have the same number of candidates") + + result: list[Result] = [] + + with tqdm(range(0, len(requests))) as bar: + for i in range(0, len(requests), batch_size): + batch = requests[i : min(i + batch_size, len(requests))] + batch_result = reorder_policy.reorder( + requests=[ + Result( + query=copy.deepcopy(request.query), + candidates=copy.deepcopy(request.candidates), + ranking_exec_summary=[], + ) + for request in batch + ], + rank_start=max(rank_start, 0), + rank_end=min( + rank_end, len(requests[0].candidates) + ), # TODO: Fails arbitrary hit sizes + model=model_functions, + shuffle_candidates=shuffle_candidates, + logging=logging, + populate_exec_summary=populate_exec_summary, + ) + result.extend(batch_result) + bar.update(len(batch)) + + return result def get_output_filename( self, @@ -358,7 +417,8 @@ def _clean_response(self, response: str) -> str: new_response = "" for c in response: if not c.isdigit(): - new_response += " " + if len(new_response) == 0 or new_response[-1] != " ": + new_response += " " else: new_response += c new_response = new_response.strip() @@ -444,11 +504,13 @@ def convert_doc_to_prompt_content( return self._replace_number(content) def _permutation_to_rank(self, perm_string: str, selected_indices: List[int]): - perm = [int(x) - 1 for x in self._clean_response(perm_string).split(" ")] + perm = [ + int(x) - 1 for x in self._clean_response(perm_string).strip().split(" ") + ] perm = [ int(x) for x in self._remove_duplicate(perm) - if 0 <= x < len(selected_indices) + if 0 <= int(x) < len(selected_indices) ] perm = perm + [i for i in range(len(selected_indices)) if i not in perm] return perm @@ -461,7 +523,7 @@ def create_prompt(batch: List[Tuple[Result, List[int]]]): return [ prompt for prompt, _ in self.create_prompt_batched( - [result for result, selected_location in batch], + [result for result, selected_indices in batch], [selected_indices for result, selected_indices in batch], 32, ) @@ -498,3 +560,10 @@ def execute( ] return ModelFunction(create_prompt=create_prompt, execute=execute) + + @staticmethod + def get_reorder_policy(reorder_policy: str, **kwargs): + for policy in SUPPORT_REORDER_POLICIES: + if policy.name() == reorder_policy: + return policy(**kwargs) + raise Exception(f"Cannot find reorder policy {reorder_policy}") diff --git a/src/rank_llm/rerank/listwise/rank_fid.py b/src/rank_llm/rerank/listwise/rank_fid.py index 5b695d9d..bf5c102a 100644 --- a/src/rank_llm/rerank/listwise/rank_fid.py +++ b/src/rank_llm/rerank/listwise/rank_fid.py @@ -7,6 +7,7 @@ from rank_llm.data import Request, Result from rank_llm.rerank.listwise.listwise_rankllm import ListwiseRankLLM from rank_llm.rerank.listwise.lit5.model import FiD, FiDCrossAttentionScore +from rank_llm.rerank.listwise.reorder.reorder_policy import ReorderPolicy from rank_llm.rerank.rankllm import PromptMode @@ -30,12 +31,11 @@ def _to_precision(self, precision: str) -> None: def __init__( self, + reorder_policy: ReorderPolicy, model: str, context_size: int = 150, prompt_mode: PromptMode = PromptMode.LiT5, # Placeholder for actual mode num_few_shot_examples: int = 0, - window_size: int = 20, - step_size: int = 10, precision: str = "bfloat16", device: str = "cuda", batched: bool = False, @@ -44,11 +44,11 @@ def __init__( Creates instance of the RankFiDDistill class, a specialized version of RankLLM designed from Lit5-Distill. """ super().__init__( + reorder_policy=reorder_policy, model=model, context_size=context_size, prompt_mode=prompt_mode, num_few_shot_examples=num_few_shot_examples, - window_size=window_size, ) self._precision = precision self._tokenizer = T5Tokenizer.from_pretrained(model) @@ -56,13 +56,15 @@ def __init__( self._device = device - self._window_size = window_size - self._stride = step_size - self._batched = batched self._answer_maxlength = len( - " > ".join(map(lambda x: f"[{x}]", range(1, window_size + 1))) + " > ".join( + map( + lambda x: f"[{x}]", + range(1, reorder_policy.max_selected_indices() + 1), + ) + ) ) self._output_token_estimate = None @@ -119,60 +121,15 @@ def rerank_batch( logging: bool = False, **kwargs: Any, ) -> List[Result]: - top_k_retrieve: int = kwargs.get("top_k_retrieve", 100) - - window_size: int = kwargs.get("window_size", self._window_size) - window_size = min(window_size, top_k_retrieve) - step: int = kwargs.get("step_size", self._stride) - - populate_exec_summary: bool = kwargs.get("populate_exec_summary", False) - - batch_size = kwargs.get("batch_size", 1) - - if self._batched: - # reranking using vllm - if len(set([len(req.candidates) for req in requests])) != 1: - raise ValueError( - "Batched requests must have the same number of candidates" - ) - - result = [] - - with tqdm(range(0, len(requests))) as bar: - for i in range(0, len(requests), batch_size): - batch = requests[i : min(i + batch_size, len(requests))] - batch_result = self.sliding_windows_batched( - batch, - rank_start=max(rank_start, 0), - rank_end=min( - rank_end, len(requests[0].candidates) - ), # TODO: Fails arbitrary hit sizes - window_size=window_size, - step=step, - shuffle_candidates=shuffle_candidates, - logging=logging, - populate_exec_summary=populate_exec_summary, - ) - result.extend(batch_result) - bar.update(len(batch)) - - return result - else: - # Normal operation mode - results = [] - for request in tqdm(requests): - result = self.sliding_windows( - request, - rank_start=max(rank_start, 0), - rank_end=min(rank_end, len(request.candidates)), - window_size=window_size, - step=step, - shuffle_candidates=shuffle_candidates, - logging=logging, - populate_exec_summary=populate_exec_summary, - ) - results.append(result) - return results + return super().rerank_batch( + requests=requests, + rank_start=rank_start, + rank_end=rank_end, + shuffle_candidates=shuffle_candidates, + logging=logging, + batched=self._batched, + **kwargs, + ) def run_llm_batched( self, prompts: List[List[Dict[str, str]]], **kwargs @@ -191,7 +148,8 @@ def create_prompt_batched( self, results: List[Result], selected_indices_batch: List[int], batch_size: int ) -> List[Tuple[List[Dict[str, str]], int]]: return [ - self.create_prompt(result, selected_indices_batch) for result in results + self.create_prompt(result, selected_indices) + for result, selected_indices in zip(results, selected_indices_batch) ] def run_llm(self, prompts: List[Dict[str, str]], **kwargs) -> Tuple[str, int]: diff --git a/src/rank_llm/rerank/listwise/reorder/reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/reorder_policy.py index c54d550e..553e9def 100644 --- a/src/rank_llm/rerank/listwise/reorder/reorder_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/reorder_policy.py @@ -30,14 +30,23 @@ def reorder( rank_end: int, model: ModelFunction, **kwargs, - ) -> Result: + ) -> list[Result]: + pass + + @abstractmethod + def max_selected_indices(self) -> int: + pass + + @staticmethod + @abstractmethod + def name() -> str: pass @staticmethod def _shuffle_and_rescore( results: List[Result], select_indexes: List[int] ) -> List[Result]: - # do nothing for now + # TODO: do nothing for now return results @staticmethod @@ -59,7 +68,11 @@ def _reorder_by_rank(items: List[T], idxes: List[int], rank: List[int]) -> List[ class SlidingWindowReorderPolicy(ReorderPolicy): def __init__( - self, window_size: int, step_size: int, shuffle_candidates: bool = False + self, + window_size: int = 20, + step_size: int = 10, + shuffle_candidates: bool = False, + **kwargs, ): self._window_size = window_size self._step_size = step_size @@ -72,6 +85,9 @@ def reorder( rank_start: int, rank_end: int, model: ModelFunction, + shuffle_candidates=False, + logging=False, + populate_exec_summary=False, **kwargs, ) -> List[Result]: rerank_results = [ @@ -114,7 +130,7 @@ def reorder( end_pos = end_pos - self._step_size start_pos = start_pos - self._step_size - return [ + results = [ Result( query=copy.deepcopy(request.query), candidates=self._reorder_by_rank( @@ -126,3 +142,16 @@ def reorder( ) for request, rank in zip(requests, request_ranks) ] + + for result, request in zip(results, requests): + for j in range(len(result.candidates)): + result.candidates[j].score = request.candidates[j].score + + return results + + @staticmethod + def name() -> str: + return "reorder_policy.sliding_window" + + def max_selected_indices(self) -> int: + return self._window_size diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index edf47ef4..39f0757b 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -2,13 +2,9 @@ from typing import Any, List, Optional, Tuple from rank_llm.data import DataWriter, Request, Result -from rank_llm.rerank import ( - PromptMode, - RankLLM, - get_azure_openai_args, - get_openai_api_key, -) +from rank_llm.rerank import PromptMode, get_azure_openai_args, get_openai_api_key from rank_llm.rerank.listwise import RankListwiseOSLLM, SafeOpenai +from rank_llm.rerank.listwise.listwise_rankllm import ListwiseRankLLM from rank_llm.rerank.listwise.rank_fid import RankFiDDistill, RankFiDScore from rank_llm.rerank.rankllm import RankLLM @@ -156,6 +152,7 @@ def write_rerank_results( def get_agent(self) -> RankLLM: return self._agent + @staticmethod def create_agent( model_path: str, default_agent: RankLLM, @@ -254,10 +251,10 @@ def create_agent( print(f"Completed loading {model_path}") elif "lit5-distill" in model_path.lower(): keys_and_defaults = [ + ("reorder_policy", "reorder_policy.sliding_window"), ("context_size", 150), ("prompt_mode", PromptMode.LiT5), ("num_few_shot_examples", 0), - ("window_size", 20), ("precision", "bfloat16"), ("device", "cuda"), # reuse this parameter, but its not for "vllm", but only for "batched" @@ -265,21 +262,23 @@ def create_agent( ] ( + reorder_policy, context_size, prompt_mode, num_few_shot_examples, - window_size, precision, device, vllm_batched, ) = extract_kwargs(keys_and_defaults, **kwargs) agent = RankFiDDistill( + reorder_policy=ListwiseRankLLM.get_reorder_policy( + reorder_policy, **kwargs + ), model=model_path, context_size=context_size, prompt_mode=prompt_mode, num_few_shot_examples=num_few_shot_examples, - window_size=window_size, precision=precision, device=device, batched=vllm_batched, From 95f065a929c58e33e35ca17fe9dfd8f2502544e2 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Sat, 17 Aug 2024 23:54:25 -0400 Subject: [PATCH 07/30] Transition rank listwise os llm to using reorder policy --- .../rerank/listwise/rank_listwise_os_llm.py | 58 +++++-------------- 1 file changed, 16 insertions(+), 42 deletions(-) diff --git a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py index eb6323b9..0cf32875 100644 --- a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py @@ -15,6 +15,7 @@ from rank_llm.rerank import PromptMode from .listwise_rankllm import ListwiseRankLLM +from .reorder.reorder_policy import ReorderPolicy try: from vllm import LLM, SamplingParams @@ -28,6 +29,7 @@ class RankListwiseOSLLM(ListwiseRankLLM): def __init__( self, + reorder_policy: ReorderPolicy, model: str, name: str, context_size: int = 4096, @@ -71,7 +73,11 @@ def __init__( TODO: Make repetition_penalty configurable """ super().__init__( - model, context_size, prompt_mode, num_few_shot_examples, window_size + reorder_policy=reorder_policy, + model=model, + context_size=context_size, + prompt_mode=prompt_mode, + num_few_shot_examples=num_few_shot_examples, ) self._device = device self._vllm_batched = vllm_batched @@ -113,47 +119,15 @@ def rerank_batch( logging: bool = False, **kwargs: Any, ) -> List[Result]: - top_k_retrieve: int = kwargs.get("top_k_retrieve", 50) - window_size: int = kwargs.get("window_size", 20) - window_size = min(window_size, top_k_retrieve) - step: int = kwargs.get("step", 10) - populate_exec_summary: bool = kwargs.get("populate_exec_summary", False) - - if self._vllm_batched: - # reranking using vllm - if len(set([len(req.candidates) for req in requests])) != 1: - raise ValueError( - "Batched requests must have the same number of candidates" - ) - - return self.sliding_windows_batched( - requests, - rank_start=max(rank_start, 0), - rank_end=min( - rank_end, len(requests[0].candidates) - ), # TODO: Fails arbitrary hit sizes - window_size=window_size, - step=step, - shuffle_candidates=shuffle_candidates, - logging=logging, - populate_exec_summary=populate_exec_summary, - ) - else: - # Normal operation mode - results = [] - for request in tqdm(requests): - result = self.sliding_windows( - request, - rank_start=max(rank_start, 0), - rank_end=min(rank_end, len(request.candidates)), - window_size=window_size, - step=step, - shuffle_candidates=shuffle_candidates, - logging=logging, - populate_exec_summary=populate_exec_summary, - ) - results.append(result) - return results + return super().rerank_batch( + requests=requests, + rank_start=rank_start, + rank_end=rank_end, + shuffle_candidates=shuffle_candidates, + logging=logging, + batched=self._batched, + **kwargs, + ) def run_llm_batched( self, From 438c8bd204ab082ec7063194c092d6c155721d72 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Sun, 18 Aug 2024 00:39:52 -0400 Subject: [PATCH 08/30] Added reorder policy for rank listwise os and fid score --- src/rank_llm/rerank/listwise/rank_fid.py | 71 +++---------------- .../rerank/listwise/rank_listwise_os_llm.py | 1 - src/rank_llm/rerank/reranker.py | 16 +++-- 3 files changed, 21 insertions(+), 67 deletions(-) diff --git a/src/rank_llm/rerank/listwise/rank_fid.py b/src/rank_llm/rerank/listwise/rank_fid.py index bf5c102a..a8fa7855 100644 --- a/src/rank_llm/rerank/listwise/rank_fid.py +++ b/src/rank_llm/rerank/listwise/rank_fid.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -from tqdm import tqdm from transformers import T5Tokenizer from rank_llm.data import Request, Result @@ -252,30 +251,27 @@ def _to_precision(self, precision: str) -> None: def __init__( self, + reorder_policy: ReorderPolicy, model: str, context_size: int = 150, prompt_mode: PromptMode = PromptMode.LiT5, # Placeholder for actual mode num_few_shot_examples: int = 0, - window_size: int = 20, - step_size: int = 10, precision: str = "bfloat16", device: str = "cuda", batched: bool = False, ) -> None: super().__init__( + reorder_policy=reorder_policy, model=model, context_size=context_size, prompt_mode=prompt_mode, num_few_shot_examples=num_few_shot_examples, - window_size=window_size, ) self._precision = precision self._tokenizer = T5Tokenizer.from_pretrained(model) self._llm = FiDCrossAttentionScore.from_pretrained(model).to(device).eval() self._device = device - self._window_size = window_size - self._stride = step_size self._batched = batched @@ -366,60 +362,15 @@ def rerank_batch( logging: bool = False, **kwargs: Any, ) -> List[Result]: - top_k_retrieve: int = kwargs.get("top_k_retrieve", 100) - - window_size: int = kwargs.get("window_size", self._window_size) - window_size = min(window_size, top_k_retrieve) - step: int = kwargs.get("step_size", self._stride) - - populate_exec_summary: bool = kwargs.get("populate_exec_summary", False) - - batch_size = kwargs.get("batch_size", 1) - - if self._batched: - # reranking using vllm - if len(set([len(req.candidates) for req in requests])) != 1: - raise ValueError( - "Batched requests must have the same number of candidates" - ) - - result = [] - - with tqdm(range(0, len(requests))) as bar: - for i in range(0, len(requests), batch_size): - batch = requests[i : min(i + batch_size, len(requests))] - batch_result = self.sliding_windows_batched( - batch, - rank_start=max(rank_start, 0), - rank_end=min( - rank_end, len(requests[0].candidates) - ), # TODO: Fails arbitrary hit sizes - window_size=window_size, - step=step, - shuffle_candidates=shuffle_candidates, - logging=logging, - populate_exec_summary=populate_exec_summary, - ) - result.extend(batch_result) - bar.update(len(batch)) - - return result - else: - # Normal operation mode - results = [] - for request in tqdm(requests): - result = self.sliding_windows( - request, - rank_start=max(rank_start, 0), - rank_end=min(rank_end, len(request.candidates)), - window_size=window_size, - step=step, - shuffle_candidates=shuffle_candidates, - logging=logging, - populate_exec_summary=populate_exec_summary, - ) - results.append(result) - return results + return super().rerank_batch( + requests=requests, + rank_start=rank_start, + rank_end=rank_end, + shuffle_candidates=shuffle_candidates, + logging=logging, + batched=self._batched, + **kwargs, + ) def run_llm_batched( self, prompts: List[List[Dict[str, str]]], **kwargs diff --git a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py index 0cf32875..0279e4bb 100644 --- a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py @@ -38,7 +38,6 @@ def __init__( device: str = "cuda", num_gpus: int = 1, variable_passages: bool = False, - window_size: int = 20, system_message: str = None, vllm_batched: bool = False, ) -> None: diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 39f0757b..10ef576d 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -171,6 +171,13 @@ def create_agent( """ use_azure_openai: bool = kwargs.get("use_azure_openai", False) + keys_reorder_policy = [("reorder_policy", "reorder_policy.sliding_window")] + [reorder_policy_name] = extract_kwargs(keys_reorder_policy, **kwargs) + + reorder_policy = ListwiseRankLLM.get_reorder_policy( + reorder_policy_name, **kwargs + ) + if interactive and default_agent is not None: # Default rerank agent agent = default_agent @@ -233,6 +240,7 @@ def create_agent( ] = extract_kwargs(keys_and_defaults, **kwargs) agent = RankListwiseOSLLM( + reorder_policy=reorder_policy, model=model_full_paths[model_path] if model_path in model_full_paths else model_path, @@ -243,7 +251,6 @@ def create_agent( device=device, num_gpus=num_gpus, variable_passages=variable_passages, - window_size=window_size, system_message=system_message, vllm_batched=vllm_batched, ) @@ -262,7 +269,6 @@ def create_agent( ] ( - reorder_policy, context_size, prompt_mode, num_few_shot_examples, @@ -272,9 +278,7 @@ def create_agent( ) = extract_kwargs(keys_and_defaults, **kwargs) agent = RankFiDDistill( - reorder_policy=ListwiseRankLLM.get_reorder_policy( - reorder_policy, **kwargs - ), + reorder_policy=reorder_policy, model=model_path, context_size=context_size, prompt_mode=prompt_mode, @@ -307,11 +311,11 @@ def create_agent( ) = extract_kwargs(keys_and_defaults, **kwargs) agent = RankFiDScore( + reorder_policy=reorder_policy, model=model_path, context_size=context_size, prompt_mode=prompt_mode, num_few_shot_examples=num_few_shot_examples, - window_size=window_size, precision=precision, device=device, batched=vllm_batched, From 2aa8644e8a8253d0390a254316e1caf3d2329d01 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Sun, 18 Aug 2024 02:14:35 -0400 Subject: [PATCH 09/30] Revised bug in OS LLM, add Rank GPT, deprecated old functions --- .../rerank/listwise/listwise_rankllm.py | 5 +++ src/rank_llm/rerank/listwise/rank_gpt.py | 37 ++++++++----------- .../rerank/listwise/rank_listwise_os_llm.py | 2 +- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index d69d0965..5db4d7b9 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -5,6 +5,7 @@ from abc import ABC from datetime import datetime from typing import Any, Dict, List, Tuple, Union +from warnings import deprecated from ftfy import fix_text from tqdm import tqdm @@ -132,6 +133,7 @@ def max_tokens(self) -> int: """ return self._context_size + @deprecated("old sliding window pipeline is deprecated. please use reorder policy") def permutation_pipeline_batched( self, results: List[Result], @@ -184,6 +186,7 @@ def permutation_pipeline_batched( return results + @deprecated("old sliding window pipeline is deprecated. please use reorder policy") def permutation_pipeline( self, result: Result, @@ -244,6 +247,7 @@ def shuffle_and_rescore( cand["score"] = 1.0 / (i + 1) cand["rank"] = i + 1 + @deprecated("old sliding window pipeline is deprecated. please use reorder policy") def sliding_windows_batched( self, requests: List[Request], @@ -294,6 +298,7 @@ def sliding_windows_batched( start_pos = start_pos - step return rerank_results + @deprecated("old sliding window pipeline is deprecated. please use reorder policy") def sliding_windows( self, request: Request, diff --git a/src/rank_llm/rerank/listwise/rank_gpt.py b/src/rank_llm/rerank/listwise/rank_gpt.py index 5c8cd3e4..491ed82e 100644 --- a/src/rank_llm/rerank/listwise/rank_gpt.py +++ b/src/rank_llm/rerank/listwise/rank_gpt.py @@ -4,12 +4,12 @@ import openai import tiktoken -from tqdm import tqdm from rank_llm.data import Request, Result from rank_llm.rerank import PromptMode from .listwise_rankllm import ListwiseRankLLM +from .reorder.reorder_policy import ReorderPolicy class CompletionMode(Enum): @@ -21,11 +21,11 @@ class CompletionMode(Enum): class SafeOpenai(ListwiseRankLLM): def __init__( self, + reorder_policy: ReorderPolicy, model: str, context_size: int, prompt_mode: PromptMode = PromptMode.RANK_GPT, num_few_shot_examples: int = 0, - window_size: int = 20, keys=None, key_start_id=None, proxy=None, @@ -61,7 +61,11 @@ def __init__( - Azure AI integration is depends on the presence of `api_type`, `api_base`, and `api_version`. """ super().__init__( - model, context_size, prompt_mode, num_few_shot_examples, window_size + reorder_policy=reorder_policy, + model=model, + context_size=context_size, + prompt_mode=prompt_mode, + num_few_shot_examples=num_few_shot_examples, ) if isinstance(keys, str): keys = [keys] @@ -100,24 +104,15 @@ def rerank_batch( logging: bool = False, **kwargs: Any, ) -> List[Result]: - window_size: int = kwargs.get("window_size", 20) - step: int = kwargs.get("step", 10) - populate_exec_summary: bool = kwargs.get("populate_exec_summary", False) - - results = [] - for request in tqdm(requests): - result = self.sliding_windows( - request, - rank_start=max(rank_start, 0), - rank_end=min(rank_end, len(request.candidates)), - window_size=window_size, - step=step, - shuffle_candidates=shuffle_candidates, - logging=logging, - populate_exec_summary=populate_exec_summary, - ) - results.append(result) - return results + return super().rerank_batch( + requests=requests, + rank_start=rank_start, + rank_end=rank_end, + shuffle_candidates=shuffle_candidates, + logging=logging, + batched=False, # You never batch in RankGPT + **kwargs, + ) def _call_completion( self, diff --git a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py index 0279e4bb..d5e87ceb 100644 --- a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py @@ -124,7 +124,7 @@ def rerank_batch( rank_end=rank_end, shuffle_candidates=shuffle_candidates, logging=logging, - batched=self._batched, + batched=self._vllm_batched, **kwargs, ) From 7a19b87f184a27f7a4e531eb050ef095b6c237e7 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Sat, 24 Aug 2024 21:08:25 -0400 Subject: [PATCH 10/30] Finish the tournament sort node --- .../rerank/listwise/listwise_rankllm.py | 9 +- .../reorder/tournament_sort_rerank_policy.py | 221 ++++++++++++++++++ 2 files changed, 225 insertions(+), 5 deletions(-) create mode 100644 src/rank_llm/rerank/listwise/reorder/tournament_sort_rerank_policy.py diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index 5db4d7b9..ff1f0163 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -5,7 +5,6 @@ from abc import ABC from datetime import datetime from typing import Any, Dict, List, Tuple, Union -from warnings import deprecated from ftfy import fix_text from tqdm import tqdm @@ -133,7 +132,7 @@ def max_tokens(self) -> int: """ return self._context_size - @deprecated("old sliding window pipeline is deprecated. please use reorder policy") + # @deprecated("old sliding window pipeline is deprecated. please use reorder policy") def permutation_pipeline_batched( self, results: List[Result], @@ -186,7 +185,7 @@ def permutation_pipeline_batched( return results - @deprecated("old sliding window pipeline is deprecated. please use reorder policy") + # @deprecated("old sliding window pipeline is deprecated. please use reorder policy") def permutation_pipeline( self, result: Result, @@ -247,7 +246,7 @@ def shuffle_and_rescore( cand["score"] = 1.0 / (i + 1) cand["rank"] = i + 1 - @deprecated("old sliding window pipeline is deprecated. please use reorder policy") + # @deprecated("old sliding window pipeline is deprecated. please use reorder policy") def sliding_windows_batched( self, requests: List[Request], @@ -298,7 +297,7 @@ def sliding_windows_batched( start_pos = start_pos - step return rerank_results - @deprecated("old sliding window pipeline is deprecated. please use reorder policy") + # @deprecated("old sliding window pipeline is deprecated. please use reorder policy") def sliding_windows( self, request: Request, diff --git a/src/rank_llm/rerank/listwise/reorder/tournament_sort_rerank_policy.py b/src/rank_llm/rerank/listwise/reorder/tournament_sort_rerank_policy.py new file mode 100644 index 00000000..d33f8d58 --- /dev/null +++ b/src/rank_llm/rerank/listwise/reorder/tournament_sort_rerank_policy.py @@ -0,0 +1,221 @@ +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union + +from rank_llm.data import Result + +from .reorder_policy import ModelFunction, ReorderPolicy + + +@dataclass +class ResortRequest: + indices: List[int] + result: List[int] + + +class TournamentSortNode: + @staticmethod + def build( + inds: List[int], window_size: int, top_k: int + ) -> Tuple[ + "TournamentSortNode", + List["TournamentSortNode"], + Dict[int, "TournamentSortNode"], + ]: + assert window_size % top_k == 0 + children_size = window_size // top_k + + cs: List["TournamentSortNode"] = [ + TournamentSortNode(top_k=top_k, index=x) for x in inds + ] + + base_nodes = {idx: c for idx, c in zip(inds, cs)} + all_cs: List["TournamentSortNode"] = [] + all_cs.extend(cs) + + while len(cs) > 1: + nxt = [] + for c in range(0, len(cs), children_size): + children = cs[c : min(len(cs), c + children_size)] + if len(children) == 1: + nxt.append(children[0]) + else: + nxt.append(TournamentSortNode(top_k=top_k, children=children)) + all_cs.append(nxt[-1]) + + cs = nxt + + return cs[0], all_cs, base_nodes + + def __init__( + self, + top_k: int, + *, + children: Union[List["TournamentSortNode"]] = None, + index: int = None, + ): + super().__init__() + + self.parent: "TournamentSortNode" = None + + self._top_k = top_k + + if children is not None: + for child in children: + child.parent = self + + self._n = len(children) + self._children = children + self._top: List[int] = None + self._tmp: List[int] = None + else: + self._n = -1 + self._index = index + self._top: List[int] = [index] + self._tmp: List[int] = None + + def reset(self): + if self._n == -1: + return + self._top = None + + def invalidate(self): + if self._n != -1: + return + + self._top = [] + + def get_resort_param(self) -> Union[List[int], None]: + if self._n == -1 or self._top is not None: + return None + self._tmp = [x for child in self._children for x in child.top()] + return [ind for ind in self._tmp] + + def resort(self, perm: List[int]): + assert self._tmp is not None and self._top is None + + tops = [] + for i in perm: + if len(tops) > self._top_k: + break + ind = self._tmp[i] + if ind not in tops: + tops.append(ind) + + self._top = tops + + return + + def top(self) -> List[int]: + assert self._top is not None + return self._top[: min(len(self._top), self._top_k)] + + def __str__(self): + if self._n == -1: + return f"[{self._index}]" + else: + return f"({' '.join([str(x) for x in self._children])})" + + +class TournamentSorter: + def _get_random_indices( + self, expect_size: int, ind_choices: List[int] + ) -> List[int]: + choices = set(ind_choices) + result = [] + for j in reversed(range(self._n_passage)): + if len(result) + len(ind_choices) >= expect_size: + break + if j not in choices: + result.append(j) + + for j in reversed(range(self._n_passage)): + if len(result) + len(ind_choices) >= expect_size: + break + result.append(j) + return result + + def _pad_size(self, inds: List[int]) -> List[int]: + if len(inds) >= self._window_size: + return inds + else: + fitters = self._get_random_indices(self._window_size, inds) + return inds + fitters + + def _unpad_perm(self, inds: List[int], padded: List[int], perm: List[int]): + return [x for x in perm if x < len(inds)] + + def __init__(self, indices: List[int], window_size: int, r: int): + super().__init__() + self._window_size = window_size + self._r = r + + self._n_passage = len(indices) + + self._tr, self._all_node, self._idx_to_node = TournamentSortNode.build( + list(range(self._n_passage)), window_size=window_size, top_k=r + ) + + def _pop(self, x: int) -> List[TournamentSortNode]: + on: TournamentSortNode = self._idx_to_node[x] + lst = [] + while on is not None: + lst.append(on) + on.invalidate() + on.reset() + on = on.parent + return lst + + def perform(self, top_k: int): + result = [] + + # firstly, simple sort + for nd in self._all_node: + resort_param = nd.get_resort_param() + if resort_param is not None: + padded = self._pad_size(resort_param) + request = ResortRequest(padded, []) + yield request + cleaned_result = self._unpad_perm(resort_param, padded, request.result) + nd.resort(cleaned_result) + + while len(result) < top_k: + tpv = self._tr.top()[0] + result.append(tpv) + nodes = self._pop(tpv) + for node in nodes: + resort_param = node.get_resort_param() + if resort_param is not None: + padded = self._pad_size(resort_param) + request = ResortRequest(padded, []) + yield request + assert len(request.result) > 0 + cleaned_result = self._unpad_perm( + resort_param, padded, request.result + ) + node.resort(cleaned_result) + + return result + + +class TournamentSortReorderPolicy(ReorderPolicy): + def __init__(self, top_k: int, window_size: int): + super().__init__() + self._top_k = top_k + self._window_size = window_size + + def reorder( + self, + requests: List[Result], + rank_start: int, + rank_end: int, + model: ModelFunction, + **kwargs, + ) -> list[Result]: + pass + + @staticmethod + def name() -> str: + return "reorder_policy.tournament_sort" + + def max_selected_indices(self) -> int: + return self._window_size From c8734f0f41a9ffa5eb9d195bed34b5406a7416a8 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Mon, 26 Aug 2024 18:23:52 -0400 Subject: [PATCH 11/30] Finish tournament sort --- .../rerank/listwise/listwise_rankllm.py | 5 +- .../reorder/tournament_sort_rerank_policy.py | 93 ++++++++++++++++++- src/rank_llm/rerank/reranker.py | 5 +- src/rank_llm/scripts/run_rank_llm.py | 8 ++ 4 files changed, 101 insertions(+), 10 deletions(-) diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index ff1f0163..fc7f3517 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -16,10 +16,13 @@ ReorderPolicy, SlidingWindowReorderPolicy, ) +from rank_llm.rerank.listwise.reorder.tournament_sort_rerank_policy import ( + TournamentSortReorderPolicy, +) logger = logging.getLogger(__name__) -SUPPORT_REORDER_POLICIES = [SlidingWindowReorderPolicy] +SUPPORT_REORDER_POLICIES = [SlidingWindowReorderPolicy, TournamentSortReorderPolicy] class ListwiseRankLLM(RankLLM, ABC): diff --git a/src/rank_llm/rerank/listwise/reorder/tournament_sort_rerank_policy.py b/src/rank_llm/rerank/listwise/reorder/tournament_sort_rerank_policy.py index d33f8d58..052b8b59 100644 --- a/src/rank_llm/rerank/listwise/reorder/tournament_sort_rerank_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/tournament_sort_rerank_policy.py @@ -1,5 +1,6 @@ +import copy from dataclasses import dataclass -from typing import Dict, List, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union from rank_llm.data import Result @@ -144,6 +145,14 @@ def _pad_size(self, inds: List[int]) -> List[int]: def _unpad_perm(self, inds: List[int], padded: List[int], perm: List[int]): return [x for x in perm if x < len(inds)] + def _fill_up(self, result: List[int]) -> List[int]: + result_set = set(result) + filled_up_result = [x for x in result] + for idx in self._indices: + if idx not in result_set: + filled_up_result.append(idx) + return filled_up_result + def __init__(self, indices: List[int], window_size: int, r: int): super().__init__() self._window_size = window_size @@ -151,8 +160,10 @@ def __init__(self, indices: List[int], window_size: int, r: int): self._n_passage = len(indices) + self._indices = indices + self._tr, self._all_node, self._idx_to_node = TournamentSortNode.build( - list(range(self._n_passage)), window_size=window_size, top_k=r + indices, window_size=window_size, top_k=r ) def _pop(self, x: int) -> List[TournamentSortNode]: @@ -194,11 +205,51 @@ def perform(self, top_k: int): ) node.resort(cleaned_result) - return result + return self._fill_up(result) + + +def multiple_sort( + requests: List[Result], + indices_batch: List[List[int]], + runner: Callable[[List[Tuple[Result, List[int]]]], List[List[int]]], + window_size: int, + r: int, + top_k: int, +): + batch_size = len(requests) + tournament_sorters: List[TournamentSorter] = [ + TournamentSorter(indices, window_size, r) for indices in indices_batch + ] + progress = [ + tournament_sorter.perform(top_k) for tournament_sorter in tournament_sorters + ] + result = [None for _ in range(batch_size)] + left_not_sorted = set(range(batch_size)) + + while len(left_not_sorted) > 0: + perm_request = [] + + finish_requests = [] + for idx in left_not_sorted: + try: + req = next(progress[idx]) + perm_request.append((idx, req)) + except StopIteration as e: + result[idx] = e.value + finish_requests.append(idx) + for idx in finish_requests: + left_not_sorted.remove(idx) + + outputs = runner([(requests[idx], req.indices) for idx, req in perm_request]) + + for (idx, req), output in zip(perm_request, outputs): + req.result = output + + return result class TournamentSortReorderPolicy(ReorderPolicy): - def __init__(self, top_k: int, window_size: int): + def __init__(self, window_size: int, top_k: int = 10, **kwargs): super().__init__() self._top_k = top_k self._window_size = window_size @@ -211,7 +262,39 @@ def reorder( model: ModelFunction, **kwargs, ) -> list[Result]: - pass + runner: Callable[ + [List[Tuple[Result, List[int]]]], List[List[int]] + ] = lambda reqs: model.execute( + model.create_prompt(reqs), [ind for req, ind in reqs] + ) + + request_ranks = multiple_sort( + requests, + [list(range(rank_start, rank_end)) for _ in range(len(requests))], + runner=runner, + window_size=self._window_size, + top_k=self._top_k, + r=1, + ) + + results = [ + Result( + query=copy.deepcopy(request.query), + candidates=self._reorder_by_rank( + copy.deepcopy(request.candidates), + [*range(len(request.candidates))], + rank, + ), + ranking_exec_summary=[], + ) + for request, rank in zip(requests, request_ranks) + ] + + for result, request in zip(results, requests): + for j in range(len(result.candidates)): + result.candidates[j].score = request.candidates[j].score + + return results @staticmethod def name() -> str: diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 10ef576d..284a95b4 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -174,9 +174,7 @@ def create_agent( keys_reorder_policy = [("reorder_policy", "reorder_policy.sliding_window")] [reorder_policy_name] = extract_kwargs(keys_reorder_policy, **kwargs) - reorder_policy = ListwiseRankLLM.get_reorder_policy( - reorder_policy_name, **kwargs - ) + reorder_policy = ListwiseRankLLM.get_reorder_policy(**kwargs) if interactive and default_agent is not None: # Default rerank agent @@ -258,7 +256,6 @@ def create_agent( print(f"Completed loading {model_path}") elif "lit5-distill" in model_path.lower(): keys_and_defaults = [ - ("reorder_policy", "reorder_policy.sliding_window"), ("context_size", 150), ("prompt_mode", PromptMode.LiT5), ("num_few_shot_examples", 0), diff --git a/src/rank_llm/scripts/run_rank_llm.py b/src/rank_llm/scripts/run_rank_llm.py index 29326add..8c4d251b 100644 --- a/src/rank_llm/scripts/run_rank_llm.py +++ b/src/rank_llm/scripts/run_rank_llm.py @@ -37,6 +37,7 @@ def main(args): system_message = args.system_message vllm_batched = args.vllm_batched batch_size = args.batch_size + reorder_policy = args.reorder_policy _ = retrieve_and_rerank( model_path=model_path, @@ -60,6 +61,7 @@ def main(args): system_message=system_message, vllm_batched=vllm_batched, batch_size=batch_size, + reorder_policy=reorder_policy, ) @@ -174,5 +176,11 @@ def main(args): help="batch size of the non vllm-determined-batch-size models. -1 means not allowed be in batch", type=int, ) + parser.add_argument( + "--reorder_policy", + default="reorder_policy.sliding_window", + help="policy in reordering. defaultly to be sliding window", + type=str, + ) args = parser.parse_args() main(args) From 7981c06ba2912e0449a89e50c7b8b0d62a658ab7 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Fri, 30 Aug 2024 16:32:44 -0400 Subject: [PATCH 12/30] Finish reorganize of parameters, move window_size to ListwiseRankLLM --- .../rerank/listwise/listwise_rankllm.py | 27 ++++++++++++++++--- src/rank_llm/rerank/listwise/rank_fid.py | 6 ++++- src/rank_llm/rerank/listwise/rank_gpt.py | 2 ++ .../rerank/listwise/rank_listwise_os_llm.py | 2 ++ .../rerank/listwise/reorder/reorder_policy.py | 24 +++++++---------- ...y.py => tournament_sort_reorder_policy.py} | 12 ++++----- src/rank_llm/rerank/reranker.py | 6 +++-- src/rank_llm/scripts/run_rank_llm.py | 12 ++------- 8 files changed, 53 insertions(+), 38 deletions(-) rename src/rank_llm/rerank/listwise/reorder/{tournament_sort_rerank_policy.py => tournament_sort_reorder_policy.py} (96%) diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index fc7f3517..152a4f04 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -1,4 +1,5 @@ import copy +import json import logging import random import re @@ -16,7 +17,7 @@ ReorderPolicy, SlidingWindowReorderPolicy, ) -from rank_llm.rerank.listwise.reorder.tournament_sort_rerank_policy import ( +from rank_llm.rerank.listwise.reorder.tournament_sort_reorder_policy import ( TournamentSortReorderPolicy, ) @@ -43,6 +44,7 @@ def __init__( reorder_policy: ReorderPolicy, model: str, context_size: int, + window_size: int, prompt_mode: PromptMode, num_few_shot_examples: int, ) -> None: @@ -50,6 +52,7 @@ def __init__( self._num_few_shot_examples = num_few_shot_examples self.reorder_policy = reorder_policy + self._window_size = window_size def rerank_batch( self, @@ -566,11 +569,27 @@ def execute( for x, selected_indices in zip(batch, selected_indices_batch) ] - return ModelFunction(create_prompt=create_prompt, execute=execute) + return ModelFunction( + create_prompt=create_prompt, execute=execute, window_size=self._window_size + ) @staticmethod def get_reorder_policy(reorder_policy: str, **kwargs): for policy in SUPPORT_REORDER_POLICIES: - if policy.name() == reorder_policy: - return policy(**kwargs) + if reorder_policy.startswith(policy.name()): + reorder_params = reorder_policy[len(policy.name()) :] + if len(reorder_params) <= 1: + return policy() + else: + assert reorder_params[0] == ":" and reorder_params[1] == "{" + reorder_params = reorder_params[1:] + try: + reorder_param_dict = json.loads(reorder_params) + if not isinstance(reorder_param_dict, dict): + raise Exception() + except Exception as e: + print(e) + raise Exception(f"Cannot load reorder policy {reorder_policy}") + return policy(**reorder_param_dict) + raise Exception(f"Cannot find reorder policy {reorder_policy}") diff --git a/src/rank_llm/rerank/listwise/rank_fid.py b/src/rank_llm/rerank/listwise/rank_fid.py index a8fa7855..b28bdda2 100644 --- a/src/rank_llm/rerank/listwise/rank_fid.py +++ b/src/rank_llm/rerank/listwise/rank_fid.py @@ -33,6 +33,7 @@ def __init__( reorder_policy: ReorderPolicy, model: str, context_size: int = 150, + window_size: int = 20, prompt_mode: PromptMode = PromptMode.LiT5, # Placeholder for actual mode num_few_shot_examples: int = 0, precision: str = "bfloat16", @@ -46,6 +47,7 @@ def __init__( reorder_policy=reorder_policy, model=model, context_size=context_size, + window_size=window_size, prompt_mode=prompt_mode, num_few_shot_examples=num_few_shot_examples, ) @@ -61,7 +63,7 @@ def __init__( " > ".join( map( lambda x: f"[{x}]", - range(1, reorder_policy.max_selected_indices() + 1), + range(1, self._window_size + 1), ) ) ) @@ -254,6 +256,7 @@ def __init__( reorder_policy: ReorderPolicy, model: str, context_size: int = 150, + window_size: int = 20, prompt_mode: PromptMode = PromptMode.LiT5, # Placeholder for actual mode num_few_shot_examples: int = 0, precision: str = "bfloat16", @@ -264,6 +267,7 @@ def __init__( reorder_policy=reorder_policy, model=model, context_size=context_size, + window_size=window_size, prompt_mode=prompt_mode, num_few_shot_examples=num_few_shot_examples, ) diff --git a/src/rank_llm/rerank/listwise/rank_gpt.py b/src/rank_llm/rerank/listwise/rank_gpt.py index 491ed82e..3636ec0a 100644 --- a/src/rank_llm/rerank/listwise/rank_gpt.py +++ b/src/rank_llm/rerank/listwise/rank_gpt.py @@ -24,6 +24,7 @@ def __init__( reorder_policy: ReorderPolicy, model: str, context_size: int, + window_size: int, prompt_mode: PromptMode = PromptMode.RANK_GPT, num_few_shot_examples: int = 0, keys=None, @@ -64,6 +65,7 @@ def __init__( reorder_policy=reorder_policy, model=model, context_size=context_size, + window_size=window_size, prompt_mode=prompt_mode, num_few_shot_examples=num_few_shot_examples, ) diff --git a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py index d5e87ceb..0f3b0a7a 100644 --- a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py @@ -33,6 +33,7 @@ def __init__( model: str, name: str, context_size: int = 4096, + window_size: int = 20, prompt_mode: PromptMode = PromptMode.RANK_GPT, num_few_shot_examples: int = 0, device: str = "cuda", @@ -75,6 +76,7 @@ def __init__( reorder_policy=reorder_policy, model=model, context_size=context_size, + window_size=window_size, prompt_mode=prompt_mode, num_few_shot_examples=num_few_shot_examples, ) diff --git a/src/rank_llm/rerank/listwise/reorder/reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/reorder_policy.py index 553e9def..5dd00798 100644 --- a/src/rank_llm/rerank/listwise/reorder/reorder_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/reorder_policy.py @@ -20,6 +20,9 @@ class ModelFunction: [List[Union[str, Dict[str, str]]], List[List[int]]], List[List[int]] ] + # Accepted Window Size + window_size: int + class ReorderPolicy(ABC): @abstractmethod @@ -33,10 +36,6 @@ def reorder( ) -> list[Result]: pass - @abstractmethod - def max_selected_indices(self) -> int: - pass - @staticmethod @abstractmethod def name() -> str: @@ -69,15 +68,13 @@ def _reorder_by_rank(items: List[T], idxes: List[int], rank: List[int]) -> List[ class SlidingWindowReorderPolicy(ReorderPolicy): def __init__( self, - window_size: int = 20, - step_size: int = 10, + step: int = 10, shuffle_candidates: bool = False, **kwargs, ): - self._window_size = window_size - self._step_size = step_size + self._step_size = step - self._shuffle_candidates = shuffle_candidates + self._shuffle_candidates = bool(shuffle_candidates) def reorder( self, @@ -90,6 +87,8 @@ def reorder( populate_exec_summary=False, **kwargs, ) -> List[Result]: + window_size = model.window_size + rerank_results = [ Result( query=copy.deepcopy(request.query), @@ -106,7 +105,7 @@ def reorder( request_ranks = [[*range(len(request.candidates))] for request in requests] end_pos = rank_end - start_pos = rank_end - self._window_size + start_pos = rank_end - window_size # end_pos > rank_start ensures that the list is non-empty while allowing last window to be smaller than window_size # start_pos + step != rank_start prevents processing of redundant windows (e.g. 0-20, followed by 0-10) @@ -151,7 +150,4 @@ def reorder( @staticmethod def name() -> str: - return "reorder_policy.sliding_window" - - def max_selected_indices(self) -> int: - return self._window_size + return "sliding_window" diff --git a/src/rank_llm/rerank/listwise/reorder/tournament_sort_rerank_policy.py b/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py similarity index 96% rename from src/rank_llm/rerank/listwise/reorder/tournament_sort_rerank_policy.py rename to src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py index 052b8b59..065a0ff7 100644 --- a/src/rank_llm/rerank/listwise/reorder/tournament_sort_rerank_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py @@ -249,10 +249,9 @@ def multiple_sort( class TournamentSortReorderPolicy(ReorderPolicy): - def __init__(self, window_size: int, top_k: int = 10, **kwargs): + def __init__(self, top_k: int = 10, **kwargs): super().__init__() self._top_k = top_k - self._window_size = window_size def reorder( self, @@ -262,6 +261,8 @@ def reorder( model: ModelFunction, **kwargs, ) -> list[Result]: + window_size = model.window_size + runner: Callable[ [List[Tuple[Result, List[int]]]], List[List[int]] ] = lambda reqs: model.execute( @@ -272,7 +273,7 @@ def reorder( requests, [list(range(rank_start, rank_end)) for _ in range(len(requests))], runner=runner, - window_size=self._window_size, + window_size=window_size, top_k=self._top_k, r=1, ) @@ -298,7 +299,4 @@ def reorder( @staticmethod def name() -> str: - return "reorder_policy.tournament_sort" - - def max_selected_indices(self) -> int: - return self._window_size + return "tournament_sort" diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 284a95b4..c0b73040 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -171,10 +171,12 @@ def create_agent( """ use_azure_openai: bool = kwargs.get("use_azure_openai", False) - keys_reorder_policy = [("reorder_policy", "reorder_policy.sliding_window")] + keys_reorder_policy = [("reorder_policy", "sliding_window")] [reorder_policy_name] = extract_kwargs(keys_reorder_policy, **kwargs) - reorder_policy = ListwiseRankLLM.get_reorder_policy(**kwargs) + reorder_policy = ListwiseRankLLM.get_reorder_policy( + reorder_policy=reorder_policy_name + ) if interactive and default_agent is not None: # Default rerank agent diff --git a/src/rank_llm/scripts/run_rank_llm.py b/src/rank_llm/scripts/run_rank_llm.py index 8c4d251b..78cb227c 100644 --- a/src/rank_llm/scripts/run_rank_llm.py +++ b/src/rank_llm/scripts/run_rank_llm.py @@ -32,7 +32,6 @@ def main(args): variable_passages = args.variable_passages retrieval_mode = RetrievalMode.DATASET num_passes = args.num_passes - step_size = args.step_size window_size = args.window_size system_message = args.system_message vllm_batched = args.vllm_batched @@ -57,7 +56,6 @@ def main(args): variable_passages=variable_passages, num_passes=num_passes, window_size=window_size, - step_size=step_size, system_message=system_message, vllm_batched=vllm_batched, batch_size=batch_size, @@ -151,13 +149,7 @@ def main(args): "--window_size", type=int, default=20, - help="window size for the sliding window approach", - ) - parser.add_argument( - "--step_size", - type=int, - default=10, - help="step size for the sliding window approach", + help="window size for the LLM", ) parser.add_argument( "--system_message", @@ -178,7 +170,7 @@ def main(args): ) parser.add_argument( "--reorder_policy", - default="reorder_policy.sliding_window", + default="sliding_window", help="policy in reordering. defaultly to be sliding window", type=str, ) From bbbdc31a256b85ccbb60a3c5b9384d90a8639077 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Wed, 4 Sep 2024 23:10:06 -0400 Subject: [PATCH 13/30] Added window size back --- src/rank_llm/rerank/reranker.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index c0b73040..75702482 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -246,6 +246,7 @@ def create_agent( else model_path, name=model_path, context_size=context_size, + window_size=window_size, prompt_mode=prompt_mode, num_few_shot_examples=num_few_shot_examples, device=device, @@ -259,6 +260,7 @@ def create_agent( elif "lit5-distill" in model_path.lower(): keys_and_defaults = [ ("context_size", 150), + ("window_size", 20), ("prompt_mode", PromptMode.LiT5), ("num_few_shot_examples", 0), ("precision", "bfloat16"), @@ -269,6 +271,7 @@ def create_agent( ( context_size, + window_size, prompt_mode, num_few_shot_examples, precision, @@ -280,6 +283,7 @@ def create_agent( reorder_policy=reorder_policy, model=model_path, context_size=context_size, + window_size=window_size, prompt_mode=prompt_mode, num_few_shot_examples=num_few_shot_examples, precision=precision, @@ -313,6 +317,7 @@ def create_agent( reorder_policy=reorder_policy, model=model_path, context_size=context_size, + window_size=window_size, prompt_mode=prompt_mode, num_few_shot_examples=num_few_shot_examples, precision=precision, From bb0ad2ba78259bd65d019ed0ead39fbb0c539a49 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Thu, 5 Sep 2024 21:42:06 +0000 Subject: [PATCH 14/30] Fix Rerankers --- src/rank_llm/rerank/listwise/lit5_reranker.py | 7 +++ .../rerank/listwise/vicuna_reranker.py | 8 +++ .../rerank/listwise/zephyr_reranker.py | 59 +++++++++++-------- 3 files changed, 48 insertions(+), 26 deletions(-) diff --git a/src/rank_llm/rerank/listwise/lit5_reranker.py b/src/rank_llm/rerank/listwise/lit5_reranker.py index 8b755e67..94d5ab5c 100644 --- a/src/rank_llm/rerank/listwise/lit5_reranker.py +++ b/src/rank_llm/rerank/listwise/lit5_reranker.py @@ -1,5 +1,6 @@ from rank_llm.data import Request, Result from rank_llm.rerank.listwise.rank_fid import RankFiDDistill, RankFiDScore +from rank_llm.rerank.listwise.reorder.reorder_policy import ReorderPolicy, SlidingWindowReorderPolicy from rank_llm.rerank.rankllm import PromptMode from rank_llm.rerank.reranker import Reranker @@ -11,12 +12,16 @@ def __init__( context_size: int = 300, prompt_mode: PromptMode = PromptMode.LiT5, window_size: int = 20, + reorder_policy: ReorderPolicy = None ) -> None: + if reorder_policy is None: + reorder_policy = SlidingWindowReorderPolicy() agent = RankFiDDistill( model=model_path, context_size=context_size, prompt_mode=prompt_mode, window_size=window_size, + reorder_policy=reorder_policy ) self._reranker = Reranker(agent) @@ -62,6 +67,7 @@ def rerank( class LiT5ScoreReranker: def __init__( self, + reorder_policy: ReorderPolicy, model_path: str = "castorini/LiT5-Score-base", context_size: int = 300, prompt_mode: PromptMode = PromptMode.LiT5, @@ -69,6 +75,7 @@ def __init__( runfile_path: str = "runs/run.${topics}_${firststage}_${model//\//}", ) -> None: agent = RankFiDScore( + reorder_policy=reorder_policy, model=model_path, context_size=context_size, prompt_mode=prompt_mode, diff --git a/src/rank_llm/rerank/listwise/vicuna_reranker.py b/src/rank_llm/rerank/listwise/vicuna_reranker.py index 1bd52830..22f17b96 100644 --- a/src/rank_llm/rerank/listwise/vicuna_reranker.py +++ b/src/rank_llm/rerank/listwise/vicuna_reranker.py @@ -3,6 +3,7 @@ from rank_llm.data import Request, Result from rank_llm.rerank import PromptMode from rank_llm.rerank.listwise import RankListwiseOSLLM +from rank_llm.rerank.listwise.reorder.reorder_policy import ReorderPolicy, SlidingWindowReorderPolicy class VicunaReranker: @@ -16,10 +17,16 @@ def __init__( num_gpus: int = 1, variable_passages: bool = False, window_size: int = 20, + reorder_policy: ReorderPolicy = None, system_message: str = None, ) -> None: + + if reorder_policy is None: + reorder_policy = SlidingWindowReorderPolicy() + self._reranker = RankListwiseOSLLM( model=model_path, + name=model_path, context_size=context_size, prompt_mode=prompt_mode, num_few_shot_examples=num_few_shot_examples, @@ -28,6 +35,7 @@ def __init__( variable_passages=variable_passages, window_size=window_size, system_message=system_message, + reorder_policy=reorder_policy, ) def rerank_batch( diff --git a/src/rank_llm/rerank/listwise/zephyr_reranker.py b/src/rank_llm/rerank/listwise/zephyr_reranker.py index 7210c053..41c6d190 100644 --- a/src/rank_llm/rerank/listwise/zephyr_reranker.py +++ b/src/rank_llm/rerank/listwise/zephyr_reranker.py @@ -3,23 +3,29 @@ from rank_llm.data import Request, Result from rank_llm.rerank import PromptMode from rank_llm.rerank.listwise import RankListwiseOSLLM +from rank_llm.rerank.listwise.reorder.reorder_policy import ReorderPolicy, SlidingWindowReorderPolicy class ZephyrReranker: def __init__( - self, - model_path: str = "castorini/rank_zephyr_7b_v1_full", - context_size: int = 4096, - prompt_mode: PromptMode = PromptMode.RANK_GPT, - num_few_shot_examples: int = 0, - device: str = "cuda", - num_gpus: int = 1, - variable_passages: bool = True, - window_size: int = 20, - system_message: str = "You are RankLLM, an intelligent assistant that can rank passages based on their relevancy to the query", + self, + model_path: str = "castorini/rank_zephyr_7b_v1_full", + context_size: int = 4096, + prompt_mode: PromptMode = PromptMode.RANK_GPT, + num_few_shot_examples: int = 0, + device: str = "cuda", + num_gpus: int = 1, + variable_passages: bool = True, + window_size: int = 20, + reorder_policy: ReorderPolicy = None, + system_message: str = "You are RankLLM, an intelligent assistant that can rank passages based on their relevancy to the query", ) -> None: + if reorder_policy is None: + reorder_policy = SlidingWindowReorderPolicy() + self._reranker = RankListwiseOSLLM( model=model_path, + name=model_path, context_size=context_size, prompt_mode=prompt_mode, num_few_shot_examples=num_few_shot_examples, @@ -28,17 +34,18 @@ def __init__( variable_passages=variable_passages, window_size=window_size, system_message=system_message, + reorder_policy=reorder_policy ) def rerank_batch( - self, - requests: List[Request], - rank_start: int = 0, - rank_end: int = 100, - window_size: int = 20, - step: int = 10, - shuffle_candidates: bool = False, - logging: bool = False, + self, + requests: List[Request], + rank_start: int = 0, + rank_end: int = 100, + window_size: int = 20, + step: int = 10, + shuffle_candidates: bool = False, + logging: bool = False, ) -> List[Result]: """ Reranks a list of requests using the Zephyr model. @@ -69,14 +76,14 @@ def rerank_batch( ) def rerank( - self, - request: Request, - rank_start: int = 0, - rank_end: int = 100, - window_size: int = 20, - step: int = 10, - shuffle_candidates: bool = False, - logging: bool = False, + self, + request: Request, + rank_start: int = 0, + rank_end: int = 100, + window_size: int = 20, + step: int = 10, + shuffle_candidates: bool = False, + logging: bool = False, ) -> Result: """ Reranks a request using the Zephyr model. From 28e3a23f301fe8e685016d080e5e1ca756544226 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Thu, 5 Sep 2024 22:07:42 +0000 Subject: [PATCH 15/30] Added r parameter --- .../listwise/reorder/tournament_sort_reorder_policy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py index 065a0ff7..6750df62 100644 --- a/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py @@ -249,9 +249,10 @@ def multiple_sort( class TournamentSortReorderPolicy(ReorderPolicy): - def __init__(self, top_k: int = 10, **kwargs): + def __init__(self, top_k: int = 10, r: int = 1, **kwargs): super().__init__() self._top_k = top_k + self._r = r def reorder( self, @@ -275,7 +276,7 @@ def reorder( runner=runner, window_size=window_size, top_k=self._top_k, - r=1, + r=self._r, ) results = [ From 0e9edcee4d3a1d976e42e7414a0e1f579c66beb4 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Fri, 6 Sep 2024 01:52:09 +0000 Subject: [PATCH 16/30] Some bug fix --- .../reorder/tournament_sort_reorder_policy.py | 84 ++++++++++++------- src/rank_llm/rerank/reranker.py | 3 - src/rank_llm/scripts/run_rank_llm.py | 9 -- 3 files changed, 53 insertions(+), 43 deletions(-) diff --git a/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py index 6750df62..4068c086 100644 --- a/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py @@ -1,4 +1,5 @@ import copy +from collections import deque from dataclasses import dataclass from typing import Callable, Dict, List, Tuple, Union @@ -16,14 +17,13 @@ class ResortRequest: class TournamentSortNode: @staticmethod def build( - inds: List[int], window_size: int, top_k: int + inds: List[int], window_size: int, top_k: int ) -> Tuple[ "TournamentSortNode", List["TournamentSortNode"], Dict[int, "TournamentSortNode"], ]: assert window_size % top_k == 0 - children_size = window_size // top_k cs: List["TournamentSortNode"] = [ TournamentSortNode(top_k=top_k, index=x) for x in inds @@ -33,26 +33,34 @@ def build( all_cs: List["TournamentSortNode"] = [] all_cs.extend(cs) - while len(cs) > 1: - nxt = [] - for c in range(0, len(cs), children_size): - children = cs[c : min(len(cs), c + children_size)] - if len(children) == 1: - nxt.append(children[0]) - else: - nxt.append(TournamentSortNode(top_k=top_k, children=children)) - all_cs.append(nxt[-1]) + dq = deque(all_cs) - cs = nxt + while len(dq) > 1: - return cs[0], all_cs, base_nodes + cnt = 0 + children = [] + + while (len(dq) != 0) and (cnt + dq[0].estimate_size() <= window_size): + children.append(dq[0]) + cnt += dq[0].estimate_size() + dq.popleft() + + if len(children) == 1: + child = children[0] + dq.append(child) + else: + nd = TournamentSortNode(top_k=top_k, children=children) + all_cs.append(nd) + dq.append(nd) + + return dq[0], all_cs, base_nodes def __init__( - self, - top_k: int, - *, - children: Union[List["TournamentSortNode"]] = None, - index: int = None, + self, + top_k: int, + *, + children: Union[List["TournamentSortNode"]] = None, + index: int = None, ): super().__init__() @@ -110,6 +118,12 @@ def top(self) -> List[int]: assert self._top is not None return self._top[: min(len(self._top), self._top_k)] + def estimate_size(self) -> int: + if self._n == -1: + return 1 + else: + return self._top_k + def __str__(self): if self._n == -1: return f"[{self._index}]" @@ -119,7 +133,7 @@ def __str__(self): class TournamentSorter: def _get_random_indices( - self, expect_size: int, ind_choices: List[int] + self, expect_size: int, ind_choices: List[int] ) -> List[int]: choices = set(ind_choices) result = [] @@ -166,6 +180,8 @@ def __init__(self, indices: List[int], window_size: int, r: int): indices, window_size=window_size, top_k=r ) + self.count_inference = 0 + def _pop(self, x: int) -> List[TournamentSortNode]: on: TournamentSortNode = self._idx_to_node[x] lst = [] @@ -186,6 +202,7 @@ def perform(self, top_k: int): padded = self._pad_size(resort_param) request = ResortRequest(padded, []) yield request + self.count_inference += 1 cleaned_result = self._unpad_perm(resort_param, padded, request.result) nd.resort(cleaned_result) @@ -193,12 +210,17 @@ def perform(self, top_k: int): tpv = self._tr.top()[0] result.append(tpv) nodes = self._pop(tpv) + + if len(result) >= top_k: + break + for node in nodes: resort_param = node.get_resort_param() if resort_param is not None: padded = self._pad_size(resort_param) request = ResortRequest(padded, []) yield request + self.count_inference += 1 assert len(request.result) > 0 cleaned_result = self._unpad_perm( resort_param, padded, request.result @@ -209,12 +231,12 @@ def perform(self, top_k: int): def multiple_sort( - requests: List[Result], - indices_batch: List[List[int]], - runner: Callable[[List[Tuple[Result, List[int]]]], List[List[int]]], - window_size: int, - r: int, - top_k: int, + requests: List[Result], + indices_batch: List[List[int]], + runner: Callable[[List[Tuple[Result, List[int]]]], List[List[int]]], + window_size: int, + r: int, + top_k: int, ): batch_size = len(requests) tournament_sorters: List[TournamentSorter] = [ @@ -255,12 +277,12 @@ def __init__(self, top_k: int = 10, r: int = 1, **kwargs): self._r = r def reorder( - self, - requests: List[Result], - rank_start: int, - rank_end: int, - model: ModelFunction, - **kwargs, + self, + requests: List[Result], + rank_start: int, + rank_end: int, + model: ModelFunction, + **kwargs, ) -> list[Result]: window_size = model.window_size diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 78276b41..561d12b0 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -247,9 +247,6 @@ def create_agent( else model_path ), reorder_policy=reorder_policy, - model=model_full_paths[model_path] - if model_path in model_full_paths - else model_path, name=model_path, context_size=context_size, window_size=window_size, diff --git a/src/rank_llm/scripts/run_rank_llm.py b/src/rank_llm/scripts/run_rank_llm.py index f9c40ab0..43f69fb6 100644 --- a/src/rank_llm/scripts/run_rank_llm.py +++ b/src/rank_llm/scripts/run_rank_llm.py @@ -62,7 +62,6 @@ def main(args): window_size=window_size, system_message=system_message, vllm_batched=vllm_batched, - batch_size=batch_size, reorder_policy=reorder_policy, ) @@ -163,8 +162,6 @@ def main(args): parser.add_argument( "--step_size", type=int, - default=10, - help="step size for the sliding window approach", default=20, help="window size for the LLM", ) @@ -179,12 +176,6 @@ def main(args): action="store_true", help="whether to run the model in batches", ) - parser.add_argument( - "--batch_size", - default=-1, - help="batch size of the non vllm-determined-batch-size models. -1 means not allowed be in batch", - type=int, - ) parser.add_argument( "--reorder_policy", default="sliding_window", From 3005455c2d11c66de0b8bdeab9c2b6750961c474 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Fri, 6 Sep 2024 01:54:26 +0000 Subject: [PATCH 17/30] Reformatted --- src/rank_llm/rerank/listwise/lit5_reranker.py | 9 ++- .../reorder/tournament_sort_reorder_policy.py | 39 ++++++------ .../rerank/listwise/vicuna_reranker.py | 6 +- .../rerank/listwise/zephyr_reranker.py | 61 ++++++++++--------- 4 files changed, 61 insertions(+), 54 deletions(-) diff --git a/src/rank_llm/rerank/listwise/lit5_reranker.py b/src/rank_llm/rerank/listwise/lit5_reranker.py index 94d5ab5c..769f1611 100644 --- a/src/rank_llm/rerank/listwise/lit5_reranker.py +++ b/src/rank_llm/rerank/listwise/lit5_reranker.py @@ -1,6 +1,9 @@ from rank_llm.data import Request, Result from rank_llm.rerank.listwise.rank_fid import RankFiDDistill, RankFiDScore -from rank_llm.rerank.listwise.reorder.reorder_policy import ReorderPolicy, SlidingWindowReorderPolicy +from rank_llm.rerank.listwise.reorder.reorder_policy import ( + ReorderPolicy, + SlidingWindowReorderPolicy, +) from rank_llm.rerank.rankllm import PromptMode from rank_llm.rerank.reranker import Reranker @@ -12,7 +15,7 @@ def __init__( context_size: int = 300, prompt_mode: PromptMode = PromptMode.LiT5, window_size: int = 20, - reorder_policy: ReorderPolicy = None + reorder_policy: ReorderPolicy = None, ) -> None: if reorder_policy is None: reorder_policy = SlidingWindowReorderPolicy() @@ -21,7 +24,7 @@ def __init__( context_size=context_size, prompt_mode=prompt_mode, window_size=window_size, - reorder_policy=reorder_policy + reorder_policy=reorder_policy, ) self._reranker = Reranker(agent) diff --git a/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py index 4068c086..1e202489 100644 --- a/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py @@ -17,7 +17,7 @@ class ResortRequest: class TournamentSortNode: @staticmethod def build( - inds: List[int], window_size: int, top_k: int + inds: List[int], window_size: int, top_k: int ) -> Tuple[ "TournamentSortNode", List["TournamentSortNode"], @@ -36,7 +36,6 @@ def build( dq = deque(all_cs) while len(dq) > 1: - cnt = 0 children = [] @@ -56,11 +55,11 @@ def build( return dq[0], all_cs, base_nodes def __init__( - self, - top_k: int, - *, - children: Union[List["TournamentSortNode"]] = None, - index: int = None, + self, + top_k: int, + *, + children: Union[List["TournamentSortNode"]] = None, + index: int = None, ): super().__init__() @@ -133,7 +132,7 @@ def __str__(self): class TournamentSorter: def _get_random_indices( - self, expect_size: int, ind_choices: List[int] + self, expect_size: int, ind_choices: List[int] ) -> List[int]: choices = set(ind_choices) result = [] @@ -231,12 +230,12 @@ def perform(self, top_k: int): def multiple_sort( - requests: List[Result], - indices_batch: List[List[int]], - runner: Callable[[List[Tuple[Result, List[int]]]], List[List[int]]], - window_size: int, - r: int, - top_k: int, + requests: List[Result], + indices_batch: List[List[int]], + runner: Callable[[List[Tuple[Result, List[int]]]], List[List[int]]], + window_size: int, + r: int, + top_k: int, ): batch_size = len(requests) tournament_sorters: List[TournamentSorter] = [ @@ -277,12 +276,12 @@ def __init__(self, top_k: int = 10, r: int = 1, **kwargs): self._r = r def reorder( - self, - requests: List[Result], - rank_start: int, - rank_end: int, - model: ModelFunction, - **kwargs, + self, + requests: List[Result], + rank_start: int, + rank_end: int, + model: ModelFunction, + **kwargs, ) -> list[Result]: window_size = model.window_size diff --git a/src/rank_llm/rerank/listwise/vicuna_reranker.py b/src/rank_llm/rerank/listwise/vicuna_reranker.py index ab900a1c..49498098 100644 --- a/src/rank_llm/rerank/listwise/vicuna_reranker.py +++ b/src/rank_llm/rerank/listwise/vicuna_reranker.py @@ -3,7 +3,10 @@ from rank_llm.data import Request, Result from rank_llm.rerank import PromptMode from rank_llm.rerank.listwise import RankListwiseOSLLM -from rank_llm.rerank.listwise.reorder.reorder_policy import ReorderPolicy, SlidingWindowReorderPolicy +from rank_llm.rerank.listwise.reorder.reorder_policy import ( + ReorderPolicy, + SlidingWindowReorderPolicy, +) class VicunaReranker: @@ -20,7 +23,6 @@ def __init__( reorder_policy: ReorderPolicy = None, system_message: str = None, ) -> None: - if reorder_policy is None: reorder_policy = SlidingWindowReorderPolicy() diff --git a/src/rank_llm/rerank/listwise/zephyr_reranker.py b/src/rank_llm/rerank/listwise/zephyr_reranker.py index 5ce10b81..340f4600 100644 --- a/src/rank_llm/rerank/listwise/zephyr_reranker.py +++ b/src/rank_llm/rerank/listwise/zephyr_reranker.py @@ -3,22 +3,25 @@ from rank_llm.data import Request, Result from rank_llm.rerank import PromptMode from rank_llm.rerank.listwise import RankListwiseOSLLM -from rank_llm.rerank.listwise.reorder.reorder_policy import ReorderPolicy, SlidingWindowReorderPolicy +from rank_llm.rerank.listwise.reorder.reorder_policy import ( + ReorderPolicy, + SlidingWindowReorderPolicy, +) class ZephyrReranker: def __init__( - self, - model_path: str = "castorini/rank_zephyr_7b_v1_full", - context_size: int = 4096, - prompt_mode: PromptMode = PromptMode.RANK_GPT, - num_few_shot_examples: int = 0, - device: str = "cuda", - num_gpus: int = 1, - variable_passages: bool = True, - window_size: int = 20, - reorder_policy: ReorderPolicy = None, - system_message: str = "You are RankLLM, an intelligent assistant that can rank passages based on their relevancy to the query", + self, + model_path: str = "castorini/rank_zephyr_7b_v1_full", + context_size: int = 4096, + prompt_mode: PromptMode = PromptMode.RANK_GPT, + num_few_shot_examples: int = 0, + device: str = "cuda", + num_gpus: int = 1, + variable_passages: bool = True, + window_size: int = 20, + reorder_policy: ReorderPolicy = None, + system_message: str = "You are RankLLM, an intelligent assistant that can rank passages based on their relevancy to the query", ) -> None: if reorder_policy is None: reorder_policy = SlidingWindowReorderPolicy() @@ -34,18 +37,18 @@ def __init__( variable_passages=variable_passages, window_size=window_size, system_message=system_message, - reorder_policy=reorder_policy + reorder_policy=reorder_policy, ) def rerank_batch( - self, - requests: List[Request], - rank_start: int = 0, - rank_end: int = 100, - window_size: int = 20, - step: int = 10, - shuffle_candidates: bool = False, - logging: bool = False, + self, + requests: List[Request], + rank_start: int = 0, + rank_end: int = 100, + window_size: int = 20, + step: int = 10, + shuffle_candidates: bool = False, + logging: bool = False, ) -> List[Result]: """ Reranks a list of requests using the Zephyr model. @@ -76,14 +79,14 @@ def rerank_batch( ) def rerank( - self, - request: Request, - rank_start: int = 0, - rank_end: int = 100, - window_size: int = 20, - step: int = 10, - shuffle_candidates: bool = False, - logging: bool = False, + self, + request: Request, + rank_start: int = 0, + rank_end: int = 100, + window_size: int = 20, + step: int = 10, + shuffle_candidates: bool = False, + logging: bool = False, ) -> Result: """ Reranks a request using the Zephyr model. From 379530805da984e28573f25a713190afe51a6327 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Thu, 12 Sep 2024 06:27:31 +0000 Subject: [PATCH 18/30] Added top down --- .../rerank/listwise/listwise_rankllm.py | 46 +++- .../reorder/top_down_reorder_policy.py | 230 ++++++++++++++++++ 2 files changed, 269 insertions(+), 7 deletions(-) create mode 100644 src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index 8106842b..42e750f5 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -4,6 +4,7 @@ import random import re from abc import ABC +from dataclasses import dataclass from datetime import datetime from typing import Any, Dict, List, Tuple, Union @@ -17,13 +18,26 @@ ReorderPolicy, SlidingWindowReorderPolicy, ) +from rank_llm.rerank.listwise.reorder.top_down_reorder_policy import ( + TopDownReorderPolicy, +) from rank_llm.rerank.listwise.reorder.tournament_sort_reorder_policy import ( TournamentSortReorderPolicy, ) logger = logging.getLogger(__name__) -SUPPORT_REORDER_POLICIES = [SlidingWindowReorderPolicy, TournamentSortReorderPolicy] +SUPPORT_REORDER_POLICIES = [ + SlidingWindowReorderPolicy, + TournamentSortReorderPolicy, + TopDownReorderPolicy, +] + + +@dataclass +class RerankConsumption: + consumption_reference_by_batch: int + consumption_reference_by_item: int class ListwiseRankLLM(RankLLM, ABC): @@ -51,7 +65,9 @@ def __init__( super().__init__(model, context_size, prompt_mode) self._num_few_shot_examples = num_few_shot_examples - self.reorder_policy = reorder_policy + self.reorder_policy = ( + SlidingWindowReorderPolicy() if reorder_policy is None else reorder_policy + ) self._window_size = window_size def rerank_batch( @@ -72,7 +88,7 @@ def rerank_batch( batch_size = 1 reorder_policy = self.reorder_policy - model_functions = self._get_model_function(batched) + model_functions, consumption = self._get_model_function(batched) # reranking using vllm if len(set([len(req.candidates) for req in requests])) != 1: @@ -80,7 +96,7 @@ def rerank_batch( result: list[Result] = [] - with tqdm(range(0, len(requests))) as bar: + with tqdm(range(0, len(requests)), leave=False) as bar: for i in range(0, len(requests), batch_size): batch = requests[i : min(i + batch_size, len(requests))] batch_result = reorder_policy.reorder( @@ -527,8 +543,13 @@ def _permutation_to_rank(self, perm_string: str, selected_indices: List[int]): perm = perm + [i for i in range(len(selected_indices)) if i not in perm] return perm - def _get_model_function(self, batched: bool = False, **kwargs) -> ModelFunction: + def _get_model_function( + self, batched: bool = False, **kwargs + ) -> Tuple[ModelFunction, RerankConsumption]: # [(Request, SelectIndex)] -> [Prompt] + + consumption = RerankConsumption(0, 0) + if batched: def create_prompt(batch: List[Tuple[Result, List[int]]]): @@ -545,6 +566,9 @@ def execute( batch: List[Union[str, Dict[str, str]]], selected_indices_batch: List[List[int]], ): + consumption.consumption_reference_by_batch += 1 + consumption.consumption_reference_by_item += len(batch) + return [ self._permutation_to_rank(s, selected_indices) for (s, _), selected_indices in zip( @@ -564,6 +588,9 @@ def execute( batch: List[Union[str, Dict[str, str]]], selected_indices_batch: List[List[int]], ): + consumption.consumption_reference_by_batch += 1 + consumption.consumption_reference_by_item += len(batch) + return [ self._permutation_to_rank( self.run_llm(x, **kwargs)[0], selected_indices @@ -571,8 +598,13 @@ def execute( for x, selected_indices in zip(batch, selected_indices_batch) ] - return ModelFunction( - create_prompt=create_prompt, execute=execute, window_size=self._window_size + return ( + ModelFunction( + create_prompt=create_prompt, + execute=execute, + window_size=self._window_size, + ), + consumption, ) @staticmethod diff --git a/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py new file mode 100644 index 00000000..ed2d7363 --- /dev/null +++ b/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py @@ -0,0 +1,230 @@ +import copy +import logging +import random +from dataclasses import dataclass +from typing import Callable, List, Tuple + +from rank_llm.data import Result +from rank_llm.rerank.listwise.reorder.reorder_policy import ModelFunction, ReorderPolicy + +logger = logging.getLogger(__name__) + + +@dataclass +class ReorderRequest: + indices: List[int] + result: List[int] + + +class TopDownReorderProcess: + def __init__(self, top_k: int, window_size: int, pivot: int, indices: List[int]): + super().__init__() + self._window_size = window_size + self._pivot = pivot + self._top_k = top_k + self._indices = indices + + def _find_pivot(self, lst: List[int], piv: int) -> int: + for i in range(len(lst)): + if lst[i] == piv: + return i + # unreachable + assert False + + def _fill_unchoose(self, lst: List[int]): + st = set(lst) + for x in self._indices: + if x not in st: + lst.append(x) + + return lst + + def _pad(self, lst: List[int]): + st = set(lst) + results = [x for x in lst] + for i in reversed(self._indices): + if len(results) >= self._window_size: + break + if i not in st: + results.append(i) + + for i in reversed(self._indices): + if len(results) >= self._window_size: + break + results.append(i) + + return results + + def _unpad(self, lst: List[int], result_perm: List[int]): + return [x for x in result_perm if x < len(lst)] + + def _remove_from_occ(self, lst: List[int], inds: List[int]): + st = set(inds) + return [x for x in lst if x not in st] + + def _shuffle(self, lst: List[int]) -> List[int]: + l = [x for x in lst] + random.shuffle(l) + return l + + def perform(self): + top_k = self._top_k + window_size = self._window_size + pivot = self._pivot + indices = [x for x in self._indices] + + assert pivot <= window_size + assert top_k <= window_size + + """ + Algorithm is O(N^2) here, we can eliminate it to O(N) by split result into result and result_this_turn + """ + + while len(indices) > window_size: + result = [] + + while len(result) < top_k: + # base + base = indices[: min(window_size, len(indices))] + request = ReorderRequest(self._pad(base), None) + yield request + base = [base[i] for i in self._unpad(base, request.result)] + + if len(base) < window_size: + for i in base: + if len(result) >= top_k: + break + result.append(i) + break + + piv_item = base[pivot] + for i in range(pivot - 1): + result.append(base[i]) + + # then sort others + for i in range(window_size, len(indices), window_size - 1): + request_indices = indices[i : i + window_size - 1] + [piv_item] + request = ReorderRequest(self._pad(request_indices), None) + yield request + request_indices = [ + request_indices[i] + for i in self._unpad(request_indices, request.result) + ] + + # reordered + loc = self._find_pivot(request_indices, piv_item) + for i in range(loc - 1): + result.append(request_indices[i]) + + if len(result) + 1 == top_k: + result.append(piv_item) + + indices = self._remove_from_occ(indices, result) + + indices = result + + # finally, resort the value + # here len(indices) == top_k + request_indices = indices + request = ReorderRequest(self._pad(request_indices), None) + yield request + indices = [ + request_indices[i] for i in self._unpad(request_indices, request.result) + ] + + return self._fill_unchoose(indices) + + +def multiple_sort( + requests: List[Result], + indices_batch: List[List[int]], + runner: Callable[[List[Tuple[Result, List[int]]]], List[List[int]]], + window_size: int, + pivot: int, + top_k: int, +) -> List[List[int]]: + batch_size = len(requests) + top_down_sorters = [ + TopDownReorderProcess(top_k, window_size, pivot, indices) + for indices in indices_batch + ] + progress = [top_down_sorter.perform() for top_down_sorter in top_down_sorters] + result: List[List[int]] = [None for _ in range(batch_size)] + left_not_sorted = set(range(batch_size)) + + while len(left_not_sorted) > 0: + perm_request = [] + + finish_requests = [] + for idx in left_not_sorted: + try: + req = next(progress[idx]) + perm_request.append((idx, req)) + except StopIteration as e: + result[idx] = e.value + finish_requests.append(idx) + for idx in finish_requests: + left_not_sorted.remove(idx) + + outputs = runner([(requests[idx], req.indices) for idx, req in perm_request]) + + for (idx, req), output in zip(perm_request, outputs): + req.result = output + + return result + + +class TopDownReorderPolicy(ReorderPolicy): + def __init__(self, top_k: int = 10, pivot: int = -1, **kwargs): + super().__init__() + self._top_k = top_k + self._pivot = pivot + + def reorder( + self, + requests: List[Result], + rank_start: int, + rank_end: int, + model: ModelFunction, + **kwargs, + ) -> list[Result]: + window_size = model.window_size + pivot = window_size // 2 if self._pivot < 0 else self._pivot + + runner: Callable[ + [List[Tuple[Result, List[int]]]], List[List[int]] + ] = lambda reqs: model.execute( + model.create_prompt(reqs), [ind for req, ind in reqs] + ) + + request_ranks = multiple_sort( + requests, + [list(range(rank_start, rank_end)) for _ in range(len(requests))], + runner=runner, + top_k=self._top_k, + pivot=pivot, + window_size=window_size, + ) + + results = [ + Result( + query=copy.deepcopy(request.query), + candidates=self._reorder_by_rank( + copy.deepcopy(request.candidates), + [*range(len(request.candidates))], + rank, + ), + ranking_exec_summary=[], + ) + for request, rank in zip(requests, request_ranks) + ] + + for result, request in zip(results, requests): + for j in range(len(result.candidates)): + result.candidates[j].score = request.candidates[j].score + + return results + + @staticmethod + def name() -> str: + return "top_down" From 65d3508a54566b8ce73de0d096598ced09c08713 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Sun, 15 Sep 2024 20:05:22 +0000 Subject: [PATCH 19/30] Added final some stuff --- .../rerank/listwise/listwise_rankllm.py | 4 ++ .../rerank/listwise/rank_listwise_os_llm.py | 6 +- .../rerank/listwise/reorder/reorder_policy.py | 66 +++++++++++++------ .../reorder/top_down_reorder_policy.py | 17 +++-- .../reorder/tournament_sort_reorder_policy.py | 15 +++-- 5 files changed, 73 insertions(+), 35 deletions(-) diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index 42e750f5..7bbe45a3 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -120,6 +120,10 @@ def rerank_batch( result.extend(batch_result) bar.update(len(batch)) + logger.info( + f"Average consumption per request: {consumption.consumption_reference_by_item / len(requests) : .2f}" + ) + return result def get_output_filename( diff --git a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py index 930b9b5b..3e738fbe 100644 --- a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py @@ -29,8 +29,8 @@ class RankListwiseOSLLM(ListwiseRankLLM): def __init__( self, - reorder_policy: ReorderPolicy, model: str, + reorder_policy: ReorderPolicy = None, name: str = "", context_size: int = 4096, window_size: int = 20, @@ -270,7 +270,9 @@ def chunks(lst, n): all_completed_prompts = [] with ThreadPoolExecutor() as executor: - for batch in tqdm(chunks(results, batch_size), desc="Processing batches"): + for batch in tqdm( + chunks(results, batch_size), desc="Processing batches", leave=False + ): completed_prompts = list( executor.map( lambda req: self.create_prompt(req[0], req[1]), diff --git a/src/rank_llm/rerank/listwise/reorder/reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/reorder_policy.py index 5dd00798..ae884c95 100644 --- a/src/rank_llm/rerank/listwise/reorder/reorder_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/reorder_policy.py @@ -1,8 +1,11 @@ import copy +import random from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Callable, Dict, List, Tuple, TypeVar, Union +import numpy as np + from rank_llm.data import Result T = TypeVar("T") @@ -42,11 +45,39 @@ def name() -> str: pass @staticmethod - def _shuffle_and_rescore( - results: List[Result], select_indexes: List[int] - ) -> List[Result]: - # TODO: do nothing for now - return results + def _shuffle_indices(indices: List[int]) -> List[int]: + indices = list(indices) + random.shuffle(indices) + return indices + + @staticmethod + def _shuffled( + func: Callable[[List[Tuple[Result, List[int]]]], List[List[int]]] + ) -> Callable[[List[Tuple[Result, List[int]]]], List[List[int]]]: + def fun(batch: List[Tuple[Result, List[int]]]) -> List[List[int]]: + perms = [] + perms_back = [] + batch_feed = [] + for res, ind in batch: + perm = np.random.permutation(len(ind)).tolist() + perm_back = [0 for _ in range(len(perm))] + perms.append(perm) + + for i in range(len(perm)): + perm_back[perm[i]] = i + + batch_feed.append((res, [ind[x] for x in perm])) + perms_back.append(perm_back) + + result_raw = func(batch) + + results = [] + for result, perm_back in zip(result_raw, perms_back): + results.append([result[perm_back[x]] for x in range(len(result))]) + + return results + + return fun @staticmethod def _reorder_by_rank(items: List[T], idxes: List[int], rank: List[int]) -> List[T]: @@ -69,13 +100,10 @@ class SlidingWindowReorderPolicy(ReorderPolicy): def __init__( self, step: int = 10, - shuffle_candidates: bool = False, **kwargs, ): self._step_size = step - self._shuffle_candidates = bool(shuffle_candidates) - def reorder( self, requests: List[Result], @@ -89,20 +117,16 @@ def reorder( ) -> List[Result]: window_size = model.window_size - rerank_results = [ - Result( - query=copy.deepcopy(request.query), - candidates=copy.deepcopy(request.candidates), - ranking_exec_summary=[], - ) - for request in requests - ] - - if self._shuffle_candidates: - self._shuffle_and_rescore(rerank_results, [*range(rank_start, rank_end)]) - # order of requests - request_ranks = [[*range(len(request.candidates))] for request in requests] + if shuffle_candidates: + request_ranks = [ + self._shuffle_indices(list(range(len(request.candidates)))) + for request in requests + ] + else: + request_ranks = [ + list(range(len(request.candidates))) for request in requests + ] end_pos = rank_end start_pos = rank_end - window_size diff --git a/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py index ed2d7363..2af58773 100644 --- a/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py @@ -1,6 +1,5 @@ import copy import logging -import random from dataclasses import dataclass from typing import Callable, List, Tuple @@ -62,11 +61,6 @@ def _remove_from_occ(self, lst: List[int], inds: List[int]): st = set(inds) return [x for x in lst if x not in st] - def _shuffle(self, lst: List[int]) -> List[int]: - l = [x for x in lst] - random.shuffle(l) - return l - def perform(self): top_k = self._top_k window_size = self._window_size @@ -186,6 +180,7 @@ def reorder( rank_start: int, rank_end: int, model: ModelFunction, + shuffle_candidates: bool = False, **kwargs, ) -> list[Result]: window_size = model.window_size @@ -197,9 +192,17 @@ def reorder( model.create_prompt(reqs), [ind for req, ind in reqs] ) + if shuffle_candidates: + indices = [ + self._shuffle_indices(list(range(len(request.candidates)))) + for request in requests + ] + else: + indices = [list(range(rank_start, rank_end)) for _ in range(len(requests))] + request_ranks = multiple_sort( requests, - [list(range(rank_start, rank_end)) for _ in range(len(requests))], + indices, runner=runner, top_k=self._top_k, pivot=pivot, diff --git a/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py index 1e202489..bc32c088 100644 --- a/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py @@ -179,8 +179,6 @@ def __init__(self, indices: List[int], window_size: int, r: int): indices, window_size=window_size, top_k=r ) - self.count_inference = 0 - def _pop(self, x: int) -> List[TournamentSortNode]: on: TournamentSortNode = self._idx_to_node[x] lst = [] @@ -201,7 +199,6 @@ def perform(self, top_k: int): padded = self._pad_size(resort_param) request = ResortRequest(padded, []) yield request - self.count_inference += 1 cleaned_result = self._unpad_perm(resort_param, padded, request.result) nd.resort(cleaned_result) @@ -219,7 +216,6 @@ def perform(self, top_k: int): padded = self._pad_size(resort_param) request = ResortRequest(padded, []) yield request - self.count_inference += 1 assert len(request.result) > 0 cleaned_result = self._unpad_perm( resort_param, padded, request.result @@ -281,6 +277,7 @@ def reorder( rank_start: int, rank_end: int, model: ModelFunction, + shuffle_candidates: bool = False, **kwargs, ) -> list[Result]: window_size = model.window_size @@ -291,9 +288,17 @@ def reorder( model.create_prompt(reqs), [ind for req, ind in reqs] ) + if shuffle_candidates: + indices = [ + self._shuffle_indices(list(range(len(request.candidates)))) + for request in requests + ] + else: + indices = [list(range(rank_start, rank_end)) for _ in range(len(requests))] + request_ranks = multiple_sort( requests, - [list(range(rank_start, rank_end)) for _ in range(len(requests))], + indices, runner=runner, window_size=window_size, top_k=self._top_k, From 69e3d657b2f9fc3fc7341aa83b38b63b1f678641 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Sun, 15 Sep 2024 20:18:33 +0000 Subject: [PATCH 20/30] Updated readme --- README.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/README.md b/README.md index 46a639fa..b2ed2fc9 100644 --- a/README.md +++ b/README.md @@ -203,6 +203,35 @@ If you use one of the monoT5 models please cite the following relevant paper: journal = {arXiv:2101.05667}, } ``` + +If you use `reorder_policy=tournament_sort`, please cite the following paper: + +``` +@misc{yoon2024listt5listwisererankingfusionindecoder, + title={ListT5: Listwise Reranking with Fusion-in-Decoder Improves Zero-shot Retrieval}, + author={Soyoung Yoon and Eunbi Choi and Jiyeon Kim and Hyeongu Yun and Yireun Kim and Seung-won Hwang}, + year={2024}, + eprint={2402.15838}, + archivePrefix={arXiv}, + primaryClass={cs.IR}, + url={https://arxiv.org/abs/2402.15838}, +} +``` + +If you use `reorder_policy=top_down`, please cite the following paper: + +``` +@misc{parry2024topdownpartitioningefficientlistwise, + title={Top-Down Partitioning for Efficient List-Wise Ranking}, + author={Andrew Parry and Sean MacAvaney and Debasis Ganguly}, + year={2024}, + eprint={2405.14589}, + archivePrefix={arXiv}, + primaryClass={cs.IR}, + url={https://arxiv.org/abs/2405.14589}, +} +``` + ## 🙏 Acknowledgments This research is supported in part by the Natural Sciences and Engineering Research Council (NSERC) of Canada. From 8f4f72996d46d19ab19d18e9f413834747275760 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Sun, 15 Sep 2024 21:42:05 +0000 Subject: [PATCH 21/30] Clean the README --- README.md | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/README.md b/README.md index b2ed2fc9..55374d2e 100644 --- a/README.md +++ b/README.md @@ -204,33 +204,6 @@ If you use one of the monoT5 models please cite the following relevant paper: } ``` -If you use `reorder_policy=tournament_sort`, please cite the following paper: - -``` -@misc{yoon2024listt5listwisererankingfusionindecoder, - title={ListT5: Listwise Reranking with Fusion-in-Decoder Improves Zero-shot Retrieval}, - author={Soyoung Yoon and Eunbi Choi and Jiyeon Kim and Hyeongu Yun and Yireun Kim and Seung-won Hwang}, - year={2024}, - eprint={2402.15838}, - archivePrefix={arXiv}, - primaryClass={cs.IR}, - url={https://arxiv.org/abs/2402.15838}, -} -``` - -If you use `reorder_policy=top_down`, please cite the following paper: - -``` -@misc{parry2024topdownpartitioningefficientlistwise, - title={Top-Down Partitioning for Efficient List-Wise Ranking}, - author={Andrew Parry and Sean MacAvaney and Debasis Ganguly}, - year={2024}, - eprint={2405.14589}, - archivePrefix={arXiv}, - primaryClass={cs.IR}, - url={https://arxiv.org/abs/2405.14589}, -} -``` ## 🙏 Acknowledgments From cc22c9587ffaea4aa210721546d44384948554f4 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Mon, 16 Sep 2024 16:41:49 +0000 Subject: [PATCH 22/30] Changed filename --- src/rank_llm/rerank/listwise/listwise_rankllm.py | 1 + src/rank_llm/rerank/listwise/reorder/reorder_policy.py | 7 +++++++ .../rerank/listwise/reorder/top_down_reorder_policy.py | 3 +++ .../listwise/reorder/tournament_sort_reorder_policy.py | 3 +++ 4 files changed, 14 insertions(+) diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index 7bbe45a3..25416132 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -143,6 +143,7 @@ def get_output_filename( name = f"{name}_{dataset_name}" if self._num_few_shot_examples > 0: name += f"_{self._num_few_shot_examples}_shot" + name += f"_{self.reorder_policy.param_name()}" return ( f"{name}_shuffled_{datetime.isoformat(datetime.now())}" if shuffle_candidates diff --git a/src/rank_llm/rerank/listwise/reorder/reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/reorder_policy.py index ae884c95..c600e06f 100644 --- a/src/rank_llm/rerank/listwise/reorder/reorder_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/reorder_policy.py @@ -39,6 +39,10 @@ def reorder( ) -> list[Result]: pass + @abstractmethod + def param_name(self): + pass + @staticmethod @abstractmethod def name() -> str: @@ -172,6 +176,9 @@ def reorder( return results + def param_name(self): + return f"slidingwindow_stp{self._step_size}" + @staticmethod def name() -> str: return "sliding_window" diff --git a/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py index 2af58773..333f33dc 100644 --- a/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py @@ -228,6 +228,9 @@ def reorder( return results + def param_name(self): + return f"topdown_tpk{self._top_k}_pvt{self._pivot}" + @staticmethod def name() -> str: return "top_down" diff --git a/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py index bc32c088..b1061aca 100644 --- a/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/tournament_sort_reorder_policy.py @@ -324,6 +324,9 @@ def reorder( return results + def param_name(self): + return f"tournamentsort_tpk{self._top_k}_r{self._r}" + @staticmethod def name() -> str: return "tournament_sort" From 26beeedec262d16c8f4309d4a97db27709195e32 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Mon, 16 Sep 2024 20:43:59 +0000 Subject: [PATCH 23/30] Updated silence --- .../rerank/listwise/listwise_rankllm.py | 9 +++++---- .../rerank/listwise/rank_listwise_os_llm.py | 19 ++++++++++++++----- src/rank_llm/scripts/run_rank_llm.py | 8 ++++++++ 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index 25416132..ae092327 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -88,7 +88,7 @@ def rerank_batch( batch_size = 1 reorder_policy = self.reorder_policy - model_functions, consumption = self._get_model_function(batched) + model_functions, consumption = self._get_model_function(batched, **kwargs) # reranking using vllm if len(set([len(req.candidates) for req in requests])) != 1: @@ -549,7 +549,7 @@ def _permutation_to_rank(self, perm_string: str, selected_indices: List[int]): return perm def _get_model_function( - self, batched: bool = False, **kwargs + self, batched: bool = False, silence: bool = False, **kwargs ) -> Tuple[ModelFunction, RerankConsumption]: # [(Request, SelectIndex)] -> [Prompt] @@ -577,7 +577,8 @@ def execute( return [ self._permutation_to_rank(s, selected_indices) for (s, _), selected_indices in zip( - self.run_llm_batched(batch, **kwargs), selected_indices_batch + self.run_llm_batched(batch, silence=silence, **kwargs), + selected_indices_batch, ) ] @@ -598,7 +599,7 @@ def execute( return [ self._permutation_to_rank( - self.run_llm(x, **kwargs)[0], selected_indices + self.run_llm(x, silence=silence, **kwargs)[0], selected_indices ) for x, selected_indices in zip(batch, selected_indices_batch) ] diff --git a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py index 3e738fbe..063e0ee8 100644 --- a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py @@ -103,7 +103,9 @@ def __init__( ) elif vllm_batched: self._llm = LLM( - model, download_dir=os.getenv("HF_HOME"), enforce_eager=False + model, + download_dir=os.getenv("HF_HOME"), + enforce_eager=False, ) self._tokenizer = self._llm.get_tokenizer() else: @@ -133,26 +135,31 @@ def rerank_batch( def run_llm_batched( self, prompts: List[str | List[Dict[str, str]]], + silence: bool = False, current_window_size: Optional[int] = None, + **kwargs, ) -> List[Tuple[str, int]]: if SamplingParams is None: raise ImportError( "Please install rank-llm with `pip install rank-llm[vllm]` to use batch inference." ) - logger.info(f"VLLM Generating!") sampling_params = SamplingParams( temperature=0.0, max_tokens=self.num_output_tokens(current_window_size), min_tokens=self.num_output_tokens(current_window_size), ) - outputs = self._llm.generate(prompts, sampling_params) + outputs = self._llm.generate(prompts, sampling_params, use_tqdm=not silence) return [ (output.outputs[0].text, len(output.outputs[0].token_ids)) for output in outputs ] def run_llm( - self, prompt: str, current_window_size: Optional[int] = None + self, + prompt: str, + silence: bool = False, + current_window_size: Optional[int] = None, + **kwargs, ) -> Tuple[str, int]: if current_window_size is None: current_window_size = self._window_size @@ -163,7 +170,9 @@ def run_llm( gen_cfg.min_new_tokens = self.num_output_tokens(current_window_size) # gen_cfg.temperature = 0 gen_cfg.do_sample = False - output_ids = self._llm.generate(**inputs, generation_config=gen_cfg) + output_ids = self._llm.generate( + **inputs, use_tqdm=not silence, generation_config=gen_cfg + ) if self._llm.config.is_encoder_decoder: output_ids = output_ids[0] diff --git a/src/rank_llm/scripts/run_rank_llm.py b/src/rank_llm/scripts/run_rank_llm.py index 03209646..d570bf53 100644 --- a/src/rank_llm/scripts/run_rank_llm.py +++ b/src/rank_llm/scripts/run_rank_llm.py @@ -39,6 +39,7 @@ def main(args): vllm_batched = args.vllm_batched batch_size = args.batch_size reorder_policy = args.reorder_policy + silence = args.silence _ = retrieve_and_rerank( model_path=model_path, @@ -63,6 +64,7 @@ def main(args): system_message=system_message, vllm_batched=vllm_batched, reorder_policy=reorder_policy, + silence=silence, ) @@ -182,5 +184,11 @@ def main(args): help="policy in reordering. defaultly to be sliding window", type=str, ) + parser.add_argument( + "--silence", + default=False, + action="store_true", + help="Whether or not omitting some unbeautiful tqdm bars that is unavoidable (not able to set leave=False)", + ) args = parser.parse_args() main(args) From 36beed149a7d14f8aa5b3046f3378c3237597c94 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Fri, 20 Sep 2024 06:18:38 +0000 Subject: [PATCH 24/30] Fixed bug when batch size > 32, which is in rank_listwise_os_llm's create_prompt --- src/rank_llm/rerank/listwise/rank_listwise_os_llm.py | 9 ++++----- src/rank_llm/rerank/listwise/reorder/reorder_policy.py | 7 +++++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py index 063e0ee8..a4e9044b 100644 --- a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py @@ -280,13 +280,12 @@ def chunks(lst, n): with ThreadPoolExecutor() as executor: for batch in tqdm( - chunks(results, batch_size), desc="Processing batches", leave=False + chunks(list(zip(results, selected_indices_batch)), batch_size), + desc="Processing batches", + leave=False, ): completed_prompts = list( - executor.map( - lambda req: self.create_prompt(req[0], req[1]), - zip(batch, selected_indices_batch), - ) + executor.map(lambda req: self.create_prompt(req[0], req[1]), batch) ) all_completed_prompts.extend(completed_prompts) return all_completed_prompts diff --git a/src/rank_llm/rerank/listwise/reorder/reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/reorder_policy.py index c600e06f..f9549115 100644 --- a/src/rank_llm/rerank/listwise/reorder/reorder_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/reorder_policy.py @@ -107,6 +107,7 @@ def __init__( **kwargs, ): self._step_size = step + self.coll = 0 def reorder( self, @@ -142,14 +143,16 @@ def reorder( # logger.info(f"start_pos: {start_pos}, end_pos: {end_pos}") start_pos = max(start_pos, rank_start) - indices_working_on = [*range(start_pos, end_pos)] + indices_working_on = list(range(start_pos, end_pos)) prompts = model.create_prompt( [ (request, [request_rank[i] for i in indices_working_on]) for request, request_rank in zip(requests, request_ranks) ] ) - orders = model.execute(prompts, [indices_working_on] * len(requests)) + orders = model.execute( + prompts, [list(indices_working_on) for _ in requests] + ) for request_rank, order in zip(request_ranks, orders): self._reorder_by_rank(request_rank, indices_working_on, order) From 406fca731e44d93a4311a34f9e808e429b897f6f Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Wed, 25 Sep 2024 04:30:49 +0000 Subject: [PATCH 25/30] Updated for removing use_tqdm=False --- src/rank_llm/rerank/listwise/rank_listwise_os_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py index a4e9044b..78d91f6f 100644 --- a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py @@ -171,7 +171,7 @@ def run_llm( # gen_cfg.temperature = 0 gen_cfg.do_sample = False output_ids = self._llm.generate( - **inputs, use_tqdm=not silence, generation_config=gen_cfg + **inputs, generation_config=gen_cfg ) if self._llm.config.is_encoder_decoder: From 9e7f846febc5ed7ff5d2ca32d41b107e6b59fcfc Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Fri, 27 Sep 2024 01:01:16 +0000 Subject: [PATCH 26/30] Update parallel for topdown --- .../reorder/top_down_reorder_policy.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py index 333f33dc..f13df2c8 100644 --- a/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py @@ -81,7 +81,7 @@ def perform(self): # base base = indices[: min(window_size, len(indices))] request = ReorderRequest(self._pad(base), None) - yield request + yield [request] base = [base[i] for i in self._unpad(base, request.result)] if len(base) < window_size: @@ -95,11 +95,20 @@ def perform(self): for i in range(pivot - 1): result.append(base[i]) + requests = [] + req_inds = [] + # then sort others for i in range(window_size, len(indices), window_size - 1): request_indices = indices[i : i + window_size - 1] + [piv_item] + req_inds.append(request_indices) request = ReorderRequest(self._pad(request_indices), None) - yield request + requests.append(request) + + yield requests + + for request, request_indices, i \ + in zip(requests, req_inds, range(window_size, len(indices), window_size - 1)): request_indices = [ request_indices[i] for i in self._unpad(request_indices, request.result) @@ -121,7 +130,7 @@ def perform(self): # here len(indices) == top_k request_indices = indices request = ReorderRequest(self._pad(request_indices), None) - yield request + yield [request] indices = [ request_indices[i] for i in self._unpad(request_indices, request.result) ] @@ -152,8 +161,8 @@ def multiple_sort( finish_requests = [] for idx in left_not_sorted: try: - req = next(progress[idx]) - perm_request.append((idx, req)) + reqs = next(progress[idx]) + perm_request.extend([(idx, req) for req in reqs]) except StopIteration as e: result[idx] = e.value finish_requests.append(idx) From 329d6eb4dae6b1eef3e8a2d08afa6dd5598bda7c Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Fri, 27 Sep 2024 01:04:28 +0000 Subject: [PATCH 27/30] Reformat --- .../rerank/listwise/reorder/top_down_reorder_policy.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py index f13df2c8..4b8709f6 100644 --- a/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py @@ -107,8 +107,11 @@ def perform(self): yield requests - for request, request_indices, i \ - in zip(requests, req_inds, range(window_size, len(indices), window_size - 1)): + for request, request_indices, i in zip( + requests, + req_inds, + range(window_size, len(indices), window_size - 1), + ): request_indices = [ request_indices[i] for i in self._unpad(request_indices, request.result) From 4f561e8201312aa1b305c647548ebb167b214512 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Thu, 3 Oct 2024 00:03:07 +0000 Subject: [PATCH 28/30] Added vllm_chunked_prefill --- src/rank_llm/rerank/listwise/rank_listwise_os_llm.py | 7 ++++--- src/rank_llm/rerank/reranker.py | 3 +++ src/rank_llm/scripts/run_rank_llm.py | 7 +++++++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py index 78d91f6f..4fea0e88 100644 --- a/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/listwise/rank_listwise_os_llm.py @@ -41,6 +41,7 @@ def __init__( variable_passages: bool = False, system_message: str = None, vllm_batched: bool = False, + vllm_chunked_prefill: bool = False ) -> None: """ Creates instance of the RankListwiseOSLLM class, an extension of RankLLM designed for performing listwise ranking of passages using a specified language model. Advanced configurations are supported such as GPU acceleration, variable passage handling, and custom system messages for generating prompts. @@ -106,6 +107,8 @@ def __init__( model, download_dir=os.getenv("HF_HOME"), enforce_eager=False, + enable_chunked_prefill=vllm_chunked_prefill, + disable_sliding_window=vllm_chunked_prefill ) self._tokenizer = self._llm.get_tokenizer() else: @@ -170,9 +173,7 @@ def run_llm( gen_cfg.min_new_tokens = self.num_output_tokens(current_window_size) # gen_cfg.temperature = 0 gen_cfg.do_sample = False - output_ids = self._llm.generate( - **inputs, generation_config=gen_cfg - ) + output_ids = self._llm.generate(**inputs, generation_config=gen_cfg) if self._llm.config.is_encoder_decoder: output_ids = output_ids[0] diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index e72e2019..9b9cbafb 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -227,6 +227,7 @@ def create_agent( ("window_size", 20), ("system_message", None), ("vllm_batched", False), + ("vllm_chunked_prefill", False) ] [ context_size, @@ -238,6 +239,7 @@ def create_agent( window_size, system_message, vllm_batched, + vllm_chunked_prefill ] = extract_kwargs(keys_and_defaults, **kwargs) agent = RankListwiseOSLLM( @@ -257,6 +259,7 @@ def create_agent( variable_passages=variable_passages, system_message=system_message, vllm_batched=vllm_batched, + vllm_chunked_prefill=vllm_chunked_prefill ) print(f"Completed loading {model_path}") diff --git a/src/rank_llm/scripts/run_rank_llm.py b/src/rank_llm/scripts/run_rank_llm.py index d570bf53..9d1da28a 100644 --- a/src/rank_llm/scripts/run_rank_llm.py +++ b/src/rank_llm/scripts/run_rank_llm.py @@ -37,6 +37,7 @@ def main(args): window_size = args.window_size system_message = args.system_message vllm_batched = args.vllm_batched + vllm_chunked_prefill = args.vllm_chunked_prefill batch_size = args.batch_size reorder_policy = args.reorder_policy silence = args.silence @@ -63,6 +64,7 @@ def main(args): window_size=window_size, system_message=system_message, vllm_batched=vllm_batched, + vllm_chunked_prefill=vllm_chunked_prefill, reorder_policy=reorder_policy, silence=silence, ) @@ -178,6 +180,11 @@ def main(args): action="store_true", help="whether to run the model in batches", ) + parser.add_argument( + "--vllm_chunked_prefill", + action="store_true", + help="whether to run the model in vllm chunked prefill. no function if vllm_batched is not on", + ) parser.add_argument( "--reorder_policy", default="sliding_window", From 8470d3d9f0592e519461020f76e494da4fce7d95 Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Sat, 19 Oct 2024 01:30:55 +0000 Subject: [PATCH 29/30] Updated README --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 55374d2e..46a639fa 100644 --- a/README.md +++ b/README.md @@ -203,8 +203,6 @@ If you use one of the monoT5 models please cite the following relevant paper: journal = {arXiv:2101.05667}, } ``` - - ## 🙏 Acknowledgments This research is supported in part by the Natural Sciences and Engineering Research Council (NSERC) of Canada. From 9caa5fb64c7b91021b9e2156b99f1787404ec64b Mon Sep 17 00:00:00 2001 From: Yidi Chen Date: Sat, 19 Oct 2024 04:04:55 +0000 Subject: [PATCH 30/30] Put back step size for maintaining backward compatibility --- src/rank_llm/rerank/listwise/listwise_rankllm.py | 2 +- .../rerank/listwise/reorder/reorder_policy.py | 11 +++++++++-- src/rank_llm/rerank/reranker.py | 6 +++--- src/rank_llm/scripts/run_rank_llm.py | 6 ++++-- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index ae092327..196f7268 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -630,6 +630,6 @@ def get_reorder_policy(reorder_policy: str, **kwargs): except Exception as e: print(e) raise Exception(f"Cannot load reorder policy {reorder_policy}") - return policy(**reorder_param_dict) + return policy(**reorder_param_dict, extra_args=dict(**kwargs)) raise Exception(f"Cannot find reorder policy {reorder_policy}") diff --git a/src/rank_llm/rerank/listwise/reorder/reorder_policy.py b/src/rank_llm/rerank/listwise/reorder/reorder_policy.py index f9549115..f888e4ed 100644 --- a/src/rank_llm/rerank/listwise/reorder/reorder_policy.py +++ b/src/rank_llm/rerank/listwise/reorder/reorder_policy.py @@ -103,10 +103,17 @@ def _reorder_by_rank(items: List[T], idxes: List[int], rank: List[int]) -> List[ class SlidingWindowReorderPolicy(ReorderPolicy): def __init__( self, - step: int = 10, + step: int = None, + extra_args: dict = None, **kwargs, ): - self._step_size = step + self._step_size = ( + step + if step is not None + else extra_args.get("step_size", 10) + if extra_args is not None + else 10 + ) self.coll = 0 def reorder( diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 9b9cbafb..e785388a 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -227,7 +227,7 @@ def create_agent( ("window_size", 20), ("system_message", None), ("vllm_batched", False), - ("vllm_chunked_prefill", False) + ("vllm_chunked_prefill", False), ] [ context_size, @@ -239,7 +239,7 @@ def create_agent( window_size, system_message, vllm_batched, - vllm_chunked_prefill + vllm_chunked_prefill, ] = extract_kwargs(keys_and_defaults, **kwargs) agent = RankListwiseOSLLM( @@ -259,7 +259,7 @@ def create_agent( variable_passages=variable_passages, system_message=system_message, vllm_batched=vllm_batched, - vllm_chunked_prefill=vllm_chunked_prefill + vllm_chunked_prefill=vllm_chunked_prefill, ) print(f"Completed loading {model_path}") diff --git a/src/rank_llm/scripts/run_rank_llm.py b/src/rank_llm/scripts/run_rank_llm.py index 9d1da28a..2da15f3d 100644 --- a/src/rank_llm/scripts/run_rank_llm.py +++ b/src/rank_llm/scripts/run_rank_llm.py @@ -35,6 +35,7 @@ def main(args): retrieval_mode = RetrievalMode.DATASET num_passes = args.num_passes window_size = args.window_size + step_size = args.step_size system_message = args.system_message vllm_batched = args.vllm_batched vllm_chunked_prefill = args.vllm_chunked_prefill @@ -62,6 +63,7 @@ def main(args): variable_passages=variable_passages, num_passes=num_passes, window_size=window_size, + step_size=step_size, system_message=system_message, vllm_batched=vllm_batched, vllm_chunked_prefill=vllm_chunked_prefill, @@ -166,8 +168,8 @@ def main(args): parser.add_argument( "--step_size", type=int, - default=20, - help="window size for the LLM", + default=10, + help="step size for the sliding window approach", ) parser.add_argument( "--system_message",