diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 1c68e12240..21af8ff068 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -46,7 +46,7 @@ jobs: draft: true prerelease: true - - name: Build Package + - name: Build Package (api only) shell: powershell run: | ./setup_venv.ps1 @@ -54,10 +54,10 @@ jobs: $env:SHARK_PACKAGE_VERSION=${{ env.package_version }} pip install -e . pip freeze -l - pyinstaller .\apps\shark_studio\shark_studio.spec + pyinstaller .\apps\shark_studio\shark_studio_apionly.spec mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe - + - name: Upload Release Assets id: upload-release-assets uses: dwenegar/upload-release-assets@v1 diff --git a/apps/shark_studio/api/initializers.py b/apps/shark_studio/api/initializers.py index c61e748c34..e05d43a5fc 100644 --- a/apps/shark_studio/api/initializers.py +++ b/apps/shark_studio/api/initializers.py @@ -34,9 +34,9 @@ def imports(): action="ignore", category=UserWarning, module="huggingface-hub" ) - import gradio # noqa: F401 + # import gradio # noqa: F401 - startup_timer.record("import gradio") + # startup_timer.record("import gradio") import apps.shark_studio.web.utils.globals as global_obj @@ -56,9 +56,8 @@ def initialize(): # existing temporary images there if they exist. Then we can import gradio. # It has to be in this order or gradio ignores what we've set up. - config_tmp() - # clear_tmp_mlir() - clear_tmp_imgs() + # config_tmp() + # clear_tmp_imgs() from apps.shark_studio.web.utils.file_utils import ( create_model_folders, @@ -67,8 +66,6 @@ def initialize(): # Create custom models folders if they don't exist create_model_folders() - import gradio as gr - # initialize_rest(reload_script_modules=False) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 502b290578..51cd9b7113 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -14,7 +14,6 @@ from random import randint - from apps.shark_studio.api.controlnet import control_adapter_map from apps.shark_studio.api.utils import parse_device from apps.shark_studio.web.utils.state import status_label @@ -30,6 +29,7 @@ from subprocess import check_output + EMPTY_SD_MAP = { "clip": None, "scheduler": None, @@ -114,11 +114,14 @@ def __init__( from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import ( SharkSDXLPipeline, ) + self.turbine_pipe = SharkSDXLPipeline self.dynamic_steps = False self.model_map = EMPTY_SDXL_MAP else: - from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline + from turbine_models.custom_models.sd_inference.sd_pipeline import ( + SharkSDPipeline, + ) self.turbine_pipe = SharkSDPipeline self.dynamic_steps = True @@ -209,6 +212,7 @@ def prepare_pipe( preprocessCKPT, save_irpa, ) + custom_weights = os.path.join( get_checkpoints_path("checkpoints"), safe_name(self.base_model_id.split("/")[-1]), @@ -223,14 +227,20 @@ def prepare_pipe( "diffusion_pytorch_model.safetensors", ) weights[key] = save_irpa(unet_weights_path, "unet.") - - elif key in ["clip", "prompt_encoder"]: - if not self.is_sdxl: + if key in ["mmdit"]: + mmdit_weights_path = os.path.join( + diffusers_weights_path, + "mmdit", + "diffusion_pytorch_model_fp16.safetensors", + ) + weights[key] = save_irpa(mmdit_weights_path, "mmdit.") + elif key in ["clip", "prompt_encoder", "text_encoder"]: + if not self.is_sdxl and not self.is_custom: sd1_path = os.path.join( diffusers_weights_path, "text_encoder", "model.safetensors" ) weights[key] = save_irpa(sd1_path, "text_encoder_model.") - else: + elif self.is_sdxl: clip_1_path = os.path.join( diffusers_weights_path, "text_encoder", "model.safetensors" ) @@ -243,7 +253,27 @@ def prepare_pipe( save_irpa(clip_1_path, "text_encoder_model_1."), save_irpa(clip_2_path, "text_encoder_model_2."), ] - + elif self.is_custom: + clip_g_path = os.path.join( + diffusers_weights_path, + "text_encoder", + "model.fp16.safetensors", + ) + clip_l_path = os.path.join( + diffusers_weights_path, + "text_encoder_2", + "model.fp16.safetensors", + ) + t5xxl_path = os.path.join( + diffusers_weights_path, + "text_encoder_3", + "model.fp16.safetensors", + ) + weights[key] = [ + save_irpa(clip_g_path, "clip_g.transformer."), + save_irpa(clip_l_path, "clip_l.transformer."), + save_irpa(t5xxl_path, "t5xxl.transformer."), + ] elif key in ["vae_decode"] and weights[key] is None: vae_weights_path = os.path.join( diffusers_weights_path, @@ -251,6 +281,7 @@ def prepare_pipe( "diffusion_pytorch_model.safetensors", ) weights[key] = save_irpa(vae_weights_path, "vae.") + progress(0.25, desc=f"Preparing pipeline for {self.ui_device}...") vmfbs, weights = self.sd_pipe.check_prepared( @@ -291,49 +322,6 @@ def generate_images( return img -def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()): - print("\n[LOG] Submitting Request...") - - for key in sd_kwargs: - if sd_kwargs[key] in [None, []]: - sd_kwargs[key] = None - if sd_kwargs[key] in ["None"]: - sd_kwargs[key] = "" - if key in ["steps", "height", "width", "batch_count", "batch_size"]: - sd_kwargs[key] = int(sd_kwargs[key]) - if key == "seed": - sd_kwargs[key] = int(sd_kwargs[key]) - - # TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function. - if not sd_kwargs["device"]: - gr.Warning("No device specified. Please specify a device.") - return None, "" - if sd_kwargs["height"] not in [512, 1024]: - gr.Warning("Height must be 512 or 1024. This is a temporary limitation.") - return None, "" - if sd_kwargs["height"] != sd_kwargs["width"]: - gr.Warning("Height and width must be the same. This is a temporary limitation.") - return None, "" - if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo": - if sd_kwargs["steps"] > 10: - gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.") - return None, "" - if sd_kwargs["guidance_scale"] > 3: - gr.Warning( - "sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise." - ) - return None, "" - if sd_kwargs["target_triple"] == "": - if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2]: - gr.Warning( - "Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx." - ) - return None, "" - - generated_imgs = yield from shark_sd_fn(**sd_kwargs) - return generated_imgs - - def shark_sd_fn( prompt, negative_prompt, @@ -359,7 +347,8 @@ def shark_sd_fn( controlnets: dict, embeddings: dict, seed_increment: str | int = 1, - progress=gr.Progress(), + output_type: str = "png", + # progress=gr.Progress(), ): sd_kwargs = locals() if not isinstance(sd_init_image, list): @@ -464,8 +453,8 @@ def shark_sd_fn( if submit_run_kwargs["seed"] in [-1, "-1"]: submit_run_kwargs["seed"] = randint(0, 4294967295) seed_increment = "random" - #print(f"\n[LOG] Random seed: {seed}") - progress(None, desc=f"Generating...") + # print(f"\n[LOG] Random seed: {seed}") + # progress(None, desc=f"Generating...") for current_batch in range(batch_count): start_time = time.time() @@ -479,13 +468,14 @@ def shark_sd_fn( # break # else: for batch in range(batch_size): - save_output_img( - out_imgs[batch], - seed, - sd_kwargs, - ) + if output_type == "png": + save_output_img( + out_imgs[batch], + seed, + sd_kwargs, + ) generated_imgs.extend(out_imgs) - + yield generated_imgs, status_label( "Stable Diffusion", current_batch + 1, batch_count, batch_size ) @@ -495,13 +485,56 @@ def shark_sd_fn( return (generated_imgs, "") +def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()): + print("\n[LOG] Submitting Request...") + + for key in sd_kwargs: + if sd_kwargs[key] in [None, []]: + sd_kwargs[key] = None + if sd_kwargs[key] in ["None"]: + sd_kwargs[key] = "" + if key in ["steps", "height", "width", "batch_count", "batch_size"]: + sd_kwargs[key] = int(sd_kwargs[key]) + if key == "seed": + sd_kwargs[key] = int(sd_kwargs[key]) + + # TODO: move these checks into the UI code so we don't have gradio warnings in a generalized dict input function. + if not sd_kwargs["device"]: + gr.Warning("No device specified. Please specify a device.") + return None, "" + if sd_kwargs["height"] not in [512, 1024]: + gr.Warning("Height must be 512 or 1024. This is a temporary limitation.") + return None, "" + if sd_kwargs["height"] != sd_kwargs["width"]: + gr.Warning("Height and width must be the same. This is a temporary limitation.") + return None, "" + if sd_kwargs["base_model_id"] == "stabilityai/sdxl-turbo": + if sd_kwargs["steps"] > 10: + gr.Warning("Max steps for sdxl-turbo is 10. 1 to 4 steps are recommended.") + return None, "" + if sd_kwargs["guidance_scale"] > 3: + gr.Warning( + "sdxl-turbo CFG scale should be less than 2.0 if using negative prompt, 0 otherwise." + ) + return None, "" + if sd_kwargs["target_triple"] == "": + if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2]: + gr.Warning( + "Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx." + ) + return None, "" + + generated_imgs = yield from shark_sd_fn(**sd_kwargs) + return generated_imgs + + def get_next_seed(seed, seed_increment: str | int = 10): if isinstance(seed_increment, int): - #print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}") + # print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}") return int(seed + seed_increment) elif seed_increment == "random": seed = randint(0, 4294967295) - #print(f"\n[LOG] Random seed: {seed}") + # print(f"\n[LOG] Random seed: {seed}") return seed diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 85a59ada35..b68ef5a5bb 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -63,9 +63,9 @@ def iree_target_map(device): } - def get_available_devices(): - return ['rocm', 'cpu'] + return ["rocm", "cpu"] + def get_devices_by_name(driver_name): device_list = [] @@ -94,7 +94,7 @@ def get_devices_by_name(driver_name): device_list.append(f"{device_name} => {driver_name}://{i}") return device_list - #set_iree_runtime_flags() + # set_iree_runtime_flags() available_devices = [] rocm_devices = get_devices_by_name("rocm") @@ -140,17 +140,14 @@ def get_devices_by_name(driver_name): break return available_devices + def clean_device_info(raw_device): # return appropriate device and device_id for consumption by Studio pipeline # Multiple devices only supported for vulkan and rocm (as of now). # default device must be selected for all others device_id = None - device = ( - raw_device - if "=>" not in raw_device - else raw_device.split("=>")[1].strip() - ) + device = raw_device if "=>" not in raw_device else raw_device.split("=>")[1].strip() if "://" in device: device, device_id = device.split("://") if len(device_id) <= 2: @@ -162,6 +159,7 @@ def clean_device_info(raw_device): device_id = 0 return device, device_id + def parse_device(device_str, target_override=""): rt_driver, device_id = clean_device_info(device_str) @@ -287,4 +285,4 @@ def get_all_devices(driver_name): # # Due to lack of support for multi-reduce, we always collapse reduction # # dims before dispatch formation right now. # iree_flags += ["--iree-flow-collapse-reduction-dims"] -# return iree_flags \ No newline at end of file +# return iree_flags diff --git a/apps/shark_studio/modules/shared_cmd_opts.py b/apps/shark_studio/modules/shared_cmd_opts.py index fe77022068..e3dffca102 100644 --- a/apps/shark_studio/modules/shared_cmd_opts.py +++ b/apps/shark_studio/modules/shared_cmd_opts.py @@ -597,7 +597,7 @@ def is_valid_file(arg): "--defaults", default="sdxl-turbo.json", type=str, - help="Path to the default API request .json file. Works for CLI and webui." + help="Path to the default API request .json file. Works for CLI and webui.", ) p.add_argument( diff --git a/apps/shark_studio/shark_studio_apionly.spec b/apps/shark_studio/shark_studio_apionly.spec new file mode 100644 index 0000000000..6e8b5a4f64 --- /dev/null +++ b/apps/shark_studio/shark_studio_apionly.spec @@ -0,0 +1,45 @@ +# -*- mode: python ; coding: utf-8 -*- +from apps.shark_studio.studio_imports_apionly import pathex, datas, hiddenimports + +binaries = [] + +block_cipher = None + +a = Analysis( + ['web/index.py'], + pathex=pathex, + binaries=binaries, + datas=datas, + hiddenimports=hiddenimports, + hookspath=[], + hooksconfig={}, + runtime_hooks=[], + excludes=[], + win_no_prefer_redirects=False, + win_private_assemblies=False, + cipher=block_cipher, + noarchive=False, +) +pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) + +exe = EXE( + pyz, + a.scripts, + a.binaries, + a.zipfiles, + a.datas, + [], + name='shark_sd3_server', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=False, + upx_exclude=[], + runtime_tmpdir=None, + console=True, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, +) diff --git a/apps/shark_studio/studio_imports.py b/apps/shark_studio/studio_imports.py index 3f7aa319ba..37a6ffa496 100644 --- a/apps/shark_studio/studio_imports.py +++ b/apps/shark_studio/studio_imports.py @@ -22,30 +22,25 @@ datas += copy_metadata("filelock") datas += copy_metadata("numpy") datas += copy_metadata("importlib_metadata") -datas += copy_metadata("omegaconf") datas += copy_metadata("safetensors") datas += copy_metadata("Pillow") datas += copy_metadata("sentencepiece") datas += copy_metadata("pyyaml") datas += copy_metadata("huggingface-hub") datas += copy_metadata("gradio") -datas += copy_metadata("scipy") datas += collect_data_files("torch") datas += collect_data_files("tokenizers") -datas += collect_data_files("accelerate") datas += collect_data_files("diffusers") datas += collect_data_files("transformers") datas += collect_data_files("gradio") datas += collect_data_files("gradio_client") datas += collect_data_files("iree", include_py_files=True) -datas += collect_data_files("shark", include_py_files=True) +datas += collect_data_files("shark-turbine", include_py_files=True) datas += collect_data_files("tqdm") -datas += collect_data_files("tkinter") datas += collect_data_files("sentencepiece") datas += collect_data_files("jsonschema") datas += collect_data_files("jsonschema_specifications") datas += collect_data_files("cpuinfo") -datas += collect_data_files("scipy", include_py_files=True) datas += [ ("web/ui/css/*", "ui/css"), ("web/ui/js/*", "ui/js"), @@ -54,7 +49,7 @@ # hidden imports for pyinstaller -hiddenimports = ["shark", "apps"] +hiddenimports = ["apps", "shark-turbine"] hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x] hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x] blacklist = ["tests", "convert"] @@ -65,4 +60,3 @@ ] hiddenimports += [x for x in collect_submodules("iree") if "test" not in x] hiddenimports += ["iree._runtime"] -hiddenimports += [x for x in collect_submodules("scipy") if "test" not in x] diff --git a/apps/shark_studio/studio_imports_apionly.py b/apps/shark_studio/studio_imports_apionly.py new file mode 100644 index 0000000000..72697960fd --- /dev/null +++ b/apps/shark_studio/studio_imports_apionly.py @@ -0,0 +1,46 @@ +from PyInstaller.utils.hooks import collect_data_files +from PyInstaller.utils.hooks import copy_metadata +from PyInstaller.utils.hooks import collect_submodules + +import sys + +sys.setrecursionlimit(sys.getrecursionlimit() * 5) + +# python path for pyinstaller +pathex = [ + ".", +] + +# datafiles for pyinstaller +datas = [] +datas += copy_metadata("torch") +datas += copy_metadata("tokenizers") +datas += copy_metadata("tqdm") +datas += copy_metadata("regex") +datas += copy_metadata("requests") +datas += copy_metadata("packaging") +datas += copy_metadata("filelock") +datas += copy_metadata("numpy") +datas += copy_metadata("importlib_metadata") +datas += copy_metadata("safetensors") +datas += copy_metadata("Pillow") +datas += copy_metadata("sentencepiece") +datas += copy_metadata("pyyaml") +datas += copy_metadata("huggingface-hub") +datas += copy_metadata("gradio") +datas += collect_data_files("torch") +datas += collect_data_files("tokenizers") +datas += collect_data_files("diffusers") +datas += collect_data_files("transformers") +datas += collect_data_files("iree", include_py_files=True) +datas += collect_data_files("tqdm") +datas += collect_data_files("jsonschema") +datas += collect_data_files("jsonschema_specifications") +datas += collect_data_files("cpuinfo") + + +# hidden imports for pyinstaller +hiddenimports = ["apps", "shark-turbine"] +hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x] +hiddenimports += [x for x in collect_submodules("iree") if "test" not in x] +hiddenimports += ["iree._runtime"] diff --git a/apps/shark_studio/web/api/compat.py b/apps/shark_studio/web/api/compat.py index b5e81f2e9a..1a17f7919e 100644 --- a/apps/shark_studio/web/api/compat.py +++ b/apps/shark_studio/web/api/compat.py @@ -20,9 +20,6 @@ from apps.shark_studio.modules.shared_cmd_opts import cmd_opts -# from sdapi_v1 import shark_sd_api -from apps.shark_studio.api.llm import llm_chat_api - def decode_base64_to_image(encoding): if encoding.startswith("http://") or encoding.startswith("https://"): @@ -183,50 +180,8 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.app = app self.queue_lock = queue_lock api_middleware(self.app) + # self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["POST"]) - # self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["POST"]) - # self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["POST"]) - # self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse) - # self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse) - # self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse) - # self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse) - # self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) - # self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) - # self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) - # self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) - # self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) - # self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) - # self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) - # self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) - # self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) - # self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) - # self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) - # self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) - # self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) - # self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) - # self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) - # self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) - # self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) - # self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) - # self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) - # self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) - # self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) - # self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) - # self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) - # self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) - # self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) - # self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) - # self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) - # self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) - - # chat APIs needed for compatibility with multiple extensions using OpenAI API - self.add_api_route("/v1/chat/completions", llm_chat_api, methods=["POST"]) - self.add_api_route("/v1/completions", llm_chat_api, methods=["POST"]) - self.add_api_route("/chat/completions", llm_chat_api, methods=["POST"]) - self.add_api_route("/completions", llm_chat_api, methods=["POST"]) - self.add_api_route( - "/v1/engines/codegen/completions", llm_chat_api, methods=["POST"] - ) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -234,27 +189,6 @@ def __init__(self, app: FastAPI, queue_lock: Lock): def add_api_route(self, path: str, endpoint, **kwargs): return self.app.add_api_route(path, endpoint, **kwargs) - # def refresh_checkpoints(self): - # with self.queue_lock: - # studio_data.refresh_checkpoints() - - # def refresh_vae(self): - # with self.queue_lock: - # studio_data.refresh_vae_list() - - # def unloadapi(self): - # unload_model_weights() - - # return {} - - # def reloadapi(self): - # reload_model_weights() - - # return {} - - # def skip(self): - # studio.state.skip() - def launch(self, server_name, port, root_path): self.app.include_router(self.router) uvicorn.run( diff --git a/apps/shark_studio/web/api/sd.py b/apps/shark_studio/web/api/sd.py index 8b13789179..0543a033ba 100644 --- a/apps/shark_studio/web/api/sd.py +++ b/apps/shark_studio/web/api/sd.py @@ -1 +1,115 @@ +import base64 +from fastapi import FastAPI + +from io import BytesIO +from PIL import Image +from pydantic import BaseModel, Field +from fastapi.exceptions import HTTPException + +from apps.shark_studio.api.sd import shark_sd_fn + +sdapi = FastAPI() + + +class GenerationInputData(BaseModel): + prompt: list = [""] + negative_prompt: list = [""] + hf_model_id: str | None = None + height: int = Field(default=512, ge=128, le=1024, multiple_of=8) + width: int = Field(default=512, ge=128, le=1024, multiple_of=8) + sampler_name: str = "EulerDiscrete" + cfg_scale: float = Field(default=7.5, ge=1) + steps: int = Field(default=20, ge=1, le=100) + seed: int = Field(default=-1) + n_iter: int = Field(default=1) + config: dict = None + + +class GenerationResponseData(BaseModel): + images: list[str] = Field(description="Generated images, Base64 encoded") + properties: dict = {} + info: str + + +def encode_pil_to_base64(images: list[Image.Image]): + encoded_imgs = [] + for image in images: + with BytesIO() as output_bytes: + image.save(output_bytes, format="PNG") + bytes_data = output_bytes.getvalue() + encoded_imgs.append(base64.b64encode(bytes_data)) + return encoded_imgs + + +def decode_base64_to_image(encoding: str): + if encoding.startswith("data:image/"): + encoding = encoding.split(";", 1)[1].split(",", 1)[1] + try: + image = Image.open(BytesIO(base64.b64decode(encoding))) + return image + except Exception as err: + print(err) + raise HTTPException(status_code=400, detail="Invalid encoded image") + + +@sdapi.post( + "/v1/txt2img", + summary="Does text to image generation", + response_model=GenerationResponseData, +) +def txt2img_api(InputData: GenerationInputData): + model_id = ( + InputData.hf_model_id or "stabilityai/stable-diffusion-3-medium-diffusers" + ) + scheduler = "FlowEulerDiscrete" + print( + f"Prompt: {InputData.prompt}, " + f"Negative Prompt: {InputData.negative_prompt}, " + f"Seed: {InputData.seed}," + f"Model: {model_id}, " + f"Scheduler: {scheduler}. " + ) + if not getattr(InputData, "config"): + InputData.config = { + "precision": "fp16", + "device": "rocm", + "target_triple": "gfx1150", + } + + res = shark_sd_fn( + InputData.prompt, + InputData.negative_prompt, + None, + InputData.height, + InputData.width, + InputData.steps, + None, + InputData.cfg_scale, + InputData.seed, + custom_vae=None, + batch_count=InputData.n_iter, + batch_size=1, + scheduler=scheduler, + base_model_id=model_id, + custom_weights=None, + precision=InputData.config["precision"], + device=InputData.config["device"], + target_triple=InputData.config["target_triple"], + output_type="pil", + ondemand=False, + compiled_pipeline=False, + resample_type=None, + controlnets=[], + embeddings=[], + ) + + # Since we're not streaming we just want the last generator result + for items_so_far in res: + items = items_so_far + + return { + "images": encode_pil_to_base64(items[0]), + "parameters": {}, + "info": items[1], + } diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py index f32bda123b..b023c756c3 100644 --- a/apps/shark_studio/web/index.py +++ b/apps/shark_studio/web/index.py @@ -32,13 +32,15 @@ def create_api(app): def api_only(): - from fastapi import FastAPI from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + from apps.shark_studio.web.api.sd import sdapi + from fastapi import FastAPI initialize.initialize() app = FastAPI() initialize.setup_middleware(app) + app.mount("/sdapi/", sdapi) api = create_api(app) # from modules import script_callbacks @@ -56,6 +58,7 @@ def api_only(): def launch_webui(address): from tkinter import Tk import webview + import gradio as gr window = Tk() @@ -83,7 +86,7 @@ def webui(): launch_api = cmd_opts.api initialize.initialize() - #from ui.chat import chat_element + # from ui.chat import chat_element from ui.sd import sd_element from ui.outputgallery import outputgallery_element @@ -216,7 +219,8 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): if __name__ == "__main__": from apps.shark_studio.modules.shared_cmd_opts import cmd_opts - if cmd_opts.webui == False: - api_only() - else: - webui() + api_only() + # if cmd_opts.webui == False: + # api_only() + # else: + # webui() diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index 1f987c3e81..ec314498e4 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -51,7 +51,7 @@ # "stabilityai/stable-diffusion-2-1-base", # "stabilityai/stable-diffusion-2-1", # "stabilityai/stable-diffusion-xl-base-1.0", - #"stabilityai/sdxl-turbo", + # "stabilityai/sdxl-turbo", ] sd_default_models.extend(get_checkpoints(model_type="scripts")) @@ -154,7 +154,9 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str): elif os.path.exists(os.path.join(get_configs_path(), load_sd_config)): config = os.path.join(get_configs_path(), load_sd_config) else: - print("Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config.") + print( + "Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config." + ) config = sd_json new_sd_config = none_to_str_none(json.loads(view_json_file(config))) if sd_json: @@ -284,6 +286,7 @@ def base_model_changed(base_model_id): new_steps, ] + init_config = global_obj.get_init_config() init_config = none_to_str_none(json.loads(view_json_file(init_config))) @@ -307,15 +310,17 @@ def base_model_changed(base_model_id): show_copy_button=True, ) with gr.Accordion( - label="\U0001F4D0\U0000FE0F Advanced Settings", open=False + label="\U0001F4D0\U0000FE0F Advanced Settings", open=False ): - with gr.Accordion( - label="Device Settings", open=False - ): + with gr.Accordion(label="Device Settings", open=False): device = gr.Dropdown( elem_id="device", label="Device", - value=init_config["device"] if init_config["device"] else "rocm", + value=( + init_config["device"] + if init_config["device"] + else "rocm" + ), choices=global_obj.get_device_list(), allow_custom_value=True, ) @@ -347,7 +352,7 @@ def base_model_changed(base_model_id): value=512, step=512, label="\U00002195\U0000FE0F Height", - interactive=False, # DEMO + interactive=False, # DEMO visible=False, # DEMO ) width = gr.Slider( @@ -356,10 +361,10 @@ def base_model_changed(base_model_id): value=512, step=512, label="\U00002194\U0000FE0F Width", - interactive=False, # DEMO + interactive=False, # DEMO visible=False, # DEMO ) - + with gr.Accordion( label="\U0001F9EA\U0000FE0F Input Image Processing", open=False, @@ -379,7 +384,9 @@ def base_model_changed(base_model_id): allow_custom_value=True, ) with gr.Row(): - sd_model_info = f"Checkpoint Path: {str(get_checkpoints_path())}" + sd_model_info = ( + f"Checkpoint Path: {str(get_checkpoints_path())}" + ) base_model_id = gr.Dropdown( label="\U000026F0\U0000FE0F Base Model", info="Select or enter HF model ID", @@ -413,7 +420,7 @@ def base_model_changed(base_model_id): ) guidance_scale = gr.Slider( 0, - 5, #DEMO + 5, # DEMO value=4, step=0.1, label="\U0001F5C3\U0000FE0F CFG Scale", @@ -444,9 +451,7 @@ def base_model_changed(base_model_id): visible=False, # DEMO ) with gr.Row(elem_classes=["fill"], visible=False): - Path(get_configs_path()).mkdir( - parents=True, exist_ok=True - ) + Path(get_configs_path()).mkdir(parents=True, exist_ok=True) write_default_sd_configs(get_configs_path()) default_config_file = global_obj.get_init_config() sd_json = gr.JSON( @@ -463,9 +468,7 @@ def base_model_changed(base_model_id): visible=False, ) with gr.Row(): - save_sd_config = gr.Button( - value="Save Config", size="sm" - ) + save_sd_config = gr.Button(value="Save Config", size="sm") clear_sd_config = gr.ClearButton( value="Clear Config", size="sm", @@ -514,7 +517,11 @@ def base_model_changed(base_model_id): label=f"Standalone LoRA Weights", info=sd_lora_info, elem_id="lora_weights", - value=init_config["embeddings"][0] if (len(init_config["embeddings"].keys()) > 1) else "None", + value=( + init_config["embeddings"][0] + if (len(init_config["embeddings"].keys()) > 1) + else "None" + ), multiselect=True, choices=[] + get_checkpoints("lora"), scale=2, diff --git a/apps/shark_studio/web/utils/file_utils.py b/apps/shark_studio/web/utils/file_utils.py index b83b989ec4..ce40d3abb6 100644 --- a/apps/shark_studio/web/utils/file_utils.py +++ b/apps/shark_studio/web/utils/file_utils.py @@ -100,6 +100,7 @@ def get_checkpoints(model_type="checkpoints"): ckpt_files.extend(files) return sorted(ckpt_files, key=str.casefold) + def get_configs(): return sorted( [ diff --git a/apps/shark_studio/web/utils/globals.py b/apps/shark_studio/web/utils/globals.py index 963cef3d5f..ec5ce68e88 100644 --- a/apps/shark_studio/web/utils/globals.py +++ b/apps/shark_studio/web/utils/globals.py @@ -3,17 +3,21 @@ from apps.shark_studio.modules.shared_cmd_opts import cmd_opts import os from apps.shark_studio.web.utils.file_utils import get_configs_path + """ The global objects include SD pipeline and config. Maintaining the global objects would avoid creating extra pipeline objects when switching modes. Also we could avoid memory leak when switching models by clearing the cache. """ + + def view_json_file(file_path): content = "" with open(file_path, "r") as fopen: content = fopen.read() return content + def _init(): global _sd_obj global _llm_obj @@ -95,6 +99,7 @@ def get_device_list(): global _devices return _devices + def get_init_config(): global _init_config if os.path.exists(cmd_opts.defaults): @@ -102,10 +107,13 @@ def get_init_config(): elif os.path.exists(os.path.join(get_configs_path(), cmd_opts.defaults)): _init_config = os.path.join(get_configs_path(), cmd_opts.defaults) else: - print("Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config.") + print( + "Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config." + ) _init_config = os.path.join(get_configs_path(), "sdxl-turbo.json") return _init_config + def get_sd_status(): global _sd_obj return _sd_obj.status diff --git a/requirements.txt b/requirements.txt index a6e5d75a50..358fe73181 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ +-r https://raw.githubusercontent.com/llvm/torch-mlir/main/requirements.txt +-r https://raw.githubusercontent.com/llvm/torch-mlir/main/torchvision-requirements.txt -f https://download.pytorch.org/whl/nightly/cpu -f https://iree.dev/pip-release-links.html --pre @@ -5,40 +7,19 @@ setuptools wheel - -torch==2.3.0 shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main -turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-unify-sd#subdirectory=models +turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@merge_punet_sdxl#subdirectory=models diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark -brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b - -# SHARK Runner -tqdm - -# SHARK Downloader -google-cloud-storage - -# Testing -pytest Pillow -parameterized - -# Add transformers, diffusers and scipy since it most commonly used -#accelerate is now required for diffusers import from ckpt. -accelerate -scipy -transformers==4.37.1 -torchsde # Required for Stable Diffusion SDE schedulers. +transformers==4.43.3 ftfy -gradio==4.29.0 -altair -omegaconf -# 0.3.2 doesn't have binaries for arm64 -safetensors==0.3.1 +safetensors py-cpuinfo pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions mpmath==1.3.0 -optimum + +# Testing +pytest # Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors pefile diff --git a/rest_api_tests/sd3api_test.py b/rest_api_tests/sd3api_test.py new file mode 100644 index 0000000000..5cb6cf0f0a --- /dev/null +++ b/rest_api_tests/sd3api_test.py @@ -0,0 +1,77 @@ +import requests +from pydantic import BaseModel, Field +import json + + +def view_json_file(file_path): + content = "" + with open(file_path, "r") as fopen: + content = fopen.read() + return content + + +# Define the URL of the REST API endpoint +api_url = "http://127.0.0.1:8080/sdapi/v1/txt2img/" # Replace with your actual API URL + + +class GenerationInputData(BaseModel): + prompt: list = [""] + negative_prompt: list = [""] + hf_model_id: str | None = None + height: int = Field(default=512, ge=128, le=1024, multiple_of=8) + width: int = Field(default=512, ge=128, le=1024, multiple_of=8) + sampler_name: str = "EulerDiscrete" + cfg_scale: float = Field(default=7.5, ge=1) + steps: int = Field(default=20, ge=1, le=100) + seed: int = Field(default=-1) + n_iter: int = Field(default=1) + config: dict = None + + +# Create an instance of GenerationInputData with example arguments +data = GenerationInputData( + prompt=[ + "A phoenix made of diamond, black background, dream sequence, rising from coals" + ], + negative_prompt=[ + "cropped, cartoon, lowres, low quality, black and white, bad scan, pixelated" + ], + hf_model_id="shark_sd3.py", + height=512, + width=512, + sampler_name="EulerDiscrete", + cfg_scale=7.5, + steps=20, + seed=-1, + n_iter=1, + config=json.loads(view_json_file("../configs/sd3_phoenix_npu.json")), +) + +# Convert the data to a dictionary +data_dict = data.dict() + +# Optional: Define headers if needed (e.g., for authentication) +headers = { + "User-Agent": "PythonTest", + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate, br", +} + + +def test_post_request(url, data, headers=None): + try: + # Send a POST request to the API endpoint + response = requests.post(url, json=data, headers=headers) + + # Print the status code and response content + print(f"Status Code: {response.status_code}") + print("Response Content:") + # print(response.json()) # Print the JSON response + + except requests.RequestException as e: + # Handle any exceptions that occur during the request + print(f"An error occurred: {e}") + + +# Run the test +test_post_request(api_url, data_dict, headers) diff --git a/setup_venv.ps1 b/setup_venv.ps1 index 651b19421a..8465c5599f 100644 --- a/setup_venv.ps1 +++ b/setup_venv.ps1 @@ -87,9 +87,8 @@ if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark.venv\} else {python -m venv .\shark.venv\} .\shark.venv\Scripts\activate python -m pip install --upgrade pip -pip install wheel +pip install https://github.com/nod-ai/SRT/releases/download/candidate-20240619.291/iree_compiler-20240619.291-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240619.291/iree_runtime-20240619.291-cp311-cp311-win_amd64.whl pip install --pre -r requirements.txt -pip install https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_compiler-20240602.283-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_runtime-20240602.283-cp311-cp311-win_amd64.whl pip install -e . Write-Host "Source your venv with ./shark.venv/Scripts/activate" diff --git a/webui_requirements.txt b/webui_requirements.txt new file mode 100644 index 0000000000..8e88377cf8 --- /dev/null +++ b/webui_requirements.txt @@ -0,0 +1 @@ +gradio==4.29.0