diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 04d5df23cb..6f857455a5 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -1635,7 +1635,7 @@ def generate(self, prompt, cli=True): logits = generated_token_op["logits"] pkv = generated_token_op["past_key_values"] detok = generated_token_op["detok"] - yield detok + yield detok, "" res_tokens.append(token) if cli: @@ -1668,14 +1668,11 @@ def generate(self, prompt, cli=True): else: if cli: print(f"{detok}", end=" ", flush=True) - - if len(res_tokens) % 3 == 0: - part_str = self.decode_tokens(res_tokens) - yield part_str + yield detok, "" res_str = self.decode_tokens(res_tokens) # print(f"[DEBUG] final output : \n{res_str}") - yield res_str + yield res_str, "formatted" def autocomplete(self, prompt): # use First vic alone to complete a story / prompt / sentence. diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index 9daeb06081..67cb8260f3 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -8,6 +8,7 @@ from apps.stable_diffusion.web.ui.utils import available_devices from datetime import datetime as dt import json +import time def user(message, history): @@ -214,13 +215,25 @@ def chat( prompt = create_prompt(model_name, history) - for partial_text in progress.tqdm( - vicuna_model.generate(prompt, cli=cli), desc="generating response" + partial_text = "" + count = 0 + start_time = time.time() + for text, msg in progress.tqdm( + vicuna_model.generate(prompt, cli=False), + desc="generating response", ): - history[-1][1] = partial_text - yield history + count += 1 + if "formatted" in msg: + history[-1][1] = text + end_time = time.time() + tokens_per_sec = count / (end_time - start_time) + yield history, str(tokens_per_sec) + "tokens/sec" + else: + partial_text += text + " " + history[-1][1] = partial_text + yield history, "" - return history + return history, "" # else Model is StableLM global sharkModel @@ -395,6 +408,8 @@ def view_json_file(file_obj): ], visible=True, ) + tokens_time = gr.Textbox() + with gr.Row(visible=False): with gr.Group(): config_file = gr.File( @@ -429,7 +444,7 @@ def view_json_file(file_obj): ).then( fn=chat, inputs=[system_msg, chatbot, model, devices, precision, config_file], - outputs=[chatbot], + outputs=[chatbot, tokens_time], queue=True, ) submit_click_event = submit.click( @@ -437,7 +452,7 @@ def view_json_file(file_obj): ).then( fn=chat, inputs=[system_msg, chatbot, model, devices, precision, config_file], - outputs=[chatbot], + outputs=[chatbot, tokens_time], queue=True, ) stop.click(