diff --git a/requirements.txt b/requirements.txt index a5052abd..2957ff21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ python-dotenv>=1.0.1 faiss-cpu>=1.7.2 ftfy>=6.1.3 fschat[model_worker]>=0.2.35 +autoawq==0.2.4 diff --git a/src/rank_llm/rerank/rank_listwise_os_llm.py b/src/rank_llm/rerank/rank_listwise_os_llm.py index 7eac11c6..7af9aa01 100644 --- a/src/rank_llm/rerank/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/rank_listwise_os_llm.py @@ -3,8 +3,10 @@ from typing import Optional, Tuple import torch +from awq import AutoAWQForCausalLM from fastchat.model import get_conversation_template, load_model from ftfy import fix_text +from transformers import AutoTokenizer from transformers.generation import GenerationConfig from rank_llm.rerank.rankllm import PromptMode, RankLLM @@ -62,7 +64,19 @@ def __init__( f"Unsupported prompt mode: {prompt_mode}. The only prompt mode currently supported is a slight variation of Rank_GPT prompt." ) # ToDo: Make repetition_penalty configurable - self._llm, self._tokenizer = load_model(model, device=device, num_gpus=num_gpus) + if "awq" in model: + self._llm = AutoAWQForCausalLM.from_quantized( + model, + fuse_layers=True, + max_seq_len=context_size, + ).to(0) + self._tokenizer = AutoTokenizer.from_pretrained(model) + else: + self._llm, self._tokenizer = load_model( + model, + device=device, + num_gpus=num_gpus, + ) self._variable_passages = variable_passages self._window_size = window_size self._system_message = system_message diff --git a/src/rank_llm/scripts/save_awq.py b/src/rank_llm/scripts/save_awq.py new file mode 100644 index 00000000..2fa52393 --- /dev/null +++ b/src/rank_llm/scripts/save_awq.py @@ -0,0 +1,95 @@ +"""Converts and store AWQ-quantized model.""" + +import argparse +import json +import logging + +import awq +import transformers + +QUANT_CONFIG = { + "zero_point": True, + "q_group_size": 128, + "w_bit": 4, + "version": "GEMM", +} + + +def parse_args(): + """Parses command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + type=str, + default="msp_open_ai_ada2_random_s5000_gpt4_da0_mr20_sampled_mix.jsonl", + help="Path to the calibration dataset.", + ) + parser.add_argument( + "--model_path", + type=str, + default="castorini/rank_zephyr_7b_v1_full", + help="Path/slug to the original model.", + ) + parser.add_argument( + "--quant_path", + type=str, + default="awq_rank_zephyr_7b_v1_full", + help="Path/slug where the quantized model is to be stored.", + ) + args = parser.parse_args() + return args + + +def load_dataset(dataset: str): + """Returns list of prompts for given dataset.""" + with open(dataset, "r") as file: + data = json.load(file) + prompts = [] + for content in data: + content = content["conversations"] + prompt_dict = {pmt["from"]: pmt["value"] for pmt in content} + prompt = "" + for key in ["system", "human", "gpt"]: + prompt += prompt_dict[key] + "\n" + prompts.append(prompt) + return prompts + + +def main(): + """Entry point of the script.""" + args = parse_args() + model_path = args.model_path + quant_path = args.quant_path + dataset = args.dataset + + # Load model + logging.info(f"Loading model from {model_path}.") + model = awq.AutoAWQForCausalLM.from_pretrained(model_path) + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True + ) + + logging.info(f"Starting AWQ with data {dataset}.") + model.quantize( + tokenizer=tokenizer, + quant_config=QUANT_CONFIG, + calib_data=load_dataset(dataset=dataset), + ) + + # Convert config into appropriate format. + quantization_config = transformers.AwqConfig( + bits=QUANT_CONFIG["w_bit"], + group_size=QUANT_CONFIG["q_group_size"], + zero_point=QUANT_CONFIG["zero_point"], + version=QUANT_CONFIG["version"].lower(), + ).to_dict() + model.model.config.quantization_config = quantization_config + + logging.info(f"Saving quantized model at {quant_path}.") + model.save_quantized(quant_path) + tokenizer.save_pretrained(quant_path) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main()