From 8e90f1b81a401deb185f43089c5214f2528129d6 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Mon, 7 Aug 2023 21:59:39 +0530 Subject: [PATCH] [vicuna] add default config in case of sharded vicuna Signed-Off-by: Gaurav Shukla --- apps/stable_diffusion/web/ui/stablelm_ui.py | 29 +++++++++++++++++++-- shark/shark_generate_model_config.py | 2 +- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index 26a024ecfa..c92672baa3 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -113,6 +113,31 @@ def set_vicuna_model(model): vicuna_model = model +def get_default_config(): + import torch + from transformers import AutoTokenizer + + hf_model_path = "TheBloke/vicuna-7B-1.1-HF" + tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False) + compilation_prompt = "".join(["0" for _ in range(17)]) + compilation_input_ids = tokenizer( + compilation_prompt, + return_tensors="pt", + ).input_ids + compilation_input_ids = torch.tensor(compilation_input_ids).reshape( + [1, 19] + ) + firstVicunaCompileInput = (compilation_input_ids,) + from apps.language_models.src.model_wrappers.vicuna_model import ( + CombinedModel, + ) + from shark.shark_generate_model_config import GenerateConfigFile + + model = CombinedModel() + c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput) + c.split_into_layers() + + # TODO: Make chat reusable for UI and API def chat( curr_system_message, @@ -185,7 +210,7 @@ def chat( config_json = json.load(config_file) config_file.close() else: - config_json = None + config_json = get_default_config() vicuna_model = Vicuna( model_name, device=device, @@ -379,7 +404,7 @@ def view_json_file(file_obj): with gr.Group(): config_file = gr.File(label="Upload sharding configuration") json_view_button = gr.Button("View as JSON") - json_view = gr.JSON() + json_view = gr.JSON(interactive=True) json_view_button.click( fn=view_json_file, inputs=[config_file], outputs=[json_view] ) diff --git a/shark/shark_generate_model_config.py b/shark/shark_generate_model_config.py index a0476c5384..906a9a0bff 100644 --- a/shark/shark_generate_model_config.py +++ b/shark/shark_generate_model_config.py @@ -144,4 +144,4 @@ def generate_json(self, artifacts): model = CombinedModel() c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput) - c.split_into_dispatches("vulkan") + c.split_into_layers()