Skip to content

Commit

Permalink
[chatbot] Add tokens generated per second
Browse files Browse the repository at this point in the history
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
  • Loading branch information
Shukla-Gaurav committed Aug 13, 2023
1 parent 18801dc commit 76f564c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
9 changes: 3 additions & 6 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,7 +1635,7 @@ def generate(self, prompt, cli=True):
logits = generated_token_op["logits"]
pkv = generated_token_op["past_key_values"]
detok = generated_token_op["detok"]
yield detok
yield detok, ""

res_tokens.append(token)
if cli:
Expand Down Expand Up @@ -1668,14 +1668,11 @@ def generate(self, prompt, cli=True):
else:
if cli:
print(f"{detok}", end=" ", flush=True)

if len(res_tokens) % 3 == 0:
part_str = self.decode_tokens(res_tokens)
yield part_str
yield detok, ""

res_str = self.decode_tokens(res_tokens)
# print(f"[DEBUG] final output : \n{res_str}")
yield res_str
yield res_str, "formatted"

def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.
Expand Down
31 changes: 24 additions & 7 deletions apps/stable_diffusion/web/ui/stablelm_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from apps.stable_diffusion.web.ui.utils import available_devices
from datetime import datetime as dt
import json
import time


def user(message, history):
Expand Down Expand Up @@ -214,13 +215,27 @@ def chat(

prompt = create_prompt(model_name, history)

for partial_text in progress.tqdm(
vicuna_model.generate(prompt, cli=cli), desc="generating response"
partial_text = ""
count = 0
start_time = time.time()
for text, msg in progress.tqdm(
vicuna_model.generate(prompt, cli=False),
desc="generating response",
):
history[-1][1] = partial_text
yield history
count += 1
if "formatted" in msg:
history[-1][1] = text
end_time = time.time()
tokens_per_sec = count / (end_time - start_time)
yield history, str(
format(tokens_per_sec, ".2f")
) + " tokens/sec"
else:
partial_text += text + " "
history[-1][1] = partial_text
yield history, ""

return history
return history, ""

# else Model is StableLM
global sharkModel
Expand Down Expand Up @@ -395,6 +410,8 @@ def view_json_file(file_obj):
],
visible=True,
)
tokens_time = gr.Textbox(label="Tokens generated per second")

with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(
Expand Down Expand Up @@ -429,15 +446,15 @@ def view_json_file(file_obj):
).then(
fn=chat,
inputs=[system_msg, chatbot, model, devices, precision, config_file],
outputs=[chatbot],
outputs=[chatbot, tokens_time],
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, devices, precision, config_file],
outputs=[chatbot],
outputs=[chatbot, tokens_time],
queue=True,
)
stop.click(
Expand Down

0 comments on commit 76f564c

Please sign in to comment.