Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorder #143

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
60d5e7f
First step update - during create_prompt, use selected_index to subst…
XKTZ Aug 13, 2024
74118c4
Allow selected index to be variable on each step
XKTZ Aug 13, 2024
36bbcdc
Added basic _get_model_function - it is wrong now, we want to return …
XKTZ Aug 13, 2024
6d2094a
Renamed indexes to indices, [indices] is renamed to indices_batch
XKTZ Aug 14, 2024
cc78522
Renamed ReorderExecutor to Reorder Policy
XKTZ Aug 14, 2024
4936c8c
Moved LiT5Distill to policy
XKTZ Aug 15, 2024
95f065a
Transition rank listwise os llm to using reorder policy
XKTZ Aug 18, 2024
438c8bd
Added reorder policy for rank listwise os and fid score
XKTZ Aug 18, 2024
2aa8644
Revised bug in OS LLM, add Rank GPT, deprecated old functions
XKTZ Aug 18, 2024
7a19b87
Finish the tournament sort node
XKTZ Aug 25, 2024
c8734f0
Finish tournament sort
XKTZ Aug 26, 2024
7981c06
Finish reorganize of parameters, move window_size to ListwiseRankLLM
XKTZ Aug 30, 2024
bbbdc31
Added window size back
XKTZ Sep 5, 2024
bb0ad2b
Fix Rerankers
XKTZ Sep 5, 2024
62785bc
Merge pull request #4 from castorini/main
XKTZ Sep 5, 2024
9b1972a
Merge branch 'reorder' into main
XKTZ Sep 5, 2024
06b9f1f
Merge pull request #5 from XKTZ/main
XKTZ Sep 5, 2024
28e3a23
Added r parameter
XKTZ Sep 5, 2024
0e9edce
Some bug fix
XKTZ Sep 6, 2024
3005455
Reformatted
XKTZ Sep 6, 2024
3795308
Added top down
XKTZ Sep 12, 2024
edd1ff9
Merge pull request #6 from castorini/main
XKTZ Sep 12, 2024
89bab7c
Merge pull request #7 from XKTZ/main
XKTZ Sep 12, 2024
65d3508
Added final some stuff
XKTZ Sep 15, 2024
69e3d65
Updated readme
XKTZ Sep 15, 2024
8f4f729
Clean the README
XKTZ Sep 15, 2024
cc22c95
Changed filename
XKTZ Sep 16, 2024
26beeed
Updated silence
XKTZ Sep 16, 2024
36beed1
Fixed bug when batch size > 32, which is in rank_listwise_os_llm's cr…
XKTZ Sep 20, 2024
406fca7
Updated for removing use_tqdm=False
XKTZ Sep 25, 2024
9e7f846
Update parallel for topdown
XKTZ Sep 27, 2024
329d6eb
Reformat
XKTZ Sep 27, 2024
4f561e8
Added vllm_chunked_prefill
XKTZ Oct 3, 2024
8470d3d
Updated README
XKTZ Oct 19, 2024
9caa5fb
Put back step size for maintaining backward compatibility
XKTZ Oct 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ If you use one of the monoT5 models please cite the following relevant paper:
journal = {arXiv:2101.05667},
}
```


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

## 🙏 Acknowledgments

This research is supported in part by the Natural Sciences and Engineering Research Council (NSERC) of Canada.
204 changes: 198 additions & 6 deletions src/rank_llm/rerank/listwise/listwise_rankllm.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,44 @@
import copy
import json
import logging
import random
import re
from abc import ABC
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Union

from ftfy import fix_text
from tqdm import tqdm

from rank_llm.data import RankingExecInfo, Request, Result
from rank_llm.rerank import PromptMode, RankLLM
from rank_llm.rerank.listwise.reorder.reorder_policy import (
ModelFunction,
ReorderPolicy,
SlidingWindowReorderPolicy,
)
from rank_llm.rerank.listwise.reorder.top_down_reorder_policy import (
TopDownReorderPolicy,
)
from rank_llm.rerank.listwise.reorder.tournament_sort_reorder_policy import (
TournamentSortReorderPolicy,
)

logger = logging.getLogger(__name__)

SUPPORT_REORDER_POLICIES = [
SlidingWindowReorderPolicy,
TournamentSortReorderPolicy,
TopDownReorderPolicy,
]


@dataclass
class RerankConsumption:
consumption_reference_by_batch: int
consumption_reference_by_item: int


class ListwiseRankLLM(RankLLM, ABC):
"""
Expand All @@ -30,16 +55,77 @@ class ListwiseRankLLM(RankLLM, ABC):

def __init__(
self,
reorder_policy: ReorderPolicy,
model: str,
context_size: int,
window_size: int,
prompt_mode: PromptMode,
num_few_shot_examples: int,
window_size: int,
) -> None:
super().__init__(model, context_size, prompt_mode)
self._num_few_shot_examples = num_few_shot_examples

self.reorder_policy = (
SlidingWindowReorderPolicy() if reorder_policy is None else reorder_policy
)
self._window_size = window_size

def rerank_batch(
self,
requests: List[Request],
rank_start: int = 0,
rank_end: int = 100,
shuffle_candidates: bool = False,
logging: bool = False,
batched: bool = False,
**kwargs: Any,
) -> List[Result]:
populate_exec_summary: bool = kwargs.get("populate_exec_summary", False)

batch_size = kwargs.get("batch_size", 1)

if not batched:
batch_size = 1

reorder_policy = self.reorder_policy
model_functions, consumption = self._get_model_function(batched)

# reranking using vllm
if len(set([len(req.candidates) for req in requests])) != 1:
raise ValueError("Batched requests must have the same number of candidates")

result: list[Result] = []

with tqdm(range(0, len(requests)), leave=False) as bar:
for i in range(0, len(requests), batch_size):
batch = requests[i : min(i + batch_size, len(requests))]
batch_result = reorder_policy.reorder(
requests=[
Result(
query=copy.deepcopy(request.query),
candidates=copy.deepcopy(request.candidates),
ranking_exec_summary=[],
)
for request in batch
],
rank_start=max(rank_start, 0),
rank_end=min(
rank_end, len(requests[0].candidates)
), # TODO: Fails arbitrary hit sizes
model=model_functions,
shuffle_candidates=shuffle_candidates,
logging=logging,
populate_exec_summary=populate_exec_summary,
)
result.extend(batch_result)
bar.update(len(batch))

logger.info(
f"Average consumption per request: {consumption.consumption_reference_by_item / len(requests) : .2f}"
)

return result

def get_output_filename(
self,
top_k_candidates: int,
Expand Down Expand Up @@ -72,6 +158,7 @@ def max_tokens(self) -> int:
"""
return self._context_size

# @deprecated("old sliding window pipeline is deprecated. please use reorder policy")
def permutation_pipeline_batched(
self,
results: List[Result],
Expand All @@ -95,7 +182,9 @@ def permutation_pipeline_batched(
prompts = []
logger.info("Loading prompts.")
prompts = self.create_prompt_batched(
results, rank_start, rank_end, batch_size=32
results,
[list(range(rank_start, rank_end)) for _ in range(len(results))],
batch_size=32,
)
if logging:
for prompt in prompts:
Expand All @@ -122,6 +211,7 @@ def permutation_pipeline_batched(

return results

# @deprecated("old sliding window pipeline is deprecated. please use reorder policy")
def permutation_pipeline(
self,
result: Result,
Expand All @@ -142,7 +232,9 @@ def permutation_pipeline(
Returns:
Result: The processed result object after applying permutation.
"""
prompt, in_token_count = self.create_prompt(result, rank_start, rank_end)
prompt, in_token_count = self.create_prompt(
result, list(range(rank_start, rank_end))
)
if logging:
logger.info(f"Prompt: {prompt}\n")
permutation, out_token_count = self.run_llm(
Expand Down Expand Up @@ -180,6 +272,7 @@ def shuffle_and_rescore(
cand["score"] = 1.0 / (i + 1)
cand["rank"] = i + 1

# @deprecated("old sliding window pipeline is deprecated. please use reorder policy")
def sliding_windows_batched(
self,
requests: List[Request],
Expand Down Expand Up @@ -230,6 +323,7 @@ def sliding_windows_batched(
start_pos = start_pos - step
return rerank_results

# @deprecated("old sliding window pipeline is deprecated. please use reorder policy")
def sliding_windows(
self,
request: Request,
Expand Down Expand Up @@ -338,7 +432,7 @@ def get_ranking_cost(
start_pos = rank_end - window_size
while start_pos >= rank_start:
start_pos = max(start_pos, rank_start)
prompt, _ = self.create_prompt(result, start_pos, end_pos)
prompt, _ = self.create_prompt(result, list(range(start_pos, end_pos)))
input_token_count += self.get_num_tokens(prompt)
end_pos = end_pos - step
start_pos = start_pos - step
Expand All @@ -353,7 +447,8 @@ def _clean_response(self, response: str) -> str:
new_response = ""
for c in response:
if not c.isdigit():
new_response += " "
if len(new_response) == 0 or new_response[-1] != " ":
new_response += " "
else:
new_response += c
new_response = new_response.strip()
Expand Down Expand Up @@ -439,3 +534,100 @@ def convert_doc_to_prompt_content(
# For Japanese should cut by character: content = content[:int(max_length)]
content = " ".join(content.split()[: int(max_length)])
return self._replace_number(content)

def _permutation_to_rank(self, perm_string: str, selected_indices: List[int]):
perm = [
int(x) - 1 for x in self._clean_response(perm_string).strip().split(" ")
]
perm = [
int(x)
for x in self._remove_duplicate(perm)
if 0 <= int(x) < len(selected_indices)
]
perm = perm + [i for i in range(len(selected_indices)) if i not in perm]
return perm

def _get_model_function(
self, batched: bool = False, **kwargs
) -> Tuple[ModelFunction, RerankConsumption]:
# [(Request, SelectIndex)] -> [Prompt]

consumption = RerankConsumption(0, 0)

if batched:

def create_prompt(batch: List[Tuple[Result, List[int]]]):
return [
prompt
for prompt, _ in self.create_prompt_batched(
[result for result, selected_indices in batch],
[selected_indices for result, selected_indices in batch],
32,
)
]

def execute(
batch: List[Union[str, Dict[str, str]]],
selected_indices_batch: List[List[int]],
):
consumption.consumption_reference_by_batch += 1
consumption.consumption_reference_by_item += len(batch)

return [
self._permutation_to_rank(s, selected_indices)
for (s, _), selected_indices in zip(
self.run_llm_batched(batch, **kwargs), selected_indices_batch
)
]

else:

def create_prompt(batch: List[Tuple[Result, List[int]]]):
return [
self.create_prompt(result, selected_indices)[0]
for result, selected_indices in batch
]

def execute(
batch: List[Union[str, Dict[str, str]]],
selected_indices_batch: List[List[int]],
):
consumption.consumption_reference_by_batch += 1
consumption.consumption_reference_by_item += len(batch)

return [
self._permutation_to_rank(
self.run_llm(x, **kwargs)[0], selected_indices
)
for x, selected_indices in zip(batch, selected_indices_batch)
]

return (
ModelFunction(
create_prompt=create_prompt,
execute=execute,
window_size=self._window_size,
),
consumption,
)

@staticmethod
def get_reorder_policy(reorder_policy: str, **kwargs):
for policy in SUPPORT_REORDER_POLICIES:
if reorder_policy.startswith(policy.name()):
reorder_params = reorder_policy[len(policy.name()) :]
if len(reorder_params) <= 1:
return policy()
else:
assert reorder_params[0] == ":" and reorder_params[1] == "{"
reorder_params = reorder_params[1:]
try:
reorder_param_dict = json.loads(reorder_params)
if not isinstance(reorder_param_dict, dict):
raise Exception()
except Exception as e:
print(e)
raise Exception(f"Cannot load reorder policy {reorder_policy}")
return policy(**reorder_param_dict)

raise Exception(f"Cannot find reorder policy {reorder_policy}")
10 changes: 10 additions & 0 deletions src/rank_llm/rerank/listwise/lit5_reranker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from rank_llm.data import Request, Result
from rank_llm.rerank.listwise.rank_fid import RankFiDDistill, RankFiDScore
from rank_llm.rerank.listwise.reorder.reorder_policy import (
ReorderPolicy,
SlidingWindowReorderPolicy,
)
from rank_llm.rerank.rankllm import PromptMode
from rank_llm.rerank.reranker import Reranker

Expand All @@ -11,12 +15,16 @@ def __init__(
context_size: int = 300,
prompt_mode: PromptMode = PromptMode.LiT5,
window_size: int = 20,
reorder_policy: ReorderPolicy = None,
) -> None:
if reorder_policy is None:
reorder_policy = SlidingWindowReorderPolicy()
agent = RankFiDDistill(
model=model_path,
context_size=context_size,
prompt_mode=prompt_mode,
window_size=window_size,
reorder_policy=reorder_policy,
)
self._reranker = Reranker(agent)

Expand Down Expand Up @@ -62,13 +70,15 @@ def rerank(
class LiT5ScoreReranker:
def __init__(
self,
reorder_policy: ReorderPolicy,
model_path: str = "castorini/LiT5-Score-base",
context_size: int = 300,
prompt_mode: PromptMode = PromptMode.LiT5,
window_size: int = 20,
runfile_path: str = "runs/run.${topics}_${firststage}_${model//\//}",
) -> None:
agent = RankFiDScore(
reorder_policy=reorder_policy,
model=model_path,
context_size=context_size,
prompt_mode=prompt_mode,
Expand Down
Loading
Loading