-
Notifications
You must be signed in to change notification settings - Fork 0
/
gradio_template_phi3mini.py
100 lines (77 loc) · 2.58 KB
/
gradio_template_phi3mini.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Gradio chat template using Gradio's Chat Interface.
We do not need to store history manually, just need to
tokenize it properly.
"""
import gradio as gr
import threading
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
BitsAndBytesConfig
)
device = 'cuda'
quant_config = BitsAndBytesConfig(
load_in_4bit=True
)
tokenizer = AutoTokenizer.from_pretrained('microsoft/Phi-3-mini-4k-instruct')
model = AutoModelForCausalLM.from_pretrained(
'microsoft/Phi-3-mini-4k-instruct',
quantization_config=quant_config,
device_map=device
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
CONTEXT_LENGTH = 3800 # This uses around 9.9GB of GPU memory when highest context length is reached.
def generate_next_tokens(user_input, history):
print('History: ', history)
print('*' * 50)
chat = [
{'role': 'user', 'content': 'Hi'},
{'role': 'assistant', 'content': 'Hello.'},
{'role': 'user', 'content': user_input},
]
template = tokenizer.apply_chat_template(
chat,
tokenize=False,
add_generation_prompt=True
)
if len(history) == 0:
prompt = '<s>' + template
else:
prompt = '<s>'
for history_list in history:
prompt += f"<|user|>\n{history_list[0]}<|end|>\n<|assistant|>\n{history_list[1]}<|end|>\n"
prompt += f"<|user|>\n{user_input}<|end|>\n<|assistant|>\n"
print('Prompt: ', prompt)
print('*' * 50)
inputs = tokenizer(prompt, return_tensors='pt').to(device)
input_ids, attention_mask = inputs.input_ids, inputs.attention_mask
# A way to manage context length + memory for best results.
print('Global context length till now: ', input_ids.shape[1])
if input_ids.shape[1] > CONTEXT_LENGTH:
input_ids = input_ids[:, -CONTEXT_LENGTH:]
attention_mask = attention_mask[:, -CONTEXT_LENGTH:]
print('-' * 100)
generate_kwargs = dict(
{'input_ids': input_ids.to(device), 'attention_mask': attention_mask.to(device)},
streamer=streamer,
max_new_tokens=1024,
)
thread = threading.Thread(
target=model.generate,
kwargs=generate_kwargs
)
thread.start()
outputs = []
for new_token in streamer:
outputs.append(new_token)
final_output = ''.join(outputs)
yield final_output
input_text = gr.Textbox(lines=5, label='Prompt')
output_text = gr.Textbox(label='Generated Text')
iface = gr.ChatInterface(
fn=generate_next_tokens,
title='Token generator'
)
iface.launch()