-
Notifications
You must be signed in to change notification settings - Fork 273
/
inference-qlora.py
127 lines (109 loc) · 4.71 KB
/
inference-qlora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import sys
import math
import torch
import argparse
import textwrap
import transformers
from peft import PeftModel
from transformers import GenerationConfig, TextStreamer, BitsAndBytesConfig
from llama_attn_replace import replace_llama_attn
PROMPT_DICT = {
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
"prompt_no_input_llama2": (
"<s>[INST] <<SYS>>\n"
"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
"<</SYS>> \n\n {instruction} [/INST]"
),
"prompt_llama2": "[INST]{instruction}[/INST]"
}
def parse_config():
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--material', type=str, default="")
parser.add_argument('--question', type=str, default="")
parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf")
parser.add_argument('--cache_dir', type=str, default="./cache")
parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
parser.add_argument('--flash_attn', type=bool, default=False, help='')
parser.add_argument('--temperature', type=float, default=0.6, help='')
parser.add_argument('--top_p', type=float, default=0.9, help='')
parser.add_argument('--max_gen_len', type=int, default=512, help='')
args = parser.parse_args()
return args
def read_txt_file(material_txt):
if not material_txt.split(".")[-1]=='txt':
raise ValueError("Only support txt or pdf file.")
content = ""
with open(material_txt) as f:
for line in f.readlines():
content += line
return content
def build_generator(
model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True
):
def response(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
streamer = TextStreamer(tokenizer)
output = model.generate(
**inputs,
max_new_tokens=max_gen_len,
temperature=temperature,
top_p=top_p,
use_cache=use_cache,
streamer=streamer,
)
out = tokenizer.decode(output[0], skip_special_tokens=True)
out = out.split(prompt.lstrip("<s>"))[1].strip()
return out
return response
def main(args):
if args.flash_attn:
replace_llama_attn(inference=True)
# Set RoPE scaling factor
config = transformers.AutoConfig.from_pretrained(
args.base_model,
cache_dir=args.cache_dir,
)
orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len and args.context_size > orig_ctx_len:
scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
# Load model and tokenizer
model = transformers.AutoModelForCausalLM.from_pretrained(
args.base_model,
config=config,
cache_dir=args.cache_dir,
torch_dtype=torch.float16,
device_map="auto",
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
)
model.resize_token_embeddings(32001)
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.base_model,
cache_dir=args.cache_dir,
model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
padding_side="right",
use_fast=False,
)
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
respond = build_generator(model, tokenizer, temperature=args.temperature, top_p=args.top_p,
max_gen_len=args.max_gen_len, use_cache=True)
material = read_txt_file(args.material)
prompt_no_input = PROMPT_DICT["prompt_llama2"]
prompt = prompt_no_input.format_map({"instruction": material + "\n%s"%args.question})
output = respond(prompt=prompt)
if __name__ == "__main__":
args = parse_config()
main(args)