Skip to content

Commit

Permalink
Merge branch 'main' into fix_vtt_vic
Browse files Browse the repository at this point in the history
  • Loading branch information
PhaneeshB authored Aug 13, 2023
2 parents ff07675 + 4dc9c59 commit 593e164
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
16 changes: 7 additions & 9 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def combine_mlir_scripts(
with open(output_name, "rb") as f:
return f.read()

def generate_new_token(self, params, sharded=True):
def generate_new_token(self, params, sharded=True, cli=True):
is_first = params["is_first"]
if is_first:
prompt = params["prompt"]
Expand Down Expand Up @@ -417,7 +417,8 @@ def generate_new_token(self, params, sharded=True):
"past_key_values": _past_key_values,
}

print(f" token : {_token} | detok : {_detok}")
if cli:
print(f" token : {_token} | detok : {_detok}")

return ret_dict

Expand Down Expand Up @@ -1632,14 +1633,14 @@ def generate(self, prompt, cli=True):
params = {"prompt": prompt, "is_first": True, "fv": self.shark_model}

generated_token_op = self.generate_new_token(
params=params, sharded=False
params=params, sharded=False, cli=False
)

token = generated_token_op["token"]
logits = generated_token_op["logits"]
pkv = generated_token_op["past_key_values"]
detok = generated_token_op["detok"]
yield detok
yield detok, ""

res_tokens.append(token)
if cli:
Expand Down Expand Up @@ -1672,14 +1673,11 @@ def generate(self, prompt, cli=True):
else:
if cli:
print(f"{detok}", end=" ", flush=True)

if len(res_tokens) % 3 == 0:
part_str = self.decode_tokens(res_tokens)
yield part_str
yield detok, ""

res_str = self.decode_tokens(res_tokens)
# print(f"[DEBUG] final output : \n{res_str}")
yield res_str
yield res_str, "formatted"

def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.
Expand Down
34 changes: 25 additions & 9 deletions apps/stable_diffusion/web/ui/stablelm_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from apps.stable_diffusion.web.ui.utils import available_devices
from datetime import datetime as dt
import json
import time


def user(message, history):
Expand Down Expand Up @@ -225,13 +226,27 @@ def chat(

prompt = create_prompt(model_name, history)

for partial_text in progress.tqdm(
vicuna_model.generate(prompt, cli=cli), desc="generating response"
partial_text = ""
count = 0
start_time = time.time()
for text, msg in progress.tqdm(
vicuna_model.generate(prompt, cli=False),
desc="generating response",
):
history[-1][1] = partial_text
yield history
count += 1
if "formatted" in msg:
history[-1][1] = text
end_time = time.time()
tokens_per_sec = count / (end_time - start_time)
yield history, str(
format(tokens_per_sec, ".2f")
) + " tokens/sec"
else:
partial_text += text + " "
history[-1][1] = partial_text
yield history, ""

return history
return history, ""

# else Model is StableLM
global sharkModel
Expand All @@ -256,7 +271,6 @@ def chat(

partial_text = ""
for new_text in words_list:
print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to clean up the message textbox and the updated
Expand Down Expand Up @@ -386,7 +400,7 @@ def view_json_file(file_obj):
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
print(supported_devices)
# print(supported_devices)
devices = gr.Dropdown(
label="Device",
value=supported_devices[0]
Expand All @@ -406,6 +420,8 @@ def view_json_file(file_obj):
],
visible=True,
)
tokens_time = gr.Textbox(label="Tokens generated per second")

with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(
Expand Down Expand Up @@ -440,15 +456,15 @@ def view_json_file(file_obj):
).then(
fn=chat,
inputs=[system_msg, chatbot, model, devices, precision, config_file],
outputs=[chatbot],
outputs=[chatbot, tokens_time],
queue=True,
)
submit_click_event = submit.click(
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[system_msg, chatbot, model, devices, precision, config_file],
outputs=[chatbot],
outputs=[chatbot, tokens_time],
queue=True,
)
stop.click(
Expand Down

0 comments on commit 593e164

Please sign in to comment.