Skip to content

Commit

Permalink
[vicuna] add default config in case of sharded vicuna
Browse files Browse the repository at this point in the history
Signed-Off-by: Gaurav Shukla<gaurav@nod-labs.com>
  • Loading branch information
Shukla-Gaurav committed Aug 10, 2023
1 parent e8c1203 commit 8e90f1b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
29 changes: 27 additions & 2 deletions apps/stable_diffusion/web/ui/stablelm_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,31 @@ def set_vicuna_model(model):
vicuna_model = model


def get_default_config():
import torch
from transformers import AutoTokenizer

hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
CombinedModel,
)
from shark.shark_generate_model_config import GenerateConfigFile

model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_layers()


# TODO: Make chat reusable for UI and API
def chat(
curr_system_message,
Expand Down Expand Up @@ -185,7 +210,7 @@ def chat(
config_json = json.load(config_file)
config_file.close()
else:
config_json = None
config_json = get_default_config()
vicuna_model = Vicuna(
model_name,
device=device,
Expand Down Expand Up @@ -379,7 +404,7 @@ def view_json_file(file_obj):
with gr.Group():
config_file = gr.File(label="Upload sharding configuration")
json_view_button = gr.Button("View as JSON")
json_view = gr.JSON()
json_view = gr.JSON(interactive=True)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
Expand Down
2 changes: 1 addition & 1 deletion shark/shark_generate_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,4 @@ def generate_json(self, artifacts):

model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_dispatches("vulkan")
c.split_into_layers()

0 comments on commit 8e90f1b

Please sign in to comment.