diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index 2cb490ce77..f31344d715 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -7,6 +7,7 @@ ) from apps.stable_diffusion.web.ui.utils import available_devices from datetime import datetime as dt +import json def user(message, history): @@ -106,7 +107,15 @@ def set_vicuna_model(model): # TODO: Make chat reusable for UI and API -def chat(curr_system_message, history, model, device, precision, cli=True): +def chat( + curr_system_message, + history, + model, + devices, + precision, + config_file, + cli=True, +): global past_key_values global vicuna_model @@ -121,10 +130,12 @@ def chat(curr_system_message, history, model, device, precision, cli=True): ]: from apps.language_models.scripts.vicuna import ( UnshardedVicuna, + ShardedVicuna, ) from apps.stable_diffusion.src import args if vicuna_model == 0: + device = devices[0] if "cuda" in device: device = "cuda" elif "sync" in device: @@ -137,14 +148,28 @@ def chat(curr_system_message, history, model, device, precision, cli=True): print("unrecognized device") max_toks = 128 if model_name == "codegen" else 512 - vicuna_model = UnshardedVicuna( - model_name, - hf_model_path=model_path, - hf_auth_token=args.hf_auth_token, - device=device, - precision=precision, - max_num_tokens=max_toks, - ) + if len(devices) == 1 and config_file is None: + vicuna_model = UnshardedVicuna( + model_name, + hf_model_path=model_path, + hf_auth_token=args.hf_auth_token, + device=device, + precision=precision, + max_num_tokens=max_toks, + ) + else: + if config_file is not None: + config_file = open(config_file) + config_json = json.load(config_file) + config_file.close() + else: + config_json = None + vicuna_model = ShardedVicuna( + model_name, + device=device, + precision=precision, + config_json=config_json, + ) prompt = create_prompt(model_name, history) for partial_text in vicuna_model.generate(prompt, cli=cli): @@ -307,13 +332,14 @@ def view_json_file(file_obj): supported_devices = supported_devices[-1:] + supported_devices[:-1] supported_devices = [x for x in supported_devices if "sync" not in x] print(supported_devices) - device = gr.Dropdown( + devices = gr.Dropdown( label="Device", value=supported_devices[0] if enabled else "Only CUDA Supported for now", choices=supported_devices, interactive=enabled, + multiselect=True, ) precision = gr.Radio( label="Precision", @@ -357,7 +383,7 @@ def view_json_file(file_obj): fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False ).then( fn=chat, - inputs=[system_msg, chatbot, model, device, precision], + inputs=[system_msg, chatbot, model, devices, precision, config_file], outputs=[chatbot], queue=True, ) @@ -365,7 +391,7 @@ def view_json_file(file_obj): fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False ).then( fn=chat, - inputs=[system_msg, chatbot, model, device, precision], + inputs=[system_msg, chatbot, model, devices, precision, config_file], outputs=[chatbot], queue=True, )