diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index adb7e428d5..2907f2e891 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -364,7 +364,7 @@ def combine_mlir_scripts( with open(output_name, "rb") as f: return f.read() - def generate_new_token(self, params, sharded=True): + def generate_new_token(self, params, sharded=True, cli=True): is_first = params["is_first"] if is_first: prompt = params["prompt"] @@ -417,7 +417,8 @@ def generate_new_token(self, params, sharded=True): "past_key_values": _past_key_values, } - print(f" token : {_token} | detok : {_detok}") + if cli: + print(f" token : {_token} | detok : {_detok}") return ret_dict @@ -1632,14 +1633,14 @@ def generate(self, prompt, cli=True): params = {"prompt": prompt, "is_first": True, "fv": self.shark_model} generated_token_op = self.generate_new_token( - params=params, sharded=False + params=params, sharded=False, cli=False ) token = generated_token_op["token"] logits = generated_token_op["logits"] pkv = generated_token_op["past_key_values"] detok = generated_token_op["detok"] - yield detok + yield detok, "" res_tokens.append(token) if cli: @@ -1672,14 +1673,11 @@ def generate(self, prompt, cli=True): else: if cli: print(f"{detok}", end=" ", flush=True) - - if len(res_tokens) % 3 == 0: - part_str = self.decode_tokens(res_tokens) - yield part_str + yield detok, "" res_str = self.decode_tokens(res_tokens) # print(f"[DEBUG] final output : \n{res_str}") - yield res_str + yield res_str, "formatted" def autocomplete(self, prompt): # use First vic alone to complete a story / prompt / sentence. diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index e125152e0b..2e9e56ff56 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -8,6 +8,7 @@ from apps.stable_diffusion.web.ui.utils import available_devices from datetime import datetime as dt import json +import time def user(message, history): @@ -225,13 +226,27 @@ def chat( prompt = create_prompt(model_name, history) - for partial_text in progress.tqdm( - vicuna_model.generate(prompt, cli=cli), desc="generating response" + partial_text = "" + count = 0 + start_time = time.time() + for text, msg in progress.tqdm( + vicuna_model.generate(prompt, cli=False), + desc="generating response", ): - history[-1][1] = partial_text - yield history + count += 1 + if "formatted" in msg: + history[-1][1] = text + end_time = time.time() + tokens_per_sec = count / (end_time - start_time) + yield history, str( + format(tokens_per_sec, ".2f") + ) + " tokens/sec" + else: + partial_text += text + " " + history[-1][1] = partial_text + yield history, "" - return history + return history, "" # else Model is StableLM global sharkModel @@ -256,7 +271,6 @@ def chat( partial_text = "" for new_text in words_list: - print(new_text) partial_text += new_text history[-1][1] = partial_text # Yield an empty string to clean up the message textbox and the updated @@ -386,7 +400,7 @@ def view_json_file(file_obj): # show cpu-task device first in list for chatbot supported_devices = supported_devices[-1:] + supported_devices[:-1] supported_devices = [x for x in supported_devices if "sync" not in x] - print(supported_devices) + # print(supported_devices) devices = gr.Dropdown( label="Device", value=supported_devices[0] @@ -406,6 +420,8 @@ def view_json_file(file_obj): ], visible=True, ) + tokens_time = gr.Textbox(label="Tokens generated per second") + with gr.Row(visible=False): with gr.Group(): config_file = gr.File( @@ -440,7 +456,7 @@ def view_json_file(file_obj): ).then( fn=chat, inputs=[system_msg, chatbot, model, devices, precision, config_file], - outputs=[chatbot], + outputs=[chatbot, tokens_time], queue=True, ) submit_click_event = submit.click( @@ -448,7 +464,7 @@ def view_json_file(file_obj): ).then( fn=chat, inputs=[system_msg, chatbot, model, devices, precision, config_file], - outputs=[chatbot], + outputs=[chatbot, tokens_time], queue=True, ) stop.click(