diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py new file mode 100644 index 0000000000..9e92e58cb5 --- /dev/null +++ b/apps/shark_studio/api/llm.py @@ -0,0 +1,91 @@ +from turbine_models.custom_models import stateless_llama +from shark.iree_utils.compile_utils import get_iree_compiled_module +from apps.shark_studio.api.utils import get_resource_path +import iree.runtime as ireert +import gc +import torch + +llm_model_map = { + "llama2_7b": { + "initializer": stateless_llama.export_transformer_model, + "hf_model_name": "meta-llama/Llama-2-7b-chat-hf", + "stop_token": 2, + "max_tokens": 4096, + } +} + + +class LanguageModel: + def __init__( + self, model_name, hf_auth_token=None, device=None, precision="fp32" + ): + print(llm_model_map[model_name]) + self.hf_model_name = llm_model_map[model_name]["hf_model_name"] + self.torch_ir, self.tokenizer = llm_model_map[model_name][ + "initializer" + ](self.hf_model_name, hf_auth_token, compile_to="torch") + self.tempfile_name = get_resource_path("llm.torch.tempfile") + with open(self.tempfile_name, "w+") as f: + f.write(self.torch_ir) + del self.torch_ir + gc.collect() + + self.device = device + self.precision = precision + self.max_tokens = llm_model_map[model_name]["max_tokens"] + self.iree_module_dict = None + self.compile() + + def compile(self) -> None: + # this comes with keys: "vmfb", "config", and "temp_file_to_unlink". + self.iree_module_dict = get_iree_compiled_module( + self.tempfile_name, device=self.device, frontend="torch" + ) + # TODO: delete the temp file + + def chat(self, prompt): + history = [] + for iter in range(self.max_tokens): + input_tensor = self.tokenizer( + prompt, return_tensors="pt" + ).input_ids + device_inputs = [ + ireert.asdevicearray( + self.iree_module_dict["config"], input_tensor + ) + ] + if iter == 0: + token = torch.tensor( + self.iree_module_dict["vmfb"]["run_initialize"]( + *device_inputs + ).to_host()[0][0] + ) + else: + token = torch.tensor( + self.iree_module_dict["vmfb"]["run_forward"]( + *device_inputs + ).to_host()[0][0] + ) + + history.append(token) + yield self.tokenizer.decode(history) + + if token == llm_model_map["llama2_7b"]["stop_token"]: + break + + for i in range(len(history)): + if type(history[i]) != int: + history[i] = int(history[i]) + result_output = self.tokenizer.decode(history) + yield result_output + + +if __name__ == "__main__": + lm = LanguageModel( + "llama2_7b", + hf_auth_token="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", + device="cpu-task", + ) + print("model loaded") + for i in lm.chat("Hello, I am a robot."): + print(i) diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py new file mode 100644 index 0000000000..bb5e150364 --- /dev/null +++ b/apps/shark_studio/api/utils.py @@ -0,0 +1,14 @@ +import os +import sys + + +def get_available_devices(): + return ["cpu-task"] + + +def get_resource_path(relative_path): + """Get absolute path to resource, works for dev and for PyInstaller""" + base_path = getattr( + sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) + ) + return os.path.join(base_path, relative_path) diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py new file mode 100644 index 0000000000..59b66bee23 --- /dev/null +++ b/apps/shark_studio/web/index.py @@ -0,0 +1,428 @@ +from multiprocessing import Process, freeze_support +import os +import sys +import logging +from ui.chat import chat_element + +if sys.platform == "darwin": + os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib" + # import before IREE to avoid MLIR library issues + import torch_mlir + +# import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation +# from apps.stable_diffusion.src import args, clear_all +# import apps.stable_diffusion.web.utils.global_obj as global_obj + + +def launch_app(address): + from tkinter import Tk + import webview + + window = Tk() + + # get screen width and height of display and make it more reasonably + # sized as we aren't making it full-screen or maximized + width = int(window.winfo_screenwidth() * 0.81) + height = int(window.winfo_screenheight() * 0.91) + webview.create_window( + "SHARK AI Studio", + url=address, + width=width, + height=height, + text_select=True, + ) + webview.start(private_mode=False, storage_path=os.getcwd()) + + +if __name__ == "__main__": + # if args.debug: + logging.basicConfig(level=logging.DEBUG) + # required to do multiprocessing in a pyinstaller freeze + freeze_support() + # if args.api or "api" in args.ui.split(","): + # from apps.stable_diffusion.web.ui import ( + # txt2img_api, + # img2img_api, + # upscaler_api, + # inpaint_api, + # outpaint_api, + # llm_chat_api, + # ) + # + # from fastapi import FastAPI, APIRouter + # import uvicorn + # + # # init global sd pipeline and config + # global_obj._init() + # + # app = FastAPI() + # app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"]) + # app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"]) + # app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"]) + # app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"]) + # app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"]) + # + # # chat APIs needed for compatibility with multiple extensions using OpenAI API + # app.add_api_route( + # "/v1/chat/completions", llm_chat_api, methods=["post"] + # ) + # app.add_api_route("/v1/completions", llm_chat_api, methods=["post"]) + # app.add_api_route("/chat/completions", llm_chat_api, methods=["post"]) + # app.add_api_route("/completions", llm_chat_api, methods=["post"]) + # app.add_api_route( + # "/v1/engines/codegen/completions", llm_chat_api, methods=["post"] + # ) + # app.include_router(APIRouter()) + # uvicorn.run(app, host="0.0.0.0", port=args.server_port) + # sys.exit(0) + # + # Setup to use shark_tmp for gradio's temporary image files and clear any + # existing temporary images there if they exist. Then we can import gradio. + # It has to be in this order or gradio ignores what we've set up. + # from apps.stable_diffusion.web.utils.gradio_configs import ( + # config_gradio_tmp_imgs_folder, + # ) + + # config_gradio_tmp_imgs_folder() + import gradio as gr + + # Create custom models folders if they don't exist + # from apps.stable_diffusion.web.ui.utils import create_custom_models_folders + + # create_custom_models_folders() + + def resource_path(relative_path): + """Get absolute path to resource, works for dev and for PyInstaller""" + base_path = getattr( + sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) + ) + return os.path.join(base_path, relative_path) + + dark_theme = resource_path("ui/css/sd_dark_theme.css") + + # from apps.stable_diffusion.web.ui import ( + # txt2img_web, + # txt2img_custom_model, + # txt2img_gallery, + # txt2img_png_info_img, + # txt2img_status, + # txt2img_sendto_img2img, + # txt2img_sendto_inpaint, + # txt2img_sendto_outpaint, + # txt2img_sendto_upscaler, + ## h2ogpt_upload, + ## h2ogpt_web, + # img2img_web, + # img2img_custom_model, + # img2img_gallery, + # img2img_init_image, + # img2img_status, + # img2img_sendto_inpaint, + # img2img_sendto_outpaint, + # img2img_sendto_upscaler, + # inpaint_web, + # inpaint_custom_model, + # inpaint_gallery, + # inpaint_init_image, + # inpaint_status, + # inpaint_sendto_img2img, + # inpaint_sendto_outpaint, + # inpaint_sendto_upscaler, + # outpaint_web, + # outpaint_custom_model, + # outpaint_gallery, + # outpaint_init_image, + # outpaint_status, + # outpaint_sendto_img2img, + # outpaint_sendto_inpaint, + # outpaint_sendto_upscaler, + # upscaler_web, + # upscaler_custom_model, + # upscaler_gallery, + # upscaler_init_image, + # upscaler_status, + # upscaler_sendto_img2img, + # upscaler_sendto_inpaint, + # upscaler_sendto_outpaint, + ## lora_train_web, + ## model_web, + ## model_config_web, + # hf_models, + # modelmanager_sendto_txt2img, + # modelmanager_sendto_img2img, + # modelmanager_sendto_inpaint, + # modelmanager_sendto_outpaint, + # modelmanager_sendto_upscaler, + # stablelm_chat, + # minigpt4_web, + # outputgallery_web, + # outputgallery_tab_select, + # outputgallery_watch, + # outputgallery_filename, + # outputgallery_sendto_txt2img, + # outputgallery_sendto_img2img, + # outputgallery_sendto_inpaint, + # outputgallery_sendto_outpaint, + # outputgallery_sendto_upscaler, + # ) + + # init global sd pipeline and config + # global_obj._init() + + def register_button_click(button, selectedid, inputs, outputs): + button.click( + lambda x: ( + x[0]["name"] if len(x) != 0 else None, + gr.Tabs.update(selected=selectedid), + ), + inputs, + outputs, + ) + + def register_modelmanager_button(button, selectedid, inputs, outputs): + button.click( + lambda x: ( + "None", + x, + gr.Tabs.update(selected=selectedid), + ), + inputs, + outputs, + ) + + def register_outputgallery_button(button, selectedid, inputs, outputs): + button.click( + lambda x: ( + x, + gr.Tabs.update(selected=selectedid), + ), + inputs, + outputs, + ) + + with gr.Blocks( + css=dark_theme, analytics_enabled=False, title="Stable Diffusion" + ) as sd_web: + with gr.Tabs() as tabs: + # NOTE: If adding, removing, or re-ordering tabs, make sure that they + # have a unique id that doesn't clash with any of the other tabs, + # and that the order in the code here is the order they should + # appear in the ui, as the id value doesn't determine the order. + + # Where possible, avoid changing the id of any tab that is the + # destination of one of the 'send to' buttons. If you do have to change + # that id, make sure you update the relevant register_button_click calls + # further down with the new id. + # with gr.TabItem(label="Text-to-Image", id=0): + # txt2img_web.render() + # with gr.TabItem(label="Image-to-Image", id=1): + # img2img_web.render() + # with gr.TabItem(label="Inpainting", id=2): + # inpaint_web.render() + # with gr.TabItem(label="Outpainting", id=3): + # outpaint_web.render() + # with gr.TabItem(label="Upscaler", id=4): + # upscaler_web.render() + # if args.output_gallery: + # with gr.TabItem(label="Output Gallery", id=5) as og_tab: + # outputgallery_web.render() + + # # extra output gallery configuration + # outputgallery_tab_select(og_tab.select) + # outputgallery_watch( + # [ + # txt2img_status, + # img2img_status, + # inpaint_status, + # outpaint_status, + # upscaler_status, + # ] + # ) + ## with gr.TabItem(label="Model Manager", id=6): + ## model_web.render() + ## with gr.TabItem(label="LoRA Training (Experimental)", id=7): + ## lora_train_web.render() + with gr.TabItem(label="Chat Bot", id=0): + chat_element.render() + ## with gr.TabItem( + ## label="Generate Sharding Config (Experimental)", id=9 + ## ): + ## model_config_web.render() + # with gr.TabItem(label="MultiModal (Experimental)", id=10): + # minigpt4_web.render() + # with gr.TabItem(label="DocuChat Upload", id=11): + # h2ogpt_upload.render() + # with gr.TabItem(label="DocuChat(Experimental)", id=12): + # h2ogpt_web.render() + + # send to buttons + # register_button_click( + # txt2img_sendto_img2img, + # 1, + # [txt2img_gallery], + # [img2img_init_image, tabs], + # ) + # register_button_click( + # txt2img_sendto_inpaint, + # 2, + # [txt2img_gallery], + # [inpaint_init_image, tabs], + # ) + # register_button_click( + # txt2img_sendto_outpaint, + # 3, + # [txt2img_gallery], + # [outpaint_init_image, tabs], + # ) + # register_button_click( + # txt2img_sendto_upscaler, + # 4, + # [txt2img_gallery], + # [upscaler_init_image, tabs], + # ) + # register_button_click( + # img2img_sendto_inpaint, + # 2, + # [img2img_gallery], + # [inpaint_init_image, tabs], + # ) + # register_button_click( + # img2img_sendto_outpaint, + # 3, + # [img2img_gallery], + # [outpaint_init_image, tabs], + # ) + # register_button_click( + # img2img_sendto_upscaler, + # 4, + # [img2img_gallery], + # [upscaler_init_image, tabs], + # ) + # register_button_click( + # inpaint_sendto_img2img, + # 1, + # [inpaint_gallery], + # [img2img_init_image, tabs], + # ) + # register_button_click( + # inpaint_sendto_outpaint, + # 3, + # [inpaint_gallery], + # [outpaint_init_image, tabs], + # ) + # register_button_click( + # inpaint_sendto_upscaler, + # 4, + # [inpaint_gallery], + # [upscaler_init_image, tabs], + # ) + # register_button_click( + # outpaint_sendto_img2img, + # 1, + # [outpaint_gallery], + # [img2img_init_image, tabs], + # ) + # register_button_click( + # outpaint_sendto_inpaint, + # 2, + # [outpaint_gallery], + # [inpaint_init_image, tabs], + # ) + # register_button_click( + # outpaint_sendto_upscaler, + # 4, + # [outpaint_gallery], + # [upscaler_init_image, tabs], + # ) + # register_button_click( + # upscaler_sendto_img2img, + # 1, + # [upscaler_gallery], + # [img2img_init_image, tabs], + # ) + # register_button_click( + # upscaler_sendto_inpaint, + # 2, + # [upscaler_gallery], + # [inpaint_init_image, tabs], + # ) + # register_button_click( + # upscaler_sendto_outpaint, + # 3, + # [upscaler_gallery], + # [outpaint_init_image, tabs], + # ) + # if args.output_gallery: + # register_outputgallery_button( + # outputgallery_sendto_txt2img, + # 0, + # [outputgallery_filename], + # [txt2img_png_info_img, tabs], + # ) + # register_outputgallery_button( + # outputgallery_sendto_img2img, + # 1, + # [outputgallery_filename], + # [img2img_init_image, tabs], + # ) + # register_outputgallery_button( + # outputgallery_sendto_inpaint, + # 2, + # [outputgallery_filename], + # [inpaint_init_image, tabs], + # ) + # register_outputgallery_button( + # outputgallery_sendto_outpaint, + # 3, + # [outputgallery_filename], + # [outpaint_init_image, tabs], + # ) + # register_outputgallery_button( + # outputgallery_sendto_upscaler, + # 4, + # [outputgallery_filename], + # [upscaler_init_image, tabs], + # ) + # register_modelmanager_button( + # modelmanager_sendto_txt2img, + # 0, + # [hf_models], + # [txt2img_custom_model, tabs], + # ) + # register_modelmanager_button( + # modelmanager_sendto_img2img, + # 1, + # [hf_models], + # [img2img_custom_model, tabs], + # ) + # register_modelmanager_button( + # modelmanager_sendto_inpaint, + # 2, + # [hf_models], + # [inpaint_custom_model, tabs], + # ) + # register_modelmanager_button( + # modelmanager_sendto_outpaint, + # 3, + # [hf_models], + # [outpaint_custom_model, tabs], + # ) + # register_modelmanager_button( + # modelmanager_sendto_upscaler, + # 4, + # [hf_models], + # [upscaler_custom_model, tabs], + # ) + + sd_web.queue() + # if args.ui == "app": + # t = Process( + # target=launch_app, args=[f"http://localhost:{args.server_port}"] + # ) + # t.start() + sd_web.launch( + share=True, + inbrowser=True, + server_name="0.0.0.0", + server_port=11911, # args.server_port, + ) diff --git a/apps/shark_studio/web/ui/__init__.py b/apps/shark_studio/web/ui/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py new file mode 100644 index 0000000000..dd1c2d94e3 --- /dev/null +++ b/apps/shark_studio/web/ui/chat.py @@ -0,0 +1,517 @@ +import gradio as gr +import os +from pathlib import Path +from datetime import datetime as dt +import json +import sys +from apps.shark_studio.api.utils import ( + get_available_devices, +) +from apps.shark_studio.api.llm import ( + llm_model_map, + LanguageModel, +) + + +def user(message, history): + # Append the user's message to the conversation history + return "", history + [[message, ""]] + + +language_model = None + + +# NOTE: Each `model_name` should have its own start message +start_message = { + "llama2_7b": ( + "You are a helpful, respectful and honest assistant. Always answer " + "as helpfully as possible, while being safe. Your answers should not " + "include any harmful, unethical, racist, sexist, toxic, dangerous, or " + "illegal content. Please ensure that your responses are socially " + "unbiased and positive in nature. If a question does not make any " + "sense, or is not factually coherent, explain why instead of " + "answering something not correct. If you don't know the answer " + "to a question, please don't share false information." + ), + "llama2_13b": ( + "You are a helpful, respectful and honest assistant. Always answer " + "as helpfully as possible, while being safe. Your answers should not " + "include any harmful, unethical, racist, sexist, toxic, dangerous, or " + "illegal content. Please ensure that your responses are socially " + "unbiased and positive in nature. If a question does not make any " + "sense, or is not factually coherent, explain why instead of " + "answering something not correct. If you don't know the answer " + "to a question, please don't share false information." + ), + "llama2_70b": ( + "You are a helpful, respectful and honest assistant. Always answer " + "as helpfully as possible, while being safe. Your answers should not " + "include any harmful, unethical, racist, sexist, toxic, dangerous, or " + "illegal content. Please ensure that your responses are socially " + "unbiased and positive in nature. If a question does not make any " + "sense, or is not factually coherent, explain why instead of " + "answering something not correct. If you don't know the answer " + "to a question, please don't share false information." + ), + "vicuna": ( + "A chat between a curious user and an artificial intelligence " + "assistant. The assistant gives helpful, detailed, and " + "polite answers to the user's questions.\n" + ), +} + + +def create_prompt(model_name, history, prompt_prefix): + return "" + system_message = "" + if prompt_prefix: + system_message = start_message[model_name] + + if "llama2" in model_name: + B_INST, E_INST = "[INST]", "[/INST]" + B_SYS, E_SYS = "<>\n", "\n<>\n\n" + conversation = "".join( + [f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]] + ) + if prompt_prefix: + msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}" + else: + msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}" + elif model_name in ["vicuna"]: + conversation = "".join( + [ + "".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]]) + for item in history + ] + ) + msg = system_message + conversation + msg = msg.strip() + else: + conversation = "".join( + ["".join([item[0], item[1]]) for item in history] + ) + msg = system_message + conversation + msg = msg.strip() + return msg + + +def get_default_config(): + return False + import torch + from transformers import AutoTokenizer + + hf_model_path = "TheBloke/vicuna-7B-1.1-HF" + tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False) + compilation_prompt = "".join(["0" for _ in range(17)]) + compilation_input_ids = tokenizer( + compilation_prompt, + return_tensors="pt", + ).input_ids + compilation_input_ids = torch.tensor(compilation_input_ids).reshape( + [1, 19] + ) + firstVicunaCompileInput = (compilation_input_ids,) + from apps.language_models.src.model_wrappers.vicuna_model import ( + CombinedModel, + ) + from shark.shark_generate_model_config import GenerateConfigFile + + model = CombinedModel() + c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput) + c.split_into_layers() + + +# model_vmfb_key = "" + + +def chat_fn( + prompt_prefix, + history, + model, + device, + precision, + download_vmfb, + config_file, + cli=False, + progress=gr.Progress(), +): + global language_model + if language_model is None: + language_model = LanguageModel( + model, device=device, precision=precision + ) + + language_model.chat(prompt_prefix) + return "", "" + global past_key_values + global model_vmfb_key + + device_id = None + model_name, model_path = list(map(str.strip, model.split("=>"))) + if "cuda" in device: + device = "cuda" + elif "sync" in device: + device = "cpu-sync" + elif "task" in device: + device = "cpu-task" + elif "vulkan" in device: + device_id = int(device.split("://")[1]) + device = "vulkan" + elif "rocm" in device: + device = "rocm" + else: + print("unrecognized device") + + from apps.language_models.scripts.vicuna import ShardedVicuna + from apps.language_models.scripts.vicuna import UnshardedVicuna + from apps.stable_diffusion.src import args + + new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}" + if vicuna_model is None or new_model_vmfb_key != model_vmfb_key: + model_vmfb_key = new_model_vmfb_key + max_toks = 128 if model_name == "codegen" else 512 + + # get iree flags that need to be overridden, from commandline args + _extra_args = [] + # vulkan target triple + vulkan_target_triple = args.iree_vulkan_target_triple + from shark.iree_utils.vulkan_utils import ( + get_all_vulkan_devices, + get_vulkan_target_triple, + ) + + if device == "vulkan": + vulkaninfo_list = get_all_vulkan_devices() + if vulkan_target_triple == "": + # We already have the device_id extracted via WebUI, so we directly use + # that to find the target triple. + vulkan_target_triple = get_vulkan_target_triple( + vulkaninfo_list[device_id] + ) + _extra_args.append( + f"-iree-vulkan-target-triple={vulkan_target_triple}" + ) + if "rdna" in vulkan_target_triple: + flags_to_add = [ + "--iree-spirv-index-bits=64", + ] + _extra_args = _extra_args + flags_to_add + + if device_id is None: + id = 0 + for device in vulkaninfo_list: + target_triple = get_vulkan_target_triple( + vulkaninfo_list[id] + ) + if target_triple == vulkan_target_triple: + device_id = id + break + id += 1 + + assert ( + device_id + ), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists" + print(f"Will use vulkan target triple : {vulkan_target_triple}") + + elif "rocm" in device: + # add iree rocm flags + _extra_args.append( + f"--iree-rocm-target-chip={args.iree_rocm_target_chip}" + ) + print(f"extra args = {_extra_args}") + + if model_name == "vicuna4": + vicuna_model = ShardedVicuna( + model_name, + hf_model_path=model_path, + device=device, + precision=precision, + max_num_tokens=max_toks, + compressed=True, + extra_args_cmd=_extra_args, + ) + else: + # if config_file is None: + vicuna_model = UnshardedVicuna( + model_name, + hf_model_path=model_path, + hf_auth_token=args.hf_auth_token, + device=device, + vulkan_target_triple=vulkan_target_triple, + precision=precision, + max_num_tokens=max_toks, + download_vmfb=download_vmfb, + load_mlir_from_shark_tank=True, + extra_args_cmd=_extra_args, + device_id=device_id, + ) + + if vicuna_model is None: + sys.exit("Unable to instantiate the model object, exiting.") + + prompt = create_prompt(model_name, history, prompt_prefix) + + partial_text = "" + token_count = 0 + total_time_ms = 0.001 # In order to avoid divide by zero error + prefill_time = 0 + is_first = True + for text, msg, exec_time in progress.tqdm( + vicuna_model.generate(prompt, cli=cli), + desc="generating response", + ): + if msg is None: + if is_first: + prefill_time = exec_time + is_first = False + else: + total_time_ms += exec_time + token_count += 1 + partial_text += text + " " + history[-1][1] = partial_text + yield history, f"Prefill: {prefill_time:.2f}" + elif "formatted" in msg: + history[-1][1] = text + tokens_per_sec = (token_count / total_time_ms) * 1000 + yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec" + else: + sys.exit( + "unexpected message from the vicuna generate call, exiting." + ) + + return history, "" + + +def llm_chat_api(InputData: dict): + return None + print(f"Input keys : {InputData.keys()}") + # print(f"model : {InputData['model']}") + is_chat_completion_api = ( + "messages" in InputData.keys() + ) # else it is the legacy `completion` api + # For Debugging input data from API + # if is_chat_completion_api: + # print(f"message -> role : {InputData['messages'][0]['role']}") + # print(f"message -> content : {InputData['messages'][0]['content']}") + # else: + # print(f"prompt : {InputData['prompt']}") + # print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now + global vicuna_model + model_name = ( + InputData["model"] if "model" in InputData.keys() else "codegen" + ) + model_path = llm_model_map[model_name] + device = "cpu-task" + precision = "fp16" + max_toks = ( + None + if "max_tokens" not in InputData.keys() + else InputData["max_tokens"] + ) + if max_toks is None: + max_toks = 128 if model_name == "codegen" else 512 + + # make it working for codegen first + from apps.language_models.scripts.vicuna import ( + UnshardedVicuna, + ) + + device_id = None + if vicuna_model == 0: + if "cuda" in device: + device = "cuda" + elif "sync" in device: + device = "cpu-sync" + elif "task" in device: + device = "cpu-task" + elif "vulkan" in device: + device_id = int(device.split("://")[1]) + device = "vulkan" + else: + print("unrecognized device") + + vicuna_model = UnshardedVicuna( + model_name, + hf_model_path=model_path, + device=device, + precision=precision, + max_num_tokens=max_toks, + download_vmfb=True, + load_mlir_from_shark_tank=True, + device_id=device_id, + ) + + # TODO: add role dict for different models + if is_chat_completion_api: + # TODO: add funtionality for multiple messages + prompt = create_prompt( + model_name, [(InputData["messages"][0]["content"], "")] + ) + else: + prompt = InputData["prompt"] + print("prompt = ", prompt) + + res = vicuna_model.generate(prompt) + res_op = None + for op in res: + res_op = op + + if is_chat_completion_api: + choices = [ + { + "index": 0, + "message": { + "role": "assistant", + "content": res_op, # since we are yeilding the result + }, + "finish_reason": "stop", # or length + } + ] + else: + choices = [ + { + "text": res_op, + "index": 0, + "logprobs": None, + "finish_reason": "stop", # or length + } + ] + end_time = dt.now().strftime("%Y%m%d%H%M%S%f") + return { + "id": end_time, + "object": "chat.completion" + if is_chat_completion_api + else "text_completion", + "created": int(end_time), + "choices": choices, + } + + +def view_json_file(file_obj): + content = "" + with open(file_obj.name, "r") as fopen: + content = fopen.read() + return content + + +with gr.Blocks(title="Chat") as chat_element: + with gr.Row(): + model_choices = list(llm_model_map.keys()) + model = gr.Dropdown( + label="Select Model", + value=model_choices[0], + choices=model_choices, + allow_custom_value=True, + ) + supported_devices = get_available_devices() + enabled = True + if len(supported_devices) == 0: + supported_devices = ["cpu-task"] + supported_devices = [x for x in supported_devices if "sync" not in x] + device = gr.Dropdown( + label="Device", + value=supported_devices[0], + choices=supported_devices, + interactive=enabled, + allow_custom_value=True, + ) + precision = gr.Radio( + label="Precision", + value="int4", + choices=[ + # "int4", + # "int8", + # "fp16", + "fp32", + ], + visible=False, + ) + tokens_time = gr.Textbox(label="Tokens generated per second") + with gr.Column(): + download_vmfb = gr.Checkbox( + label="Download vmfb from Shark tank if available", + value=True, + interactive=True, + ) + prompt_prefix = gr.Checkbox( + label="Add System Prompt", + value=False, + interactive=True, + ) + + chatbot = gr.Chatbot(height=500) + with gr.Row(): + with gr.Column(): + msg = gr.Textbox( + label="Chat Message Box", + placeholder="Chat Message Box", + show_label=False, + interactive=enabled, + container=False, + ) + with gr.Column(): + with gr.Row(): + submit = gr.Button("Submit", interactive=enabled) + stop = gr.Button("Stop", interactive=enabled) + clear = gr.Button("Clear", interactive=enabled) + + with gr.Row(visible=False): + with gr.Group(): + config_file = gr.File( + label="Upload sharding configuration", visible=False + ) + json_view_button = gr.Button(label="View as JSON", visible=False) + json_view = gr.JSON(interactive=True, visible=False) + json_view_button.click( + fn=view_json_file, inputs=[config_file], outputs=[json_view] + ) + submit_event = msg.submit( + fn=user, + inputs=[msg, chatbot], + outputs=[msg, chatbot], + show_progress=False, + queue=False, + ).then( + fn=chat_fn, + inputs=[ + prompt_prefix, + chatbot, + model, + device, + precision, + download_vmfb, + config_file, + ], + outputs=[chatbot, tokens_time], + show_progress=False, + queue=True, + ) + submit_click_event = submit.click( + fn=user, + inputs=[msg, chatbot], + outputs=[msg, chatbot], + show_progress=False, + queue=False, + ).then( + fn=chat_fn, + inputs=[ + prompt_prefix, + chatbot, + model, + device, + precision, + download_vmfb, + config_file, + ], + outputs=[chatbot, tokens_time], + show_progress=False, + queue=True, + ) + stop.click( + fn=None, + inputs=None, + outputs=None, + cancels=[submit_event, submit_click_event], + queue=False, + ) + clear.click(lambda: None, None, [chatbot], queue=False) diff --git a/requirements.txt b/requirements.txt index cdce91e542..f43adae8fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,7 @@ pytest-forked Pillow parameterized +#shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main # Add transformers, diffusers and scipy since it most commonly used tokenizers==0.13.3 transformers @@ -49,4 +50,4 @@ pefile pyinstaller # vicuna quantization -brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea +brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea \ No newline at end of file diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 623968f1b0..b5c4527827 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -16,7 +16,6 @@ import os import re import tempfile -import time from pathlib import Path import iree.runtime as ireert @@ -63,8 +62,7 @@ def get_iree_device_args(device, extra_args=[]): + data_tiling_flag + u_kernel_flag + stack_size_flag - + ["--iree-flow-enable-quantized-matmul-reassociation"] - + ["--iree-llvmcpu-enable-quantized-matmul-reassociation"] + + ["--iree-global-opt-enable-quantized-matmul-reassociation"] ) if device_uri[0] == "cuda": from shark.iree_utils.gpu_utils import get_iree_gpu_args @@ -321,6 +319,8 @@ def compile_module_to_flatbuffer( input_type = "tosa" elif frontend in ["tm_tensor"]: input_type = ireec.InputType.TM_TENSOR + elif frontend in ["torch", "pytorch"]: + input_type = "torch" if compile_str: flatbuffer_blob = ireec.compile_str(