From 166f3ed2fa7142d79adb96d3ab8e73f391ce9c0d Mon Sep 17 00:00:00 2001 From: Rushabh Solanki Date: Mon, 4 Mar 2024 00:10:32 +0000 Subject: [PATCH 1/3] Adds script for AWQ-quantizing model --- save_awq.py | 99 +++++++++++++++++++++ src/rank_llm/rerank/rank_listwise_os_llm.py | 16 +++- 2 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 save_awq.py diff --git a/save_awq.py b/save_awq.py new file mode 100644 index 00000000..74e28636 --- /dev/null +++ b/save_awq.py @@ -0,0 +1,99 @@ +"""Converts and store AWQ-quantized model.""" + +import argparse +import json +import logging + +import awq +import transformers + +QUANT_CONFIG = { + "zero_point": True, + "q_group_size": 128, + "w_bit": 4, + "version": "GEMM", +} + + +def parse_args(): + """Parses command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + type=str, + default="msp_open_ai_ada2_random_s5000_gpt4_da0_mr20_sampled_mix.jsonl", + help="Path to the calibration dataset.", + ) + parser.add_argument( + "--model_path", + type=str, + default="castorini/rank_zephyr_7b_v1_full", + help="Path/slug to the original model.", + ) + parser.add_argument( + "--quant_path", + type=str, + default="awq_rank_zephyr_7b_v1_full", + help="Path/slug where the quantized model is to be stored.") + args = parser.parse_args() + return args + + +def load_dataset(dataset: str): + """Returns list of prompts for given dataset.""" + with open(dataset, "r") as file: + data = json.load(file) + prompts = [] + for content in data: + content = content["conversations"] + prompt = "" + for prompt_dict in content: + if prompt_dict["from"] == "system": + prompt += prompt_dict["value"] + "\n" + for prompt_dict in content: + if prompt_dict["from"] == "human": + prompt += prompt_dict["value"] + "\n" + for prompt_dict in content: + if prompt_dict["from"] == "gpt": + prompt += prompt_dict["value"] + prompts.append(prompt) + return prompts + + +def main(): + """Entry point of the script.""" + args = parse_args() + model_path = args.model_path + quant_path = args.quant_path + dataset = args.dataset + + # Load model + logging.info(f"Loading model from {model_path}.") + model = awq.AutoAWQForCausalLM.from_pretrained(model_path) + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True) + + logging.info(f"Starting AWQ with data {dataset}.") + model.quantize( + tokenizer=tokenizer, + quant_config=QUANT_CONFIG, + calib_data=load_dataset(dataset=dataset), + ) + + # Convert config into appropriate format. + quantization_config = transformers.AwqConfig( + bits=QUANT_CONFIG["w_bit"], + group_size=QUANT_CONFIG["q_group_size"], + zero_point=QUANT_CONFIG["zero_point"], + version=QUANT_CONFIG["version"].lower(), + ).to_dict() + model.model.config.quantization_config = quantization_config + + logging.info(f"Saving quantized model at {quant_path}.") + model.save_quantized(quant_path) + tokenizer.save_pretrained(quant_path) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main() diff --git a/src/rank_llm/rerank/rank_listwise_os_llm.py b/src/rank_llm/rerank/rank_listwise_os_llm.py index 7eac11c6..01d87885 100644 --- a/src/rank_llm/rerank/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/rank_listwise_os_llm.py @@ -2,9 +2,11 @@ import random from typing import Optional, Tuple +from awq import AutoAWQForCausalLM import torch from fastchat.model import get_conversation_template, load_model from ftfy import fix_text +from transformers import AutoTokenizer from transformers.generation import GenerationConfig from rank_llm.rerank.rankllm import PromptMode, RankLLM @@ -62,7 +64,19 @@ def __init__( f"Unsupported prompt mode: {prompt_mode}. The only prompt mode currently supported is a slight variation of Rank_GPT prompt." ) # ToDo: Make repetition_penalty configurable - self._llm, self._tokenizer = load_model(model, device=device, num_gpus=num_gpus) + if "awq" in model: + self._llm = AutoAWQForCausalLM.from_quantized( + model, + fuse_layers=True, + max_seq_len=context_size, + ).to(0) + self._tokenizer = AutoTokenizer.from_pretrained(model) + else: + self._llm, self._tokenizer = load_model( + model, + device=device, + num_gpus=num_gpus, + ) self._variable_passages = variable_passages self._window_size = window_size self._system_message = system_message From eeff31701a3739c0a39e35b0fe3da46ebe289366 Mon Sep 17 00:00:00 2001 From: Rushabh Solanki Date: Mon, 4 Mar 2024 00:15:11 +0000 Subject: [PATCH 2/3] Refactored changes --- save_awq.py | 6 ++++-- src/rank_llm/rerank/rank_listwise_os_llm.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/save_awq.py b/save_awq.py index 74e28636..d3ababbd 100644 --- a/save_awq.py +++ b/save_awq.py @@ -34,7 +34,8 @@ def parse_args(): "--quant_path", type=str, default="awq_rank_zephyr_7b_v1_full", - help="Path/slug where the quantized model is to be stored.") + help="Path/slug where the quantized model is to be stored.", + ) args = parser.parse_args() return args @@ -71,7 +72,8 @@ def main(): logging.info(f"Loading model from {model_path}.") model = awq.AutoAWQForCausalLM.from_pretrained(model_path) tokenizer = transformers.AutoTokenizer.from_pretrained( - model_path, trust_remote_code=True) + model_path, trust_remote_code=True + ) logging.info(f"Starting AWQ with data {dataset}.") model.quantize( diff --git a/src/rank_llm/rerank/rank_listwise_os_llm.py b/src/rank_llm/rerank/rank_listwise_os_llm.py index 01d87885..7af9aa01 100644 --- a/src/rank_llm/rerank/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/rank_listwise_os_llm.py @@ -2,8 +2,8 @@ import random from typing import Optional, Tuple -from awq import AutoAWQForCausalLM import torch +from awq import AutoAWQForCausalLM from fastchat.model import get_conversation_template, load_model from ftfy import fix_text from transformers import AutoTokenizer From b5d92f4164f14c950708b1c985e25482fe20e36a Mon Sep 17 00:00:00 2001 From: Rushabh Solanki Date: Thu, 4 Apr 2024 19:03:53 +0000 Subject: [PATCH 3/3] Moved file to scripts folder --- requirements.txt | 1 + save_awq.py => src/rank_llm/scripts/save_awq.py | 12 +++--------- 2 files changed, 4 insertions(+), 9 deletions(-) rename save_awq.py => src/rank_llm/scripts/save_awq.py (85%) diff --git a/requirements.txt b/requirements.txt index a5052abd..2957ff21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ python-dotenv>=1.0.1 faiss-cpu>=1.7.2 ftfy>=6.1.3 fschat[model_worker]>=0.2.35 +autoawq==0.2.4 diff --git a/save_awq.py b/src/rank_llm/scripts/save_awq.py similarity index 85% rename from save_awq.py rename to src/rank_llm/scripts/save_awq.py index d3ababbd..2fa52393 100644 --- a/save_awq.py +++ b/src/rank_llm/scripts/save_awq.py @@ -47,16 +47,10 @@ def load_dataset(dataset: str): prompts = [] for content in data: content = content["conversations"] + prompt_dict = {pmt["from"]: pmt["value"] for pmt in content} prompt = "" - for prompt_dict in content: - if prompt_dict["from"] == "system": - prompt += prompt_dict["value"] + "\n" - for prompt_dict in content: - if prompt_dict["from"] == "human": - prompt += prompt_dict["value"] + "\n" - for prompt_dict in content: - if prompt_dict["from"] == "gpt": - prompt += prompt_dict["value"] + for key in ["system", "human", "gpt"]: + prompt += prompt_dict[key] + "\n" prompts.append(prompt) return prompts