forked from DeltaVML/Platypus
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
122 lines (100 loc) · 3.5 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import sys
import time
import pandas as pd
import fire
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter
import gc
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
try:
if torch.backends.mps.is_available():
device = "mps"
except: # noqa: E722
pass
def main(
load_8bit: bool = False,
base_model: str = "../llama30B_hf",
lora_weights: str = "",
prompt_template: str = "alpaca",
csv_path: str = "",
output_csv_path: str = ""
):
base_model = base_model or os.environ.get("BASE_MODEL", "")
prompter = Prompter(prompt_template)
tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
)
elif device == "mps":
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map={"": device},
torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(
model,
lora_weights,
device_map={"": device},
torch_dtype=torch.float16,
)
else:
model = LlamaForCausalLM.from_pretrained(
base_model, device_map={"": device}, low_cpu_mem_usage=True
)
model = PeftModel.from_pretrained(
model,
lora_weights,
device_map={"": device},
)
if not load_8bit:
model.half()
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
df = pd.read_csv(csv_path)
instructions = df["instruction"].tolist()
inputs = df["input"].tolist()
results = []
max_batch_size = 16
for i in range(0, len(instructions), max_batch_size):
instruction_batch = instructions[i:i + max_batch_size]
input_batch = inputs[i:i + max_batch_size]
print(f"Processing batch {i // max_batch_size + 1} of {len(instructions) // max_batch_size + 1}...")
start_time = time.time()
prompts = [prompter.generate_prompt(instruction, None) for instruction, input in zip(instruction_batch, input_batch)]
batch_results = evaluate(prompter, prompts, model, tokenizer)
results.extend(batch_results)
print(f"Finished processing batch {i // max_batch_size + 1}. Time taken: {time.time() - start_time:.2f} seconds")
df["model_output"] = results
df.to_csv(output_csv_path, index=False)
def evaluate(prompter, prompts, model, tokenizer):
batch_outputs = []
for prompt in prompts:
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
generation_output = model.generate(input_ids=input_ids, num_beams=1, num_return_sequences=1,
max_new_tokens=2048, temperature=0.15, top_p=0.95)
output = tokenizer.decode(generation_output[0], skip_special_tokens=True)
resp = prompter.get_response(output)
print(resp)
batch_outputs.append(resp)
return batch_outputs
if __name__ == "__main__":
torch.cuda.empty_cache()
fire.Fire(main)