-
Notifications
You must be signed in to change notification settings - Fork 5
/
llama_prompter.py
160 lines (137 loc) · 6.25 KB
/
llama_prompter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
File: llamma_prompter.py
Provides a simple prompt-based conversation interface for a language model
(Llama 2), acting as a stream handler and allowing the user to input commands
and receive responses the model while preventing model monologues.
"""
from threading import Thread
from huggingface_hub import hf_hub_download
from huggingface_hub import snapshot_download as hf_hub_snap_download
from huggingface_hub import login as hf_hub_login
from llama_cpp import Llama # type: ignore
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
try:
import torch
print("CUDA Available for Pytorch: " + str(torch.cuda.is_available()))
except ModuleNotFoundError:
print("module 'torch' is not installed or CUDA is not available.")
try:
from auto_gptq import AutoGPTQForCausalLM
except ModuleNotFoundError:
print("Module 'auto_gptq' with CUDA extensions required for GPTQ models.")
from transformers import TextIteratorStreamer
from llama_formatter import llama_formatter
class llama_prompter:
timeout = 60*60*24 # seconds
model_metadata = None
model = None
tokenizer = None
formatter = None
thread = None
""" Initialize the model and tokenizer from metadata specifications """
def __init__(self, model_metadata: dict, huggingface_token: str):
self.model_metadata = model_metadata
self.formatter = llama_formatter()
if model_metadata["online"]:
print("Downloading the model...")
if ('file' not in model_metadata):
hf_hub_login(token=huggingface_token)
hf_hub_snap_download(
repo_id=model_metadata["name"],
local_dir=model_metadata["path"],
local_dir_use_symlinks=True
)
else:
file_path = hf_hub_download(
repo_id=model_metadata["name"],
filename=model_metadata["file"],
local_dir=model_metadata["path"],
local_dir_use_symlinks=True
)
else:
pass
# https://huggingface.co/docs/transformers/index
print("Loading the model...")
if model_metadata["architecture"] == "ggml":
self.model = Llama(file_path, n_ctx=2048) # 4096
self.tokenizer = None
elif model_metadata["architecture"] == "gptq":
self.model = AutoGPTQForCausalLM.from_quantized(
model_metadata["name"],
device_map="auto",
use_safetensors=True)
self.tokenizer = AutoTokenizer.from_pretrained(
model_metadata["name"])
elif model_metadata["architecture"] == "original":
self.model = AutoModelForCausalLM.from_pretrained(
model_metadata["name"],
device_map="auto",
token=True)
self.tokenizer = AutoTokenizer.from_pretrained(
model_metadata["name"],
token=True)
elif model_metadata["architecture"] == "tlrsft":
self.model = AutoModelForCausalLMWithValueHead.from_pretrained(
model_metadata["path"],
device_map=None,
offload_folder="offload/",
token=True
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_metadata["path"],
token=True)
if hasattr(self.model, 'device'):
print(f"MODEL_DEVICE: {self.model.device}")
else:
print("MODEL_DEVICE: Undefined")
print("Model loaded.")
""" Get the current prompt text in llama v2 format """
def get_prompt(self) -> str:
return self.formatter.format()
""" Add a new text to the stack of prompts """
def stack(self, role: str, text: str) -> None:
self.formatter.add(role, text)
def empty(self) -> None:
self.formatter.empty()
""" Submit a prompt to the model and return a streamer object """
def submit(self, prompt: str):
kwargs = dict(temperature=0.6, top_p=0.9)
if self.model_metadata["architecture"] == 'ggml':
kwargs["max_tokens"] = 512
# stream=False do not solve the broken emojies issue
# https://github.com/abetlen/llama-cpp-python/issues/372
streamer = self.model(prompt=prompt, stream=True, **kwargs)
else:
streamer = TextIteratorStreamer(
self.tokenizer, skip_prompt=True, timeout=self.timeout)
if self.model_metadata["architecture"] == "tlrsft":
inputs = self.tokenizer(
prompt, return_tensors="pt") # device_map = None
else:
inputs = self.tokenizer(
prompt, return_tensors="pt").to(self.model.device)
kwargs["max_new_tokens"] = 512
kwargs["input_ids"] = inputs["input_ids"]
kwargs["streamer"] = streamer
self.thread = Thread(target=self.model.generate, kwargs=kwargs)
self.thread.start()
return streamer
""" Prevent model monologues by checking the chat history"""
def check_history(self, new_text: str, history: list) -> bool:
bloviated = False # True if model got crazy (monologue)
merged = history[-1][1] + new_text
if (self.formatter.BOS not in merged
and self.formatter.EOS not in merged):
history[-1][1] += new_text # Update chat history
self.formatter.concat_last(new_text) # Concat to last entry
else:
bloviated = True # We need to cut the monologue part
bos_pos = merged.find(self.formatter.BOS)
eos_pos = merged.find(self.formatter.EOS)
cut_pos = min(bos_pos, eos_pos) # Assume is the 1st one
if (cut_pos == -1):
cut_pos = max(bos_pos, eos_pos) # Change to the last one
history[-1][1] = merged[:cut_pos] # Cut and update chat hist.
self.formatter.replace_last(history[-1][1]) # Replace last entry
return bloviated