diff --git a/apps/shark_studio/api/controlnet.py b/apps/shark_studio/api/controlnet.py index ea8cdf0cc9..2c8a8b566b 100644 --- a/apps/shark_studio/api/controlnet.py +++ b/apps/shark_studio/api/controlnet.py @@ -1,4 +1,15 @@ # from turbine_models.custom_models.controlnet import control_adapter, preprocessors +import os +import PIL +import numpy as np +from apps.shark_studio.web.utils.file_utils import ( + get_generated_imgs_path, +) +from datetime import datetime +from PIL import Image +from gradio.components.image_editor import ( + EditorValue, +) class control_adapter: @@ -29,20 +40,12 @@ def export_controlnet_model(model_keyword): control_adapter_map = { "sd15": { "canny": {"initializer": control_adapter.export_control_adapter_model}, - "openpose": { - "initializer": control_adapter.export_control_adapter_model - }, - "scribble": { - "initializer": control_adapter.export_control_adapter_model - }, - "zoedepth": { - "initializer": control_adapter.export_control_adapter_model - }, + "openpose": {"initializer": control_adapter.export_control_adapter_model}, + "scribble": {"initializer": control_adapter.export_control_adapter_model}, + "zoedepth": {"initializer": control_adapter.export_control_adapter_model}, }, "sdxl": { - "canny": { - "initializer": control_adapter.export_xl_control_adapter_model - }, + "canny": {"initializer": control_adapter.export_xl_control_adapter_model}, }, } preprocessor_model_map = { @@ -57,78 +60,48 @@ class PreprocessorModel: def __init__( self, hf_model_id, - device, + device="cpu", ): - self.model = None + self.model = hf_model_id + self.device = device - def compile(self, device): + def compile(self): print("compile not implemented for preprocessor.") return def run(self, inputs): print("run not implemented for preprocessor.") - return + return inputs -def cnet_preview(model, input_img, stencils, images, preprocessed_hints): - if isinstance(input_image, PIL.Image.Image): - img_dict = { - "background": None, - "layers": [None], - "composite": input_image, - } - input_image = EditorValue(img_dict) - images[index] = input_image - if model: - stencils[index] = model +def cnet_preview(model, input_image): + curr_datetime = datetime.now().strftime("%Y-%m-%d.%H-%M-%S") + control_imgs_path = os.path.join(get_generated_imgs_path(), "control_hints") + if not os.path.exists(control_imgs_path): + os.mkdir(control_imgs_path) + img_dest = os.path.join(control_imgs_path, model + curr_datetime + ".png") match model: case "canny": - canny = CannyDetector() + canny = PreprocessorModel("canny") result = canny( - np.array(input_image["composite"]), + np.array(input_image), 100, 200, ) - preprocessed_hints[index] = Image.fromarray(result) - return ( - Image.fromarray(result), - stencils, - images, - preprocessed_hints, - ) + Image.fromarray(result).save(fp=img_dest) + return result, img_dest case "openpose": - openpose = OpenposeDetector() - result = openpose(np.array(input_image["composite"])) - preprocessed_hints[index] = Image.fromarray(result[0]) - return ( - Image.fromarray(result[0]), - stencils, - images, - preprocessed_hints, - ) + openpose = PreprocessorModel("openpose") + result = openpose(np.array(input_image)) + Image.fromarray(result[0]).save(fp=img_dest) + return result, img_dest case "zoedepth": - zoedepth = ZoeDetector() - result = zoedepth(np.array(input_image["composite"])) - preprocessed_hints[index] = Image.fromarray(result) - return ( - Image.fromarray(result), - stencils, - images, - preprocessed_hints, - ) + zoedepth = PreprocessorModel("ZoeDepth") + result = zoedepth(np.array(input_image)) + Image.fromarray(result).save(fp=img_dest) + return result, img_dest case "scribble": - preprocessed_hints[index] = input_image["composite"] - return ( - input_image["composite"], - stencils, - images, - preprocessed_hints, - ) + input_image.save(fp=img_dest) + return input_image, img_dest case _: - preprocessed_hints[index] = None - return ( - None, - stencils, - images, - preprocessed_hints, - ) + return None, None diff --git a/apps/shark_studio/api/initializers.py b/apps/shark_studio/api/initializers.py index bbb273354c..ef9816cfca 100644 --- a/apps/shark_studio/api/initializers.py +++ b/apps/shark_studio/api/initializers.py @@ -1,14 +1,17 @@ import importlib -import logging import os import signal import sys -import re import warnings import json from threading import Thread from apps.shark_studio.modules.timer import startup_timer +from apps.shark_studio.web.utils.tmp_configs import ( + config_tmp, + clear_tmp_mlir, + clear_tmp_imgs, +) def imports(): @@ -18,9 +21,8 @@ def imports(): warnings.filterwarnings( action="ignore", category=DeprecationWarning, module="torch" ) - warnings.filterwarnings( - action="ignore", category=UserWarning, module="torchvision" - ) + warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") + warnings.filterwarnings(action="ignore", category=UserWarning, module="torch") import gradio # noqa: F401 @@ -34,20 +36,28 @@ def imports(): from apps.shark_studio.modules import ( img_processing, ) # noqa: F401 - from apps.shark_studio.modules.schedulers import scheduler_model_map startup_timer.record("other imports") def initialize(): configure_sigint_handler() + # Setup to use shark_tmp for gradio's temporary image files and clear any + # existing temporary images there if they exist. Then we can import gradio. + # It has to be in this order or gradio ignores what we've set up. + + config_tmp() + clear_tmp_mlir() + clear_tmp_imgs() + + from apps.shark_studio.web.utils.file_utils import ( + create_checkpoint_folders, + ) - # from apps.shark_studio.modules import modelloader - # modelloader.cleanup_models() + # Create custom models folders if they don't exist + create_checkpoint_folders() - # from apps.shark_studio.modules import sd_models - # sd_models.setup_model() - # startup_timer.record("setup SD model") + import gradio as gr # initialize_rest(reload_script_modules=False) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index a209d8d1ba..852f5eff58 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -4,7 +4,7 @@ get_iree_compiled_module, load_vmfb_using_mmap, ) -from apps.shark_studio.api.utils import get_resource_path +from apps.shark_studio.web.utils.file_utils import get_resource_path import iree.runtime as ireert from itertools import chain import gc diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index a601a068f7..2822d83829 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -1,90 +1,78 @@ +import gc +import torch +import time +import os +import json +import numpy as np +from tqdm.auto import tqdm + +from pathlib import Path +from random import randint from turbine_models.custom_models.sd_inference import clip, unet, vae -from shark.iree_utils.compile_utils import get_iree_compiled_module -from apps.shark_studio.api.utils import get_resource_path from apps.shark_studio.api.controlnet import control_adapter_map from apps.shark_studio.web.utils.state import status_label +from apps.shark_studio.web.utils.file_utils import ( + safe_name, + get_resource_path, + get_checkpoints_path, +) from apps.shark_studio.modules.pipeline import SharkPipelineBase -import iree.runtime as ireert -import gc -import torch -import gradio as gr +from apps.shark_studio.modules.schedulers import get_schedulers +from apps.shark_studio.modules.prompt_encoding import ( + get_weighted_text_embeddings, +) +from apps.shark_studio.modules.img_processing import ( + resize_stencil, + save_output_img, + resamplers, + resampler_list, +) + +from apps.shark_studio.modules.ckpt_processing import ( + preprocessCKPT, + process_custom_pipe_weights, +) +from transformers import CLIPTokenizer +from diffusers.image_processor import VaeImageProcessor sd_model_map = { - "CompVis/stable-diffusion-v1-4": { - "clip": { - "initializer": clip.export_clip_model, - "max_tokens": 64, - }, - "vae_encode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, - "unet": { - "initializer": unet.export_unet_model, - "max_tokens": 512, - }, - "vae_decode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, + "clip": { + "initializer": clip.export_clip_model, + "ireec_flags": [ + "--iree-flow-collapse-reduction-dims", + "--iree-opt-const-expr-hoisting=False", + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", + "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))", + ], }, - "runwayml/stable-diffusion-v1-5": { - "clip": { - "initializer": clip.export_clip_model, - "max_tokens": 64, - }, - "vae_encode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, - "unet": { - "initializer": unet.export_unet_model, - "max_tokens": 512, - }, - "vae_decode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, + "vae_encode": { + "initializer": vae.export_vae_model, + "ireec_flags": [ + "--iree-flow-collapse-reduction-dims", + "--iree-opt-const-expr-hoisting=False", + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", + "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))", + "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))", + ], }, - "stabilityai/stable-diffusion-2-1-base": { - "clip": { - "initializer": clip.export_clip_model, - "max_tokens": 64, - }, - "vae_encode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, - "unet": { - "initializer": unet.export_unet_model, - "max_tokens": 512, - }, - "vae_decode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, + "unet": { + "initializer": unet.export_unet_model, + "ireec_flags": [ + "--iree-flow-collapse-reduction-dims", + "--iree-opt-const-expr-hoisting=False", + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", + "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))", + ], }, - "stabilityai/stable_diffusion-xl-1.0": { - "clip_1": { - "initializer": clip.export_clip_model, - "max_tokens": 64, - }, - "clip_2": { - "initializer": clip.export_clip_model, - "max_tokens": 64, - }, - "vae_encode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, - "unet": { - "initializer": unet.export_unet_model, - "max_tokens": 512, - }, - "vae_decode": { - "initializer": vae.export_vae_model, - "max_tokens": 64, - }, + "vae_decode": { + "initializer": vae.export_vae_model, + "ireec_flags": [ + "--iree-flow-collapse-reduction-dims", + "--iree-opt-const-expr-hoisting=False", + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", + "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))", + "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))", + ], }, } @@ -95,38 +83,343 @@ class StableDiffusion(SharkPipelineBase): # aims to be as general as possible, and the class will infer and compile # a list of necessary modules or a combined "pipeline module" for a # specified job based on the inference task. - # - # custom_model_ids: a dict of submodel + HF ID pairs for custom submodels. - # e.g. {"vae_decode": "madebyollin/sdxl-vae-fp16-fix"} - # - # embeddings: a dict of embedding checkpoints or model IDs to use when - # initializing the compiled modules. def __init__( self, - base_model_id: str = "runwayml/stable-diffusion-v1-5", - height: int = 512, - width: int = 512, - precision: str = "fp16", - device: str = None, - custom_model_map: dict = {}, - embeddings: dict = {}, + base_model_id, + height: int, + width: int, + batch_size: int, + precision: str, + device: str, + custom_vae: str = None, + num_loras: int = 0, import_ir: bool = True, + is_controlled: bool = False, ): - super().__init__(sd_model_map[base_model_id], device, import_ir) - self.base_model_id = base_model_id - self.device = device + self.model_max_length = 77 + self.batch_size = batch_size self.precision = precision - self.iree_module_dict = None - self.get_compiled_map() + self.dtype = torch.float16 if precision == "fp16" else torch.float32 + self.height = height + self.width = width + self.scheduler_obj = {} + static_kwargs = { + "pipe": { + "external_weights": "safetensors", + }, + "clip": {"hf_model_name": base_model_id}, + "unet": { + "hf_model_name": base_model_id, + "unet_model": unet.UnetModel( + hf_model_name=base_model_id, hf_auth_token=None + ), + "batch_size": batch_size, + # "is_controlled": is_controlled, + # "num_loras": num_loras, + "height": height, + "width": width, + "precision": precision, + "max_length": self.model_max_length, + }, + "vae_encode": { + "hf_model_name": base_model_id, + "vae_model": vae.VaeModel( + hf_model_name=base_model_id, + custom_vae=custom_vae, + ), + "batch_size": batch_size, + "height": height, + "width": width, + "precision": precision, + }, + "vae_decode": { + "hf_model_name": base_model_id, + "vae_model": vae.VaeModel( + hf_model_name=base_model_id, + custom_vae=custom_vae, + ), + "batch_size": batch_size, + "height": height, + "width": width, + "precision": precision, + }, + } + super().__init__(sd_model_map, base_model_id, static_kwargs, device, import_ir) + pipe_id_list = [ + safe_name(base_model_id), + str(batch_size), + str(static_kwargs["unet"]["max_length"]), + f"{str(height)}x{str(width)}", + precision, + ] + if num_loras > 0: + pipe_id_list.append(str(num_loras) + "lora") + if is_controlled: + pipe_id_list.append("controlled") + if custom_vae: + pipe_id_list.append(custom_vae) + self.pipe_id = "_".join(pipe_id_list) + print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.") + del static_kwargs + gc.collect() + + def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): + print(f"\n[LOG] Preparing pipeline...") + self.is_img2img = is_img2img + self.schedulers = get_schedulers(self.base_model_id) + + self.weights_path = os.path.join( + get_checkpoints_path(), self.safe_name(self.base_model_id) + ) + if not os.path.exists(self.weights_path): + os.mkdir(self.weights_path) + + for model in adapters: + self.model_map[model] = adapters[model] + + for submodel in self.static_kwargs: + if custom_weights: + custom_weights_params, _ = process_custom_pipe_weights(custom_weights) + if submodel not in ["clip", "clip2"]: + self.static_kwargs[submodel][ + "external_weight_file" + ] = custom_weights_params + else: + self.static_kwargs[submodel]["external_weight_path"] = os.path.join( + self.weights_path, submodel + ".safetensors" + ) + else: + self.static_kwargs[submodel]["external_weight_path"] = os.path.join( + self.weights_path, submodel + ".safetensors" + ) + + self.get_compiled_map(pipe_id=self.pipe_id) + print("\n[LOG] Pipeline successfully prepared for runtime.") + return + + def encode_prompts_weight( + self, + prompt, + negative_prompt, + do_classifier_free_guidance=True, + ): + # Encodes the prompt into text encoder hidden states. + self.load_submodels(["clip"]) + self.tokenizer = CLIPTokenizer.from_pretrained( + self.base_model_id, + subfolder="tokenizer", + ) + clip_inf_start = time.time() + + text_embeddings, uncond_embeddings = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + ) + + if do_classifier_free_guidance: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + pad = (0, 0) * (len(text_embeddings.shape) - 2) + pad = pad + ( + 0, + self.static_kwargs["unet"]["max_length"] - text_embeddings.shape[1], + ) + text_embeddings = torch.nn.functional.pad(text_embeddings, pad) + + # SHARK: Report clip inference time + clip_inf_time = (time.time() - clip_inf_start) * 1000 + if self.ondemand: + self.unload_submodels(["clip"]) + gc.collect() + print(f"\n[LOG] Clip Inference time (ms) = {clip_inf_time:.3f}") + + return text_embeddings.numpy().astype(np.float16) + + def prepare_latents( + self, + generator, + num_inference_steps, + image, + strength, + ): + noise = torch.randn( + ( + self.batch_size, + 4, + self.height // 8, + self.width // 8, + ), + generator=generator, + dtype=self.dtype, + ).to("cpu") + + self.scheduler.set_timesteps(num_inference_steps) + if self.is_img2img: + init_timestep = min( + int(num_inference_steps * strength), num_inference_steps + ) + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + latents = self.encode_image(image) + latents = self.scheduler.add_noise(latents, noise, timesteps[0].repeat(1)) + return latents, [timesteps] + else: + self.scheduler.is_scale_input_called = True + latents = noise * self.scheduler.init_noise_sigma + return latents, self.scheduler.timesteps + + def encode_image(self, input_image): + self.load_submodels(["vae_encode"]) + vae_encode_start = time.time() + latents = self.run("vae_encode", input_image) + vae_inf_time = (time.time() - vae_encode_start) * 1000 + if self.ondemand: + self.unload_submodels(["vae_encode"]) + print(f"\n[LOG] VAE Encode Inference time (ms): {vae_inf_time:.3f}") + + return latents + + def produce_img_latents( + self, + latents, + text_embeddings, + guidance_scale, + total_timesteps, + cpu_scheduling, + mask=None, + masked_image_latents=None, + return_all_latents=False, + ): + # self.status = SD_STATE_IDLE + step_time_sum = 0 + latent_history = [latents] + text_embeddings = torch.from_numpy(text_embeddings).to(self.dtype) + text_embeddings_numpy = text_embeddings.detach().numpy() + guidance_scale = torch.Tensor([guidance_scale]).to(self.dtype) + self.load_submodels(["unet"]) + for i, t in tqdm(enumerate(total_timesteps)): + step_start_time = time.time() + timestep = torch.tensor([t]).to(self.dtype).detach().numpy() + latent_model_input = self.scheduler.scale_model_input(latents, t).to( + self.dtype + ) + if mask is not None and masked_image_latents is not None: + latent_model_input = torch.cat( + [ + torch.from_numpy(np.asarray(latent_model_input)).to(self.dtype), + mask, + masked_image_latents, + ], + dim=1, + ).to(self.dtype) + if cpu_scheduling: + latent_model_input = latent_model_input.detach().numpy() + + # Profiling Unet. + # profile_device = start_profiling(file_path="unet.rdc") + noise_pred = self.run( + "unet", + [ + latent_model_input, + timestep, + text_embeddings_numpy, + guidance_scale, + ], + ) + # end_profiling(profile_device) + + if cpu_scheduling: + noise_pred = torch.from_numpy(noise_pred.to_host()) + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + else: + latents = self.run("scheduler_step", (noise_pred, t, latents)) + + latent_history.append(latents) + step_time = (time.time() - step_start_time) * 1000 + # print( + # f"\n [LOG] step = {i} | timestep = {t} | time = {step_time:.2f}ms" + # ) + step_time_sum += step_time + + # if self.status == SD_STATE_CANCEL: + # break - def prepare_pipeline(self, scheduler, custom_model_map): - return None + if self.ondemand: + self.unload_submodels(["unet"]) + gc.collect() + + avg_step_time = step_time_sum / len(total_timesteps) + print(f"\n[LOG] Average step time: {avg_step_time}ms/it") + + if not return_all_latents: + return latents + all_latents = torch.cat(latent_history, dim=0) + return all_latents + + def decode_latents(self, latents, cpu_scheduling=True): + latents_numpy = latents.to(self.dtype) + if cpu_scheduling: + latents_numpy = latents.detach().numpy() + + # profile_device = start_profiling(file_path="vae.rdc") + vae_start = time.time() + images = self.run("vae_decode", latents_numpy).to_host() + vae_inf_time = (time.time() - vae_start) * 1000 + # end_profiling(profile_device) + print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}") + + images = torch.from_numpy(images).permute(0, 2, 3, 1).float().numpy() + pil_images = self.image_processor.numpy_to_pil(images) + return pil_images + + # def process_sd_init_image(self, sd_init_image, resample_type): + # if isinstance(sd_init_image, list): + # images = [] + # for img in sd_init_image: + # img, _ = self.process_sd_init_image(img, resample_type) + # images.append(img) + # is_img2img = True + # return images, is_img2img + # if isinstance(sd_init_image, str): + # if os.path.isfile(sd_init_image): + # sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB") + # image, is_img2img = self.process_sd_init_image( + # sd_init_image, resample_type + # ) + # else: + # image = None + # is_img2img = False + # elif isinstance(sd_init_image, Image.Image): + # image = sd_init_image.convert("RGB") + # elif sd_init_image: + # image = sd_init_image["image"].convert("RGB") + # else: + # image = None + # is_img2img = False + # if image: + # resample_type = ( + # resamplers[resample_type] + # if resample_type in resampler_list + # # Fallback to Lanczos + # else Image.Resampling.LANCZOS + # ) + # image = image.resize((self.width, self.height), resample=resample_type) + # image_arr = np.stack([np.array(i) for i in (image,)], axis=0) + # image_arr = image_arr / 255.0 + # image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype) + # image_arr = 2 * (image_arr - 0.5) + # is_img2img = True + # image = image_arr + # return image, is_img2img def generate_images( self, prompt, negative_prompt, + image, + scheduler, steps, strength, guidance_scale, @@ -135,26 +428,101 @@ def generate_images( repeatable_seeds, resample_type, control_mode, - preprocessed_hints, + hints, ): - return None, None, None, None, None + # TODO: Batched args + self.image_processor = VaeImageProcessor(do_convert_rgb=True) + self.scheduler = self.schedulers[scheduler] + self.ondemand = ondemand + if self.is_img2img: + image, _ = self.image_processor.preprocess(image, resample_type) + else: + image = None + + print("\n[LOG] Generating images...") + batched_args = [ + prompt, + negative_prompt, + image, + ] + for arg in batched_args: + if not isinstance(arg, list): + arg = [arg] * self.batch_size + if len(arg) < self.batch_size: + arg = arg * self.batch_size + else: + arg = [arg[i] for i in range(self.batch_size)] + + text_embeddings = self.encode_prompts_weight( + prompt, + negative_prompt, + ) + uint32_info = np.iinfo(np.uint32) + uint32_min, uint32_max = uint32_info.min, uint32_info.max + if seed < uint32_min or seed >= uint32_max: + seed = randint(uint32_min, uint32_max) + + generator = torch.manual_seed(seed) + + init_latents, final_timesteps = self.prepare_latents( + generator=generator, + num_inference_steps=steps, + image=image, + strength=strength, + ) + + latents = self.produce_img_latents( + latents=init_latents, + text_embeddings=text_embeddings, + guidance_scale=guidance_scale, + total_timesteps=final_timesteps, + cpu_scheduling=True, # until we have schedulers through Turbine + ) + + # Img latents -> PIL images + all_imgs = [] + self.load_submodels(["vae_decode"]) + for i in tqdm(range(0, latents.shape[0], self.batch_size)): + imgs = self.decode_latents( + latents=latents[i : i + self.batch_size], + cpu_scheduling=True, + ) + all_imgs.extend(imgs) + if self.ondemand: + self.unload_submodels(["vae_decode"]) + + return all_imgs + + +def shark_sd_fn_dict_input( + sd_kwargs: dict, +): + print("[LOG] Submitting Request...") -# NOTE: Each `hf_model_id` should have its own starting configuration. + for key in sd_kwargs: + if sd_kwargs[key] in [None, []]: + sd_kwargs[key] = None + if sd_kwargs[key] in ["None"]: + sd_kwargs[key] = "" + if key == "seed": + sd_kwargs[key] = int(sd_kwargs[key]) -# model_vmfb_key = "" + for i in range(1): + generated_imgs = yield from shark_sd_fn(**sd_kwargs) + yield generated_imgs def shark_sd_fn( prompt, negative_prompt, - image_dict, + sd_init_image: list, height: int, width: int, steps: int, strength: float, guidance_scale: float, - seed: str | int, + seed: list, batch_count: int, batch_size: int, scheduler: str, @@ -163,86 +531,75 @@ def shark_sd_fn( custom_vae: str, precision: str, device: str, - lora_weights: str | list, ondemand: bool, repeatable_seeds: bool, resample_type: str, - control_mode: str, - stencils: list, - images: list, - preprocessed_hints: list, - progress=gr.Progress(), + controlnets: dict, + embeddings: dict, ): - # Handling gradio ImageEditor datatypes so we have unified inputs to the SD API - for i, stencil in enumerate(stencils): - if images[i] is None and stencil is not None: - continue - elif stencil is None and any( - img is not None for img in [images[i], preprocessed_hints[i]] - ): - images[i] = None - preprocessed_hints[i] = None - elif images[i] is not None: - if isinstance(images[i], dict): - images[i] = images[i]["composite"] - images[i] = images[i].convert("RGB") - - if isinstance(image_dict, PIL.Image.Image): - image = image_dict.convert("RGB") - elif image_dict: - image = image_dict["image"].convert("RGB") - else: - image = None - is_img2img = False - if image: - ( - image, - _, - _, - ) = resize_stencil(image, width, height) - is_img2img = True - print("Performing Stable Diffusion Pipeline setup...") + sd_kwargs = locals() + if not isinstance(sd_init_image, list): + sd_init_image = [sd_init_image] + is_img2img = True if sd_init_image[0] is not None else False - device_id = None + print("\n[LOG] Performing Stable Diffusion Pipeline setup...") from apps.shark_studio.modules.shared_cmd_opts import cmd_opts import apps.shark_studio.web.utils.globals as global_obj - custom_model_map = {} - if custom_weights != "None": - custom_model_map["unet"] = {"custom_weights": custom_weights} - if custom_vae != "None": - custom_model_map["vae"] = {"custom_weights": custom_vae} - if stencils: - for i, stencil in enumerate(stencils): + adapters = {} + is_controlled = False + control_mode = None + hints = [] + num_loras = 0 + for i in embeddings: + num_loras += 1 if embeddings[i] else 0 + if "model" in controlnets: + for i, model in enumerate(controlnets["model"]): if "xl" not in base_model_id.lower(): - custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[ - "runwayml/stable-diffusion-v1-5" - ][stencil] + adapters[f"control_adapter_{model}"] = { + "hf_id": control_adapter_map["runwayml/stable-diffusion-v1-5"][ + model + ], + "strength": controlnets["strength"][i], + } else: - custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[ - "stabilityai/stable-diffusion-xl-1.0" - ][stencil] + adapters[f"control_adapter_{model}"] = { + "hf_id": control_adapter_map["stabilityai/stable-diffusion-xl-1.0"][ + model + ], + "strength": controlnets["strength"][i], + } + if model is not None: + is_controlled = True + control_mode = controlnets["control_mode"] + for i in controlnets["hint"]: + hints.append[i] submit_pipe_kwargs = { "base_model_id": base_model_id, "height": height, "width": width, + "batch_size": batch_size, "precision": precision, "device": device, - "custom_model_map": custom_model_map, + "custom_vae": custom_vae, + "num_loras": num_loras, "import_ir": cmd_opts.import_mlir, - "is_img2img": is_img2img, + "is_controlled": is_controlled, } submit_prep_kwargs = { - "scheduler": scheduler, - "custom_model_map": custom_model_map, - "embeddings": lora_weights, + "custom_weights": custom_weights, + "adapters": adapters, + "embeddings": embeddings, + "is_img2img": is_img2img, } submit_run_kwargs = { "prompt": prompt, "negative_prompt": negative_prompt, + "image": sd_init_image, "steps": steps, + "scheduler": scheduler, "strength": strength, "guidance_scale": guidance_scale, "seed": seed, @@ -250,49 +607,52 @@ def shark_sd_fn( "repeatable_seeds": repeatable_seeds, "resample_type": resample_type, "control_mode": control_mode, - "preprocessed_hints": preprocessed_hints, + "hints": hints, } - - global sd_pipe - global sd_pipe_kwargs - - if sd_pipe_kwargs and sd_pipe_kwargs != submit_pipe_kwargs: - sd_pipe = None - sd_pipe_kwargs = submit_pipe_kwargs + if ( + not global_obj.get_sd_obj() + or global_obj.get_pipe_kwargs() != submit_pipe_kwargs + ): + print("\n[LOG] Initializing new pipeline...") + global_obj.clear_cache() gc.collect() - if sd_pipe is None: - history[-1][-1] = "Getting the pipeline ready..." - yield history, "" - # Initializes the pipeline and retrieves IR based on all # parameters that are static in the turbine output format, # which is currently MLIR in the torch dialect. - sd_pipe = SharkStableDiffusionPipeline( + sd_pipe = StableDiffusion( **submit_pipe_kwargs, ) - - sd_pipe.prepare_pipe(**submit_prep_kwargs) - - for prompt, msg, exec_time in progress.tqdm( - out_imgs=sd_pipe.generate_images(**submit_run_kwargs), - desc="Generating Image...", + global_obj.set_sd_obj(sd_pipe) + global_obj.set_pipe_kwargs(submit_pipe_kwargs) + if ( + not global_obj.get_prep_kwargs() + or global_obj.get_prep_kwargs() != submit_prep_kwargs ): - text_output = get_generation_text_info( - seeds[: current_batch + 1], device - ) + global_obj.set_prep_kwargs(submit_prep_kwargs) + global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs) + + generated_imgs = [] + for current_batch in range(batch_count): + start_time = time.time() + out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs) + total_time = time.time() - start_time + text_output = f"Total image(s) generation time: {total_time:.4f}sec" + print(f"\n[LOG] {text_output}") + # if global_obj.get_sd_status() == SD_STATE_CANCEL: + # break + # else: save_output_img( - out_imgs[0], - seeds[current_batch], - extra_info, + out_imgs[current_batch], + seed, + sd_kwargs, ) generated_imgs.extend(out_imgs) - yield generated_imgs, text_output, status_label( + yield generated_imgs, status_label( "Stable Diffusion", current_batch + 1, batch_count, batch_size - ), stencils, images - - return generated_imgs, text_output, "", stencils, images + ) + return generated_imgs, "" def cancel_sd(): @@ -300,9 +660,23 @@ def cancel_sd(): return +def view_json_file(file_path): + content = "" + with open(file_path, "r") as fopen: + content = fopen.read() + return content + + if __name__ == "__main__": - sd = StableDiffusion( - "runwayml/stable-diffusion-v1-5", - device="vulkan", - ) - print("model loaded") + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + import apps.shark_studio.web.utils.globals as global_obj + + global_obj._init() + + sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json")) + sd_kwargs = json.loads(sd_json) + for arg in vars(cmd_opts): + if arg in sd_kwargs: + sd_kwargs[arg] = getattr(cmd_opts, arg) + for i in shark_sd_fn_dict_input(sd_kwargs): + print(i) diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index a4f52dca24..e9268aa83b 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -1,8 +1,5 @@ -import os -import sys -import os import numpy as np -import glob +import json from random import ( randint, seed as seed_random, @@ -11,7 +8,6 @@ ) from pathlib import Path -from safetensors.torch import load_file from apps.shark_studio.modules.shared_cmd_opts import cmd_opts from cpuinfo import get_cpu_info @@ -22,11 +18,6 @@ get_iree_vulkan_runtime_flags, ) -checkpoints_filetypes = ( - "*.ckpt", - "*.safetensors", -) - def get_available_devices(): def get_devices_by_name(driver_name): @@ -55,9 +46,7 @@ def get_devices_by_name(driver_name): if len(device_list_dict) == 1: device_list.append(f"{device_name} => {driver_name}") else: - device_list.append( - f"{device_name} => {driver_name}://{i}" - ) + device_list.append(f"{device_name} => {driver_name}://{i}") return device_list set_iree_runtime_flags() @@ -109,6 +98,8 @@ def set_init_device_flags(): elif "metal" in cmd_opts.device: device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device) if not cmd_opts.iree_metal_target_platform: + from shark.iree_utils.metal_utils import get_metal_target_triple + triple = get_metal_target_triple(device_name) if triple is not None: cmd_opts.iree_metal_target_platform = triple.split("-")[-1] @@ -150,60 +141,6 @@ def get_all_devices(driver_name): return device_list_src -def get_resource_path(relative_path): - """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) - return os.path.join(base_path, relative_path) - - -def get_generated_imgs_path() -> Path: - return Path( - cmd_opts.output_dir - if cmd_opts.output_dir - else get_resource_path("..\web\generated_imgs") - ) - - -def get_generated_imgs_todays_subdir() -> str: - return dt.now().strftime("%Y%m%d") - - -def create_checkpoint_folders(): - dir = ["vae", "lora"] - if not cmd_opts.ckpt_dir: - dir.insert(0, "models") - else: - if not os.path.isdir(cmd_opts.ckpt_dir): - sys.exit( - f"Invalid --ckpt_dir argument, " - f"{args.ckpt_dir} folder does not exists." - ) - for root in dir: - Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True) - - -def get_checkpoints_path(model=""): - return get_resource_path(f"..\web\models\{model}") - - -def get_checkpoints(model="models"): - ckpt_files = [] - file_types = checkpoints_filetypes - if model == "lora": - file_types = file_types + ("*.pt", "*.bin") - for extn in file_types: - files = [ - os.path.basename(x) - for x in glob.glob(os.path.join(get_checkpoints_path(model), extn)) - ] - ckpt_files.extend(files) - return sorted(ckpt_files, key=str.casefold) - - -def get_checkpoint_pathfile(checkpoint_name, model="models"): - return os.path.join(get_checkpoints_path(model), checkpoint_name) - - def get_device_mapping(driver, key_combination=3): """This method ensures consistent device ordering when choosing specific devices for execution @@ -250,6 +187,8 @@ def get_opt_flags(model, precision="fp16"): f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}" ) if "rocm" in cmd_opts.device: + from shark.iree_utils.gpu_utils import get_iree_rocm_args + rocm_args = get_iree_rocm_args() iree_flags.extend(rocm_args) if cmd_opts.iree_constant_folding == False: @@ -318,9 +257,7 @@ def get_devices_by_name(driver_name): if len(device_list_dict) == 1: device_list.append(f"{device_name} => {driver_name}") else: - device_list.append( - f"{device_name} => {driver_name}://{i}" - ) + device_list.append(f"{device_name} => {driver_name}://{i}") return device_list set_iree_runtime_flags() @@ -352,28 +289,6 @@ def get_devices_by_name(driver_name): return available_devices -# take a seed expression in an input format and convert it to -# a list of integers, where possible -def parse_seed_input(seed_input: str | list | int): - if isinstance(seed_input, str): - try: - seed_input = json.loads(seed_input) - except (ValueError, TypeError): - seed_input = None - - if isinstance(seed_input, int): - return [seed_input] - - if isinstance(seed_input, list) and all( - type(seed) is int for seed in seed_input - ): - return seed_input - - raise TypeError( - "Seed input must be an integer or an array of integers in JSON format" - ) - - # Generate and return a new seed if the provided one is not in the # supported range (including -1) def sanitize_seed(seed: int | str): @@ -397,9 +312,7 @@ def parse_seed_input(seed_input: str | list | int): if isinstance(seed_input, int): return [seed_input] - if isinstance(seed_input, list) and all( - type(seed) is int for seed in seed_input - ): + if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input): return seed_input raise TypeError( diff --git a/apps/shark_studio/modules/checkpoint_proc.py b/apps/shark_studio/modules/checkpoint_proc.py deleted file mode 100644 index e924de4640..0000000000 --- a/apps/shark_studio/modules/checkpoint_proc.py +++ /dev/null @@ -1,66 +0,0 @@ -import os -import json -import re -from pathlib import Path -from omegaconf import OmegaConf - - -def get_path_to_diffusers_checkpoint(custom_weights): - path = Path(custom_weights) - diffusers_path = path.parent.absolute() - diffusers_directory_name = os.path.join("diffusers", path.stem) - complete_path_to_diffusers = diffusers_path / diffusers_directory_name - complete_path_to_diffusers.mkdir(parents=True, exist_ok=True) - path_to_diffusers = complete_path_to_diffusers.as_posix() - return path_to_diffusers - - -def preprocessCKPT(custom_weights, is_inpaint=False): - path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights) - if next(Path(path_to_diffusers).iterdir(), None): - print("Checkpoint already loaded at : ", path_to_diffusers) - return - else: - print( - "Diffusers' checkpoint will be identified here : ", - path_to_diffusers, - ) - from_safetensors = ( - True if custom_weights.lower().endswith(".safetensors") else False - ) - # EMA weights usually yield higher quality images for inference but - # non-EMA weights have been yielding better results in our case. - # TODO: Add an option `--ema` (`--no-ema`) for users to specify if - # they want to go for EMA weight extraction or not. - extract_ema = False - print( - "Loading diffusers' pipeline from original stable diffusion checkpoint" - ) - num_in_channels = 9 if is_inpaint else 4 - pipe = download_from_original_stable_diffusion_ckpt( - checkpoint_path_or_dict=custom_weights, - extract_ema=extract_ema, - from_safetensors=from_safetensors, - num_in_channels=num_in_channels, - ) - pipe.save_pretrained(path_to_diffusers) - print("Loading complete") - - -def convert_original_vae(vae_checkpoint): - vae_state_dict = {} - for key in list(vae_checkpoint.keys()): - vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key) - - config_url = ( - "https://raw.githubusercontent.com/CompVis/stable-diffusion/" - "main/configs/stable-diffusion/v1-inference.yaml" - ) - original_config_file = BytesIO(requests.get(config_url).content) - original_config = OmegaConf.load(original_config_file) - vae_config = create_vae_diffusers_config(original_config, image_size=512) - - converted_vae_checkpoint = convert_ldm_vae_checkpoint( - vae_state_dict, vae_config - ) - return converted_vae_checkpoint diff --git a/apps/shark_studio/modules/ckpt_processing.py b/apps/shark_studio/modules/ckpt_processing.py new file mode 100644 index 0000000000..08681f6c56 --- /dev/null +++ b/apps/shark_studio/modules/ckpt_processing.py @@ -0,0 +1,122 @@ +import os +import json +import re +import requests +from io import BytesIO +from pathlib import Path +from tqdm import tqdm +from omegaconf import OmegaConf +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + download_from_original_stable_diffusion_ckpt, + create_vae_diffusers_config, + convert_ldm_vae_checkpoint, +) + + +def get_path_to_diffusers_checkpoint(custom_weights): + path = Path(custom_weights) + diffusers_path = path.parent.absolute() + diffusers_directory_name = os.path.join("diffusers", path.stem) + complete_path_to_diffusers = diffusers_path / diffusers_directory_name + complete_path_to_diffusers.mkdir(parents=True, exist_ok=True) + path_to_diffusers = complete_path_to_diffusers.as_posix() + return path_to_diffusers + + +def preprocessCKPT(custom_weights, is_inpaint=False): + path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights) + if next(Path(path_to_diffusers).iterdir(), None): + print("Checkpoint already loaded at : ", path_to_diffusers) + return + else: + print( + "Diffusers' checkpoint will be identified here : ", + path_to_diffusers, + ) + from_safetensors = ( + True if custom_weights.lower().endswith(".safetensors") else False + ) + # EMA weights usually yield higher quality images for inference but + # non-EMA weights have been yielding better results in our case. + # TODO: Add an option `--ema` (`--no-ema`) for users to specify if + # they want to go for EMA weight extraction or not. + extract_ema = False + print("Loading diffusers' pipeline from original stable diffusion checkpoint") + num_in_channels = 9 if is_inpaint else 4 + pipe = download_from_original_stable_diffusion_ckpt( + checkpoint_path_or_dict=custom_weights, + extract_ema=extract_ema, + from_safetensors=from_safetensors, + num_in_channels=num_in_channels, + ) + pipe.save_pretrained(path_to_diffusers) + print("Loading complete") + + +def convert_original_vae(vae_checkpoint): + vae_state_dict = {} + for key in list(vae_checkpoint.keys()): + vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key) + + config_url = ( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/" + "main/configs/stable-diffusion/v1-inference.yaml" + ) + original_config_file = BytesIO(requests.get(config_url).content) + original_config = OmegaConf.load(original_config_file) + vae_config = create_vae_diffusers_config(original_config, image_size=512) + + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, vae_config) + return converted_vae_checkpoint + + +def process_custom_pipe_weights(custom_weights): + if custom_weights != "": + if custom_weights.startswith("https://civitai.com/api/"): + # download the checkpoint from civitai if we don't already have it + weights_path = get_civitai_checkpoint(custom_weights) + + # act as if we were given the local file as custom_weights originally + custom_weights_tgt = get_path_to_diffusers_checkpoint(weights_path) + custom_weights_params = weights_path + + else: + assert custom_weights.lower().endswith( + (".ckpt", ".safetensors") + ), "checkpoint files supported can be any of [.ckpt, .safetensors] type" + custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights) + custom_weights_params = custom_weights + return custom_weights_params, custom_weights_tgt + + +def get_civitai_checkpoint(url: str): + with requests.get(url, allow_redirects=True, stream=True) as response: + response.raise_for_status() + + # civitai api returns the filename in the content disposition + base_filename = re.findall( + '"([^"]*)"', response.headers["Content-Disposition"] + )[0] + destination_path = Path.cwd() / (cmd_opts.ckpt_dir or "models") / base_filename + + # we don't have this model downloaded yet + if not destination_path.is_file(): + print(f"downloading civitai model from {url} to {destination_path}") + + size = int(response.headers["content-length"], 0) + progress_bar = tqdm(total=size, unit="iB", unit_scale=True) + + with open(destination_path, "wb") as f: + for chunk in response.iter_content(chunk_size=65536): + f.write(chunk) + progress_bar.update(len(chunk)) + + progress_bar.close() + + # we already have this model downloaded + else: + print(f"civitai model already downloaded to {destination_path}") + + response.close() + return destination_path.as_posix() diff --git a/apps/shark_studio/modules/embeddings.py b/apps/shark_studio/modules/embeddings.py index 131c9006e5..87924c819e 100644 --- a/apps/shark_studio/modules/embeddings.py +++ b/apps/shark_studio/modules/embeddings.py @@ -5,7 +5,10 @@ import safetensors from dataclasses import dataclass from safetensors.torch import load_file -from apps.shark_studio.api.utils import get_checkpoint_pathfile +from apps.shark_studio.web.utils.file_utils import ( + get_checkpoint_pathfile, + get_path_stem, +) @dataclass @@ -73,22 +76,14 @@ def processLoRA(model, use_lora, splitting_prefix, lora_strength=0.75): scale = lora_weight.alpha * lora_strength if len(weight.size()) == 2: if len(lora_weight.up.shape) == 4: - weight_up = ( - lora_weight.up.squeeze(3).squeeze(2).to(torch.float32) - ) - weight_down = ( - lora_weight.down.squeeze(3).squeeze(2).to(torch.float32) - ) - change = ( - torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) - ) + weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32) + weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32) + change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) else: change = torch.mm(lora_weight.up, lora_weight.down) elif lora_weight.down.size()[2:4] == (1, 1): weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32) - weight_down = ( - lora_weight.down.squeeze(3).squeeze(2).to(torch.float32) - ) + weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32) change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) else: change = torch.nn.functional.conv2d( @@ -163,9 +158,7 @@ def get_lora_metadata(lora_filename): # get a figure for the total number of images processed for this dataset # either then number actually listed or in its dataset_dir entry or # the highest frequency's number if that doesn't exist - img_count = dataset_dirs.get(dir, {}).get( - "img_count", frequencies[0][1] - ) + img_count = dataset_dirs.get(dir, {}).get("img_count", frequencies[0][1]) # add the dataset frequencies to the overall frequencies replacing the # frequency counts on the tags with a percentage/ratio diff --git a/apps/shark_studio/modules/img_processing.py b/apps/shark_studio/modules/img_processing.py index b5cf28ce47..80062814cf 100644 --- a/apps/shark_studio/modules/img_processing.py +++ b/apps/shark_studio/modules/img_processing.py @@ -1,11 +1,33 @@ import os -import sys -from PIL import Image +import re +import json + +from csv import DictWriter +from PIL import Image, PngImagePlugin from pathlib import Path +from datetime import datetime as dt +from base64 import decode + +resamplers = { + "Lanczos": Image.Resampling.LANCZOS, + "Nearest Neighbor": Image.Resampling.NEAREST, + "Bilinear": Image.Resampling.BILINEAR, + "Bicubic": Image.Resampling.BICUBIC, + "Hamming": Image.Resampling.HAMMING, + "Box": Image.Resampling.BOX, +} + +resampler_list = resamplers.keys() # save output images and the inputs corresponding to it. def save_output_img(output_img, img_seed, extra_info=None): + from apps.shark_studio.web.utils.file_utils import ( + get_generated_imgs_path, + get_generated_imgs_todays_subdir, + ) + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + if extra_info is None: extra_info = {} generated_imgs_path = Path( @@ -14,20 +36,23 @@ def save_output_img(output_img, img_seed, extra_info=None): generated_imgs_path.mkdir(parents=True, exist_ok=True) csv_path = Path(generated_imgs_path, "imgs_details.csv") - prompt_slice = re.sub("[^a-zA-Z0-9]", "_", cmd_opts.prompts[0][:15]) + prompt_slice = re.sub("[^a-zA-Z0-9]", "_", extra_info["prompt"][0][:15]) out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}" - img_model = cmd_opts.hf_model_id - if cmd_opts.ckpt_loc: - img_model = Path(os.path.basename(cmd_opts.ckpt_loc)).stem + img_model = extra_info["base_model_id"] + if extra_info["custom_weights"] not in [None, "None"]: + img_model = Path(os.path.basename(extra_info["custom_weights"])).stem img_vae = None - if cmd_opts.custom_vae: - img_vae = Path(os.path.basename(cmd_opts.custom_vae)).stem + if extra_info["custom_vae"]: + img_vae = Path(os.path.basename(extra_info["custom_vae"])).stem - img_lora = None - if cmd_opts.use_lora: - img_lora = Path(os.path.basename(cmd_opts.use_lora)).stem + img_loras = None + if extra_info["embeddings"]: + img_lora = [] + for i in extra_info["embeddings"]: + img_lora += Path(os.path.basename(cmd_opts.use_lora)).stem + img_loras = ", ".join(img_lora) if cmd_opts.output_img_format == "jpg": out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg") @@ -39,25 +64,25 @@ def save_output_img(output_img, img_seed, extra_info=None): if cmd_opts.write_metadata_to_png: # Using a conditional expression caused problems, so setting a new # variable for now. - if cmd_opts.use_hiresfix: - png_size_text = ( - f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}" - ) - else: - png_size_text = f"{cmd_opts.width}x{cmd_opts.height}" + # if cmd_opts.use_hiresfix: + # png_size_text = ( + # f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}" + # ) + # else: + png_size_text = f"{extra_info['width']}x{extra_info['height']}" pngInfo.add_text( "parameters", - f"{cmd_opts.prompts[0]}" - f"\nNegative prompt: {cmd_opts.negative_prompts[0]}" - f"\nSteps: {cmd_opts.steps}," - f"Sampler: {cmd_opts.scheduler}, " - f"CFG scale: {cmd_opts.guidance_scale}, " + f"{extra_info['prompt'][0]}" + f"\nNegative prompt: {extra_info['negative_prompt'][0]}" + f"\nSteps: {extra_info['steps']}," + f"Sampler: {extra_info['scheduler']}, " + f"CFG scale: {extra_info['guidance_scale']}, " f"Seed: {img_seed}," f"Size: {png_size_text}, " f"Model: {img_model}, " f"VAE: {img_vae}, " - f"LoRA: {img_lora}", + f"LoRA: {img_loras}", ) output_img.save(out_img_path, "PNG", pnginfo=pngInfo) @@ -72,26 +97,7 @@ def save_output_img(output_img, img_seed, extra_info=None): # To be as low-impact as possible to the existing CSV format, we append # "VAE" and "LORA" to the end. However, it does not fit the hierarchy of # importance for each data point. Something to consider. - new_entry = { - "VARIANT": img_model, - "SCHEDULER": cmd_opts.scheduler, - "PROMPT": cmd_opts.prompts[0], - "NEG_PROMPT": cmd_opts.negative_prompts[0], - "SEED": img_seed, - "CFG_SCALE": cmd_opts.guidance_scale, - "PRECISION": cmd_opts.precision, - "STEPS": cmd_opts.steps, - "HEIGHT": cmd_opts.height - if not cmd_opts.use_hiresfix - else cmd_opts.hiresfix_height, - "WIDTH": cmd_opts.width - if not cmd_opts.use_hiresfix - else cmd_opts.hiresfix_width, - "MAX_LENGTH": cmd_opts.max_length, - "OUTPUT": out_img_path, - "VAE": img_vae, - "LORA": img_lora, - } + new_entry = {} new_entry.update(extra_info) @@ -103,23 +109,9 @@ def save_output_img(output_img, img_seed, extra_info=None): dictwriter_obj.writerow(new_entry) csv_obj.close() - if cmd_opts.save_metadata_to_json: - del new_entry["OUTPUT"] - json_path = Path(generated_imgs_path, f"{out_img_name}.json") - with open(json_path, "w") as f: - json.dump(new_entry, f, indent=4) - - -resamplers = { - "Lanczos": Image.Resampling.LANCZOS, - "Nearest Neighbor": Image.Resampling.NEAREST, - "Bilinear": Image.Resampling.BILINEAR, - "Bicubic": Image.Resampling.BICUBIC, - "Hamming": Image.Resampling.HAMMING, - "Box": Image.Resampling.BOX, -} - -resampler_list = resamplers.keys() + json_path = Path(generated_imgs_path, f"{out_img_name}.json") + with open(json_path, "w") as f: + json.dump(new_entry, f, indent=4) # For stencil, the input image can be of any size, but we need to ensure that diff --git a/apps/shark_studio/modules/logger.py b/apps/shark_studio/modules/logger.py new file mode 100644 index 0000000000..bff6c933b7 --- /dev/null +++ b/apps/shark_studio/modules/logger.py @@ -0,0 +1,37 @@ +import sys + + +class Logger: + def __init__(self, filename, filter=None): + self.terminal = sys.stdout + self.log = open(filename, "w") + self.filter = filter + + def write(self, message): + for x in message.split("\n"): + if self.filter in x: + self.log.write(message) + else: + self.terminal.write(message) + + def flush(self): + self.terminal.flush() + self.log.flush() + + def isatty(self): + return False + + +def logger_test(x): + print("[LOG] This is a test") + print(f"This is another test, without the filter") + return x + + +def read_sd_logs(): + sys.stdout.flush() + with open("shark_tmp/sd.log", "r") as f: + return f.read() + + +sys.stdout = Logger("shark_tmp/sd.log", filter="[LOG]") diff --git a/apps/shark_studio/modules/pipeline.py b/apps/shark_studio/modules/pipeline.py index c087175de4..5dee266b13 100644 --- a/apps/shark_studio/modules/pipeline.py +++ b/apps/shark_studio/modules/pipeline.py @@ -1,4 +1,21 @@ -from shark.iree_utils.compile_utils import get_iree_compiled_module +from msvcrt import kbhit +from shark.iree_utils.compile_utils import ( + get_iree_compiled_module, + load_vmfb_using_mmap, + clean_device_info, + get_iree_target_triple, +) +from apps.shark_studio.web.utils.file_utils import ( + get_checkpoints_path, + get_resource_path, +) +from apps.shark_studio.modules.shared_cmd_opts import ( + cmd_opts, +) +from iree import runtime as ireert +from pathlib import Path +import gc +import os class SharkPipelineBase: @@ -12,60 +29,178 @@ class SharkPipelineBase: def __init__( self, model_map: dict, + base_model_id: str, + static_kwargs: dict, device: str, import_mlir: bool = True, ): self.model_map = model_map - self.device = device + self.pipe_map = {} + self.static_kwargs = static_kwargs + self.base_model_id = base_model_id + self.triple = get_iree_target_triple(device) + self.device, self.device_id = clean_device_info(device) self.import_mlir = import_mlir + self.iree_module_dict = {} + self.tmp_dir = get_resource_path(os.path.join("..", "shark_tmp")) + if not os.path.exists(self.tmp_dir): + os.mkdir(self.tmp_dir) + self.tempfiles = {} + self.pipe_vmfb_path = "" - def import_torch_ir(self, base_model_id): - for submodel in self.model_map: - hf_id = ( - submodel["custom_hf_id"] - if submodel["custom_hf_id"] - else base_model_id - ) - torch_ir = submodel["initializer"]( - hf_id, **submodel["init_kwargs"], compile_to="torch" - ) - submodel["tempfile_name"] = get_resource_path( - f"{submodel}.torch.tempfile" - ) - with open(submodel["tempfile_name"], "w+") as f: - f.write(torch_ir) - del torch_ir - gc.collect() - - def load_vmfb(self, submodel): - if self.iree_module_dict[submodel]: - print( - f".vmfb for {submodel} found at {self.iree_module_dict[submodel]['vmfb']}" - ) - elif self.model_map[submodel]["tempfile_name"]: - submodel["tempfile_name"] - - return submodel["vmfb"] + def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None: + # First checks whether we have .vmfbs precompiled, then populates the map + # with the precompiled executables and fetches executables for the rest of the map. + # The weights aren't static here anymore so this function should be a part of pipeline + # initialization. As soon as you have a pipeline ID unique to your static torch IR parameters, + # and your model map is populated with any IR - unique model IDs and their static params, + # call this method to get the artifacts associated with your map. + self.pipe_id = self.safe_name(pipe_id) + self.pipe_vmfb_path = Path( + os.path.join(get_checkpoints_path(".."), self.pipe_id) + ) + self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True) + if submodel == "None": + print("\n[LOG] Gathering any pre-compiled artifacts....") + for key in self.model_map: + self.get_compiled_map(pipe_id, submodel=key) + else: + self.pipe_map[submodel] = {} + self.get_precompiled(self.pipe_id, submodel) + ireec_flags = [] + if submodel in self.iree_module_dict: + return + elif "vmfb_path" in self.pipe_map[submodel]: + return + elif submodel not in self.tempfiles: + print( + f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR..." + ) + if submodel in self.static_kwargs: + init_kwargs = self.static_kwargs[submodel] + for key in self.static_kwargs["pipe"]: + if key not in init_kwargs: + init_kwargs[key] = self.static_kwargs["pipe"][key] + self.import_torch_ir(submodel, init_kwargs) + self.get_compiled_map(pipe_id, submodel) + else: + ireec_flags = ( + self.model_map[submodel]["ireec_flags"] + if "ireec_flags" in self.model_map[submodel] + else [] + ) - def merge_custom_map(self, custom_model_map): - for submodel in custom_model_map: - for key in submodel: - self.model_map[submodel][key] = key - print(self.model_map) + weights_path = self.get_io_params(submodel) - def get_compiled_map(self, device) -> None: - # this comes with keys: "vmfb", "config", and "temp_file_to_unlink". - for submodel in self.model_map: - if not self.iree_module_dict[submodel][vmfb]: self.iree_module_dict[submodel] = get_iree_compiled_module( - submodel.tempfile_name, + self.tempfiles[submodel], device=self.device, frontend="torch", + mmap=True, + external_weight_file=weights_path, + extra_args=ireec_flags, + write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"), ) - # TODO: delete the temp file + return - def run(self, submodel, inputs): + def get_io_params(self, submodel): + if "external_weight_file" in self.static_kwargs[submodel]: + # we are using custom weights + weights_path = self.static_kwargs[submodel]["external_weight_file"] + elif "external_weight_path" in self.static_kwargs[submodel]: + # we are using the default weights for the HF model + weights_path = self.static_kwargs[submodel]["external_weight_path"] + else: + # assume the torch IR contains the weights. + weights_path = None + return weights_path + + def get_precompiled(self, pipe_id, submodel="None"): + if submodel == "None": + for model in self.model_map: + self.get_precompiled(pipe_id, model) + vmfbs = [] + for dirpath, dirnames, filenames in os.walk(self.pipe_vmfb_path): + vmfbs.extend(filenames) + break + for file in vmfbs: + if submodel in file: + self.pipe_map[submodel]["vmfb_path"] = os.path.join( + self.pipe_vmfb_path, file + ) return - def safe_name(name): - return name.replace("/", "_").replace("-", "_") + def import_torch_ir(self, submodel, kwargs): + torch_ir = self.model_map[submodel]["initializer"]( + **self.safe_dict(kwargs), compile_to="torch" + ) + if submodel == "clip": + # clip.export_clip_model returns (torch_ir, tokenizer) + torch_ir = torch_ir[0] + + self.tempfiles[submodel] = os.path.join( + self.tmp_dir, f"{submodel}.torch.tempfile" + ) + + with open(self.tempfiles[submodel], "w+") as f: + f.write(torch_ir) + del torch_ir + gc.collect() + return + + def load_submodels(self, submodels: list): + for submodel in submodels: + if submodel in self.iree_module_dict: + print(f"\n[LOG] {submodel} is ready for inference.") + continue + if "vmfb_path" in self.pipe_map[submodel]: + weights_path = self.get_io_params(submodel) + # print( + # f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}" + # ) + self.iree_module_dict[submodel] = {} + ( + self.iree_module_dict[submodel]["vmfb"], + self.iree_module_dict[submodel]["config"], + self.iree_module_dict[submodel]["temp_file_to_unlink"], + ) = load_vmfb_using_mmap( + self.pipe_map[submodel]["vmfb_path"], + self.device, + device_idx=0, + rt_flags=[], + external_weight_file=weights_path, + ) + else: + self.get_compiled_map(self.pipe_id, submodel) + return + + def unload_submodels(self, submodels: list): + for submodel in submodels: + if submodel in self.iree_module_dict: + del self.iree_module_dict[submodel] + gc.collect() + return + + def run(self, submodel, inputs): + if not isinstance(inputs, list): + inputs = [inputs] + inp = [ + ireert.asdevicearray( + self.iree_module_dict[submodel]["config"].device, input + ) + for input in inputs + ] + return self.iree_module_dict[submodel]["vmfb"]["main"](*inp) + + def safe_name(self, name): + return name.replace("/", "_").replace("-", "_").replace("\\", "_") + + def safe_dict(self, kwargs: dict): + flat_args = {} + for i in kwargs: + if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]: + flat_args[i] = [kwargs[i][j] for j in kwargs[i]] + else: + flat_args[i] = kwargs[i] + + return flat_args diff --git a/apps/shark_studio/modules/prompt_encoding.py b/apps/shark_studio/modules/prompt_encoding.py new file mode 100644 index 0000000000..3dc61aba08 --- /dev/null +++ b/apps/shark_studio/modules/prompt_encoding.py @@ -0,0 +1,376 @@ +from typing import List, Optional, Union +from iree import runtime as ireert +import re +import torch +import numpy as np + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: + text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(pipe, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = pipe.tokenizer(word).input_ids[1:-1] + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + print( + "Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples" + ) + return tokens, weights + + +def pad_tokens_and_weights( + tokens, + weights, + max_length, + bos, + eos, + no_boseos_middle=True, + chunk_length=77, +): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = ( + max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + ) + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][ + j + * (chunk_length - 2) : min( + len(weights[i]), (j + 1) * (chunk_length - 2) + ) + ] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + pipe, + text_input, + chunk_length: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[ + :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2 + ].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + text_input_chunk[:, -1] = text_input[0, -1] + + text_embedding = pipe.run("clip", text_input_chunk)[0].to_host() + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + # SHARK: Convert the result to tensor + # text_embeddings = torch.concat(text_embeddings, axis=1) + text_embeddings_np = np.concatenate(np.array(text_embeddings)) + text_embeddings = torch.from_numpy(text_embeddings_np) + else: + text_embeddings = pipe.run("clip", text_input)[0] + text_embeddings = torch.from_numpy(text_embeddings.to_host()) + return text_embeddings + + +# This function deals with NoneType values occuring in tokens after padding +# It switches out None with 49407 as truncating None values causes matrix dimension errors, +def filter_nonetype_tokens(tokens: List[List]): + return [[49407 if token is None else token for token in tokens[0]]] + + +def get_weighted_text_embeddings( + pipe, + prompt: List[str], + uncond_prompt: List[str] = None, + max_embeddings_multiples: Optional[int] = 8, + no_boseos_middle: Optional[bool] = True, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, +): + max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2 + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights( + pipe, prompt, max_length - 2 + ) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = get_prompts_with_weights( + pipe, uncond_prompt, max_length - 2 + ) + else: + prompt_tokens = [ + token[1:-1] + for token in pipe.tokenizer( + prompt, max_length=max_length, truncation=True + ).input_ids + ] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [ + token[1:-1] + for token in pipe.tokenizer( + uncond_prompt, max_length=max_length, truncation=True + ).input_ids + ] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (pipe.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + + max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = pipe.tokenizer.bos_token_id + eos = pipe.tokenizer.eos_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.model_max_length, + ) + + # FIXME: This is a hacky fix caused by tokenizer padding with None values + prompt_tokens = filter_nonetype_tokens(prompt_tokens) + + # prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu") + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.model_max_length, + ) + + # FIXME: This is a hacky fix caused by tokenizer padding with None values + uncond_tokens = filter_nonetype_tokens(uncond_tokens) + + # uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device="cpu") + + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( + pipe, + prompt_tokens, + pipe.model_max_length, + no_boseos_middle=no_boseos_middle, + ) + # prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) + prompt_weights = torch.tensor(prompt_weights, dtype=torch.float, device="cpu") + if uncond_prompt is not None: + uncond_embeddings = get_unweighted_text_embeddings( + pipe, + uncond_tokens, + pipe.model_max_length, + no_boseos_middle=no_boseos_middle, + ) + # uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) + uncond_weights = torch.tensor(uncond_weights, dtype=torch.float, device="cpu") + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + if (not skip_parsing) and (not skip_weighting): + previous_mean = ( + text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + ) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = ( + text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + ) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = ( + uncond_embeddings.float() + .mean(axis=[-2, -1]) + .to(uncond_embeddings.dtype) + ) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = ( + uncond_embeddings.float() + .mean(axis=[-2, -1]) + .to(uncond_embeddings.dtype) + ) + uncond_embeddings *= ( + (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + ) + + if uncond_prompt is not None: + return text_embeddings, uncond_embeddings + return text_embeddings, None diff --git a/apps/shark_studio/modules/schedulers.py b/apps/shark_studio/modules/schedulers.py index c62646f69c..8c2413c638 100644 --- a/apps/shark_studio/modules/schedulers.py +++ b/apps/shark_studio/modules/schedulers.py @@ -1,4 +1,99 @@ # from shark_turbine.turbine_models.schedulers import export_scheduler_model +from diffusers import ( + LCMScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDPMScheduler, + DDIMScheduler, + DPMSolverMultistepScheduler, + KDPM2DiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DEISMultistepScheduler, + DPMSolverSinglestepScheduler, + KDPM2AncestralDiscreteScheduler, + HeunDiscreteScheduler, +) + + +def get_schedulers(model_id): + # TODO: switch over to turbine and run all on GPU + print(f"\n[LOG] Initializing schedulers from model id: {model_id}") + schedulers = dict() + schedulers["PNDM"] = PNDMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["DDPM"] = DDPMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["DDIM"] = DDIMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["LCMScheduler"] = LCMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained( + model_id, subfolder="scheduler", algorithm_type="dpmsolver" + ) + schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained( + model_id, subfolder="scheduler", algorithm_type="dpmsolver++" + ) + schedulers[ + "DPMSolverMultistepKarras" + ] = DPMSolverMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + use_karras_sigmas=True, + ) + schedulers[ + "DPMSolverMultistepKarras++" + ] = DPMSolverMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + algorithm_type="dpmsolver++", + use_karras_sigmas=True, + ) + schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers[ + "EulerAncestralDiscrete" + ] = EulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["DPMSolverSinglestep"] = DPMSolverSinglestepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers[ + "KDPM2AncestralDiscrete" + ] = KDPM2AncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + return schedulers def export_scheduler_model(model): @@ -7,24 +102,16 @@ def export_scheduler_model(model): scheduler_model_map = { "EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"), - "EulerAncestralDiscrete": export_scheduler_model( - "EulerAncestralDiscreteScheduler" - ), + "EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"), "LCM": export_scheduler_model("LCMScheduler"), "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"), "PNDM": export_scheduler_model("PNDMScheduler"), "DDPM": export_scheduler_model("DDPMScheduler"), "DDIM": export_scheduler_model("DDIMScheduler"), - "DPMSolverMultistep": export_scheduler_model( - "DPMSolverMultistepScheduler" - ), + "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"), "KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"), "DEISMultistep": export_scheduler_model("DEISMultistepScheduler"), - "DPMSolverSinglestep": export_scheduler_model( - "DPMSolverSingleStepScheduler" - ), - "KDPM2AncestralDiscrete": export_scheduler_model( - "KDPM2AncestralDiscreteScheduler" - ), + "DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"), + "KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"), "HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"), } diff --git a/apps/shark_studio/modules/seed.py b/apps/shark_studio/modules/seed.py new file mode 100644 index 0000000000..d0b022a6f1 --- /dev/null +++ b/apps/shark_studio/modules/seed.py @@ -0,0 +1,66 @@ +import numpy as np +import json +from random import ( + randint, + seed as seed_random, + getstate as random_getstate, + setstate as random_setstate, +) + + +# Generate and return a new seed if the provided one is not in the +# supported range (including -1) +def sanitize_seed(seed: int | str): + seed = int(seed) + uint32_info = np.iinfo(np.uint32) + uint32_min, uint32_max = uint32_info.min, uint32_info.max + if seed < uint32_min or seed >= uint32_max: + seed = randint(uint32_min, uint32_max) + return seed + + +# take a seed expression in an input format and convert it to +# a list of integers, where possible +def parse_seed_input(seed_input: str | list | int): + if isinstance(seed_input, str): + try: + seed_input = json.loads(seed_input) + except (ValueError, TypeError): + seed_input = None + + if isinstance(seed_input, int): + return [seed_input] + + if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input): + return seed_input + + raise TypeError( + "Seed input must be an integer or an array of integers in JSON format" + ) + + +# Generate a set of seeds from an input expression for batch_count batches, +# optionally using that input as the rng seed for any randomly generated seeds. +def batch_seeds(seed_input: str | list | int, batch_count: int, repeatable=False): + # turn the input into a list if possible + seeds = parse_seed_input(seed_input) + + # slice or pad the list to be of batch_count length + seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds)) + + if repeatable: + if all(seed < 0 for seed in seeds): + seeds[0] = sanitize_seed(seeds[0]) + + # set seed for the rng based on what we have so far + saved_random_state = random_getstate() + seed_random(str([n for n in seeds if n > -1])) + + # generate any seeds that are unspecified + seeds = [sanitize_seed(seed) for seed in seeds] + + if repeatable: + # reset the rng back to normal + random_setstate(saved_random_state) + + return seeds diff --git a/apps/shark_studio/modules/shared.py b/apps/shark_studio/modules/shared.py deleted file mode 100644 index d9dc3ea26e..0000000000 --- a/apps/shark_studio/modules/shared.py +++ /dev/null @@ -1,69 +0,0 @@ -import sys - -import gradio as gr - -from modules import ( - shared_cmd_options, - shared_gradio, - options, - shared_items, - sd_models_types, -) -from modules.paths_internal import ( - models_path, - script_path, - data_path, - sd_configs_path, - sd_default_config, - sd_model_file, - default_sd_model_file, - extensions_dir, - extensions_builtin_dir, -) # noqa: F401 -from modules import util - -cmd_opts = shared_cmd_options.cmd_opts -parser = shared_cmd_options.parser - -parallel_processing_allowed = True -styles_filename = cmd_opts.styles_file -config_filename = cmd_opts.ui_settings_file - -demo = None - -device = None - -weight_load_location = None - -state = None - -prompt_styles = None - -options_templates = None -opts = None -restricted_opts = None - -sd_model: sd_models_types.WebuiSdModel = None - -settings_components = None -"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings""" - -tab_names = [] - -sd_upscalers = [] - -clip_model = None - -progress_print_out = sys.stdout - -gradio_theme = gr.themes.Base() - -total_tqdm = None - -mem_mon = None - -reload_gradio_theme = shared_gradio.reload_gradio_theme - -list_checkpoint_tiles = shared_items.list_checkpoint_tiles -refresh_checkpoints = shared_items.refresh_checkpoints -list_samplers = shared_items.list_samplers diff --git a/apps/shark_studio/modules/shared_cmd_opts.py b/apps/shark_studio/modules/shared_cmd_opts.py index dfb166a52e..93a09c6758 100644 --- a/apps/shark_studio/modules/shared_cmd_opts.py +++ b/apps/shark_studio/modules/shared_cmd_opts.py @@ -32,7 +32,7 @@ def is_valid_file(arg): ) p.add_argument( "-p", - "--prompts", + "--prompt", nargs="+", default=[ "a photo taken of the front of a super-car drifting on a road near " @@ -44,7 +44,7 @@ def is_valid_file(arg): ) p.add_argument( - "--negative_prompts", + "--negative_prompt", nargs="+", default=[ "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), " @@ -54,7 +54,7 @@ def is_valid_file(arg): ) p.add_argument( - "--img_path", + "--sd_init_image", type=str, help="Path to the image input for img2img/inpainting.", ) @@ -130,8 +130,7 @@ def is_valid_file(arg): "--strength", type=float, default=0.8, - help="The strength of change applied on the given input image for " - "img2img.", + help="The strength of change applied on the given input image for " "img2img.", ) p.add_argument( @@ -290,9 +289,7 @@ def is_valid_file(arg): # Model Config and Usage Params ############################################################################## -p.add_argument( - "--device", type=str, default="vulkan", help="Device to run the model." -) +p.add_argument("--device", type=str, default="vulkan", help="Device to run the model.") p.add_argument( "--precision", type=str, default="fp16", help="Precision to run the model." @@ -323,7 +320,7 @@ def is_valid_file(arg): p.add_argument( "--scheduler", type=str, - default="SharkEulerDiscrete", + default="DDIM", help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, " "DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, " "DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, " @@ -350,8 +347,7 @@ def is_valid_file(arg): "--batch_count", type=int, default=1, - help="Number of batches to be generated with random seeds in " - "single execution.", + help="Number of batches to be generated with random seeds in " "single execution.", ) p.add_argument( @@ -363,10 +359,10 @@ def is_valid_file(arg): ) p.add_argument( - "--ckpt_loc", + "--custom_weights", type=str, default="", - help="Path to SD's .ckpt file.", + help="Path to a .safetensors or .ckpt file for SD pipeline weights.", ) p.add_argument( @@ -378,7 +374,7 @@ def is_valid_file(arg): ) p.add_argument( - "--hf_model_id", + "--base_model_id", type=str, default="stabilityai/stable-diffusion-2-1-base", help="The repo-id of hugging face.", @@ -416,8 +412,7 @@ def is_valid_file(arg): "--use_lora", type=str, default="", - help="Use standalone LoRA weight using a HF ID or a checkpoint " - "file (~3 MB).", + help="Use standalone LoRA weight using a HF ID or a checkpoint " "file (~3 MB).", ) p.add_argument( @@ -453,12 +448,6 @@ def is_valid_file(arg): "Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)", ) -p.add_argument( - "--custom_model_map", - type=str, - default="", - help="path to custom model map to import. This should be a .json file", -) ############################################################################## # IREE - Vulkan supported flags ############################################################################## @@ -499,8 +488,7 @@ def is_valid_file(arg): "--dump_isa", default=False, action="store_true", - help="When enabled call amdllpc to get ISA dumps. " - "Use with dispatch benchmarks.", + help="When enabled call amdllpc to get ISA dumps. " "Use with dispatch benchmarks.", ) p.add_argument( @@ -521,8 +509,7 @@ def is_valid_file(arg): "--enable_rgp", default=False, action=argparse.BooleanOptionalAction, - help="Flag for inserting debug frames between iterations " - "for use with rgp.", + help="Flag for inserting debug frames between iterations " "for use with rgp.", ) p.add_argument( @@ -608,8 +595,7 @@ def is_valid_file(arg): "--progress_bar", default=True, action=argparse.BooleanOptionalAction, - help="Flag for removing the progress bar animation during " - "image generation.", + help="Flag for removing the progress bar animation during " "image generation.", ) p.add_argument( @@ -675,6 +661,13 @@ def is_valid_file(arg): "images under --output_dir in the UI.", ) +p.add_argument( + "--configs_path", + default=None, + type=str, + help="Path to .json config directory.", +) + p.add_argument( "--output_gallery_followlinks", default=False, diff --git a/apps/shark_studio/modules/timer.py b/apps/shark_studio/modules/timer.py index 8fd1e6a7df..d6918e9c8c 100644 --- a/apps/shark_studio/modules/timer.py +++ b/apps/shark_studio/modules/timer.py @@ -11,9 +11,7 @@ def __init__(self, timer, category): def __enter__(self): self.start = time.time() - self.timer.base_category = ( - self.original_base_category + self.category + "/" - ) + self.timer.base_category = self.original_base_category + self.category + "/" self.timer.subcategory_level += 1 if self.timer.print_log: @@ -82,10 +80,7 @@ def summary(self): res += " (" res += ", ".join( - [ - f"{category}: {time_taken:.1f}s" - for category, time_taken in additions - ] + [f"{category}: {time_taken:.1f}s" for category, time_taken in additions] ) res += ")" diff --git a/apps/shark_studio/tests/jupiter.png b/apps/shark_studio/tests/jupiter.png new file mode 100644 index 0000000000..e479e20548 Binary files /dev/null and b/apps/shark_studio/tests/jupiter.png differ diff --git a/apps/shark_studio/web/api/compat.py b/apps/shark_studio/web/api/compat.py index 80399505c4..3f92c41d02 100644 --- a/apps/shark_studio/web/api/compat.py +++ b/apps/shark_studio/web/api/compat.py @@ -30,17 +30,13 @@ def decode_base64_to_image(encoding): status_code=500, detail="Request to local resource not allowed" ) - headers = ( - {"user-agent": opts.api_useragent} if opts.api_useragent else {} - ) + headers = {"user-agent": opts.api_useragent} if opts.api_useragent else {} response = requests.get(encoding, timeout=30, headers=headers) try: image = Image.open(BytesIO(response.content)) return image except Exception as e: - raise HTTPException( - status_code=500, detail="Invalid image url" - ) from e + raise HTTPException(status_code=500, detail="Invalid image url") from e if encoding.startswith("data:image/"): encoding = encoding.split(";")[1].split(",")[1] @@ -48,9 +44,7 @@ def decode_base64_to_image(encoding): image = Image.open(BytesIO(base64.b64decode(encoding))) return image except Exception as e: - raise HTTPException( - status_code=500, detail="Invalid encoded image" - ) from e + raise HTTPException(status_code=500, detail="Invalid encoded image") from e def encode_pil_to_base64(image): diff --git a/apps/shark_studio/web/configs/default_sd_config.json b/apps/shark_studio/web/configs/default_sd_config.json new file mode 100644 index 0000000000..7a98a441df --- /dev/null +++ b/apps/shark_studio/web/configs/default_sd_config.json @@ -0,0 +1 @@ +{"prompt": ["a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"], "negative_prompt": ["watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"], "sd_init_image": [null], "height": 512, "width": 512, "steps": 50, "strength": 0.8, "guidance_scale": 7.5, "seed": "-1", "batch_count": 1, "batch_size": 1, "scheduler": "EulerDiscrete", "base_model_id": "stabilityai/stable-diffusion-2-1-base", "custom_weights": "None", "custom_vae": "None", "precision": "fp16", "device": "AMD Radeon RX 7900 XTX => vulkan://0", "ondemand": false, "repeatable_seeds": false, "resample_type": "Nearest Neighbor", "controlnets": {}, "embeddings": {}} \ No newline at end of file diff --git a/apps/shark_studio/web/configs/foo.json b/apps/shark_studio/web/configs/foo.json deleted file mode 100644 index 0967ef424b..0000000000 --- a/apps/shark_studio/web/configs/foo.json +++ /dev/null @@ -1 +0,0 @@ -{} diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py index 58b0c6c00b..05a9bc363d 100644 --- a/apps/shark_studio/web/index.py +++ b/apps/shark_studio/web/index.py @@ -5,9 +5,6 @@ import logging import apps.shark_studio.api.initializers as initialize -from ui.chat import chat_element -from ui.sd import sd_element -from ui.outputgallery import outputgallery_element from apps.shark_studio.modules import timer @@ -75,11 +72,13 @@ def launch_webui(address): def webui(): from apps.shark_studio.modules.shared_cmd_opts import cmd_opts - logging.basicConfig(level=logging.DEBUG) - launch_api = cmd_opts.api initialize.initialize() + from ui.chat import chat_element + from ui.sd import sd_element + from ui.outputgallery import outputgallery_element + # required to do multiprocessing in a pyinstaller freeze freeze_support() @@ -127,27 +126,8 @@ def webui(): # # uvicorn.run(api, host="0.0.0.0", port=args.server_port) # sys.exit(0) - # Setup to use shark_tmp for gradio's temporary image files and clear any - # existing temporary images there if they exist. Then we can import gradio. - # It has to be in this order or gradio ignores what we've set up. - from apps.shark_studio.web.utils.tmp_configs import ( - config_tmp, - clear_tmp_mlir, - clear_tmp_imgs, - ) - from apps.shark_studio.api.utils import ( - create_checkpoint_folders, - ) - import gradio as gr - config_tmp() - clear_tmp_mlir() - clear_tmp_imgs() - - # Create custom models folders if they don't exist - create_checkpoint_folders() - def resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) @@ -198,6 +178,7 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): chat_element.render() studio_web.queue() + # if args.ui == "app": # t = Process( # target=launch_app, args=[f"http://localhost:{args.server_port}"] diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 3a374eb5e2..917ac870bf 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -5,13 +5,11 @@ from datetime import datetime as dt import json import sys -from apps.shark_studio.api.utils import ( - get_available_devices, -) from apps.shark_studio.api.llm import ( llm_model_map, LanguageModel, ) +import apps.shark_studio.web.utils.globals as global_obj def user(message, history): @@ -186,7 +184,7 @@ def view_json_file(file_obj): choices=model_choices, allow_custom_value=True, ) - supported_devices = get_available_devices() + supported_devices = global_obj.get_device_list() enabled = True if len(supported_devices) == 0: supported_devices = ["cpu-task"] @@ -240,9 +238,7 @@ def view_json_file(file_obj): with gr.Row(visible=False): with gr.Group(): - config_file = gr.File( - label="Upload sharding configuration", visible=False - ) + config_file = gr.File(label="Upload sharding configuration", visible=False) json_view_button = gr.Button("View as JSON", visible=False) json_view = gr.JSON(visible=False) json_view_button.click( diff --git a/apps/shark_studio/web/ui/common_events.py b/apps/shark_studio/web/ui/common_events.py index 37555ed7ee..7dda8ba268 100644 --- a/apps/shark_studio/web/ui/common_events.py +++ b/apps/shark_studio/web/ui/common_events.py @@ -7,49 +7,61 @@ # Answers HTML to show the most frequent tags used when a LoRA was trained, # taken from the metadata of its .safetensors file. -def lora_changed(lora_file): +def lora_changed(lora_files): # tag frequency percentage, that gets maximum amount of the staring hue TAG_COLOR_THRESHOLD = 0.55 # tag frequency percentage, above which a tag is displayed TAG_DISPLAY_THRESHOLD = 0.65 # template for the html used to display a tag - TAG_HTML_TEMPLATE = '{tag}' - - if lora_file == "None": - return ["
No LoRA selected
"] - elif not lora_file.lower().endswith(".safetensors"): - return [ - "
Only metadata queries for .safetensors files are currently supported
" - ] - else: - metadata = get_lora_metadata(lora_file) - if metadata: - frequencies = metadata["frequencies"] - return [ - "".join( + TAG_HTML_TEMPLATE = ( + '{tag}' + ) + output = [] + for lora_file in lora_files: + if lora_file == "": + output.extend(["
No LoRA selected
"]) + elif not lora_file.lower().endswith(".safetensors"): + output.extend( + [ + "
Only metadata queries for .safetensors files are currently supported
" + ] + ) + else: + metadata = get_lora_metadata(lora_file) + if metadata: + frequencies = metadata["frequencies"] + output.extend( [ - f'
Trained against weights in: {metadata["model"]}
' - ] - + [ - TAG_HTML_TEMPLATE.format( - color=hsl_color( - (tag[1] - TAG_COLOR_THRESHOLD) - / (1 - TAG_COLOR_THRESHOLD), - start=HSLHue.RED, - end=HSLHue.GREEN, - ), - tag=tag[0], + "".join( + [ + f'
Trained against weights in: {metadata["model"]}
' + ] + + [ + TAG_HTML_TEMPLATE.format( + color=hsl_color( + (tag[1] - TAG_COLOR_THRESHOLD) + / (1 - TAG_COLOR_THRESHOLD), + start=HSLHue.RED, + end=HSLHue.GREEN, + ), + tag=tag[0], + ) + for tag in frequencies + if tag[1] > TAG_DISPLAY_THRESHOLD + ], ) - for tag in frequencies - if tag[1] > TAG_DISPLAY_THRESHOLD - ], + ] ) - ] - elif metadata is None: - return [ - "
This LoRA does not publish tag frequency metadata
" - ] - else: - return [ - "
This LoRA has empty tag frequency metadata, or we could not parse it
" - ] + elif metadata is None: + output.extend( + [ + "
This LoRA does not publish tag frequency metadata
" + ] + ) + else: + output.extend( + [ + "
This LoRA has empty tag frequency metadata, or we could not parse it
" + ] + ) + return output diff --git a/apps/shark_studio/web/ui/outputgallery.py b/apps/shark_studio/web/ui/outputgallery.py index dd58541aae..a3de6f7b57 100644 --- a/apps/shark_studio/web/ui/outputgallery.py +++ b/apps/shark_studio/web/ui/outputgallery.py @@ -6,7 +6,7 @@ from PIL import Image from apps.shark_studio.modules.shared_cmd_opts import cmd_opts -from apps.shark_studio.api.utils import ( +from apps.shark_studio.web.utils.file_utils import ( get_generated_imgs_path, get_generated_imgs_todays_subdir, ) @@ -22,8 +22,7 @@ def outputgallery_filenames(subdir) -> list[str]: new_dir_path = os.path.join(output_dir, subdir) if os.path.exists(new_dir_path): filenames = [ - glob.glob(new_dir_path + "/" + ext) - for ext in ("*.png", "*.jpg", "*.jpeg") + glob.glob(new_dir_path + "/" + ext) for ext in ("*.png", "*.jpg", "*.jpeg") ] return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True) @@ -52,11 +51,7 @@ def output_subdirs() -> list[str]: [path for path in relative_paths if path.isnumeric()], reverse=True ) result_paths = generated_paths + sorted( - [ - path - for path in relative_paths - if (not path.isnumeric()) and path != "." - ] + [path for path in relative_paths if (not path.isnumeric()) and path != "."] ) return result_paths @@ -184,9 +179,7 @@ def on_image_columns_change(columns): def on_select_subdir(subdir) -> list: # evt.value is the subdirectory name new_images = outputgallery_filenames(subdir) - new_label = ( - f"{len(new_images)} images in {os.path.join(output_dir, subdir)}" - ) + new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)}" return [ new_images, gr.Gallery( @@ -223,8 +216,7 @@ def on_refresh(current_subdir: str) -> list: ) new_images = outputgallery_filenames(new_subdir) new_label = ( - f"{len(new_images)} images in " - f"{os.path.join(output_dir, new_subdir)}" + f"{len(new_images)} images in " f"{os.path.join(output_dir, new_subdir)}" ) return [ @@ -234,9 +226,7 @@ def on_refresh(current_subdir: str) -> list: ), refreshed_subdirs, new_images, - gr.Gallery( - value=new_images, label=new_label, visible=len(new_images) > 0 - ), + gr.Gallery(value=new_images, label=new_label, visible=len(new_images) > 0), gr.Image( label=new_label, visible=len(new_images) == 0, diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index f26c7967e3..6cc0ce035f 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -1,35 +1,26 @@ import os -import time -import gradio as gr -import PIL import json -import sys - -from math import ceil +import gradio as gr +import numpy as np from inspect import signature from PIL import Image from pathlib import Path from datetime import datetime as dt from gradio.components.image_editor import ( - Brush, - Eraser, EditorValue, ) - -from apps.shark_studio.api.utils import ( - get_available_devices, +from apps.shark_studio.web.utils.file_utils import ( get_generated_imgs_path, get_checkpoints_path, get_checkpoints, + get_configs_path, ) from apps.shark_studio.api.sd import ( sd_model_map, - shark_sd_fn, + shark_sd_fn_dict_input, cancel_sd, ) from apps.shark_studio.api.controlnet import ( - preprocessor_model_map, - PreprocessorModel, cnet_preview, ) from apps.shark_studio.modules.schedulers import ( @@ -44,46 +35,167 @@ nodlogo_loc, ) from apps.shark_studio.web.utils.state import ( - get_generation_text_info, status_label, ) from apps.shark_studio.web.ui.common_events import lora_changed +from apps.shark_studio.modules import logger +import apps.shark_studio.web.utils.globals as global_obj + +sd_default_models = [ + "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", + "stabilityai/stable-diffusion-2-1-base", + "stabilityai/stable-diffusion-2-1", + "stabilityai/stable-diffusion-xl-1.0", + "stabilityai/sdxl-turbo", +] -def view_json_file(file_obj): +def view_json_file(file_path): content = "" - with open(file_obj.name, "r") as fopen: + with open(file_path, "r") as fopen: content = fopen.read() return content -max_controlnets = 3 -max_loras = 5 +def submit_to_cnet_config( + stencil: str, + preprocessed_hint: str, + cnet_strength: int, + control_mode: str, + curr_config: dict, +): + if any(i in [None, ""] for i in [stencil, preprocessed_hint]): + return gr.update() + if curr_config is not None: + if "controlnets" in curr_config: + curr_config["controlnets"]["control_mode"] = control_mode + curr_config["controlnets"]["model"].append(stencil) + curr_config["controlnets"]["hint"].append(preprocessed_hint) + curr_config["controlnets"]["strength"].append(cnet_strength) + return curr_config + + cnet_map = {} + cnet_map["controlnets"] = { + "control_mode": control_mode, + "model": [stencil], + "hint": [preprocessed_hint], + "strength": [cnet_strength], + } + return cnet_map -def show_loras(k): - k = int(k) - return gr.State( - [gr.Dropdown(visible=True)] * k - + [gr.Dropdown(visible=False, value="None")] * (max_loras - k) - ) +def update_embeddings_json(embedding): + return {"embeddings": [embedding]} + + +def submit_to_main_config(input_cfg: dict, main_cfg: dict): + if main_cfg in [None, "", {}]: + return input_cfg + + for base_key in input_cfg: + main_cfg[base_key] = input_cfg[base_key] + return main_cfg + + +def pull_sd_configs( + prompt, + negative_prompt, + sd_init_image, + height, + width, + steps, + strength, + guidance_scale, + seed, + batch_count, + batch_size, + scheduler, + base_model_id, + custom_weights, + custom_vae, + precision, + device, + ondemand, + repeatable_seeds, + resample_type, + controlnets, + embeddings, +): + sd_args = locals() + sd_cfg = {} + for arg in sd_args: + if arg in [ + "prompt", + "negative_prompt", + "sd_init_image", + ]: + sd_cfg[arg] = [sd_args[arg]] + elif arg in ["controlnets", "embeddings"]: + if isinstance(arg, dict): + sd_cfg[arg] = json.loads(sd_args[arg]) + else: + sd_cfg[arg] = {} + else: + sd_cfg[arg] = sd_args[arg] + return sd_cfg -def show_controlnets(k): - k = int(k) +def load_sd_cfg(sd_json: dict, load_sd_config: str): + new_sd_config = json.loads(view_json_file(load_sd_config)) + if sd_json: + for key in new_sd_config: + sd_json[key] = new_sd_config[key] + else: + sd_json = new_sd_config + for i in sd_json["sd_init_image"]: + if i is not None: + if os.path.isfile(i): + sd_image = [Image.open(i, mode="r")] + else: + sd_image = None + return [ - gr.State( - [ - [gr.Row(visible=True, render=True)] * k - + [gr.Row(visible=False)] * (max_controlnets - k) - ] - ), - gr.State([None] * k), - gr.State([None] * k), - gr.State([None] * k), + sd_json["prompt"][0], + sd_json["negative_prompt"][0], + sd_image, + sd_json["height"], + sd_json["width"], + sd_json["steps"], + sd_json["strength"], + sd_json["guidance_scale"], + sd_json["seed"], + sd_json["batch_count"], + sd_json["batch_size"], + sd_json["scheduler"], + sd_json["base_model_id"], + sd_json["custom_weights"], + sd_json["custom_vae"], + sd_json["precision"], + sd_json["device"], + sd_json["ondemand"], + sd_json["repeatable_seeds"], + sd_json["resample_type"], + sd_json["controlnets"], + sd_json["embeddings"], + sd_json, ] +def save_sd_cfg(config: dict, save_name: str): + if os.path.exists(save_name): + filepath = save_name + elif cmd_opts.configs_path: + filepath = os.path.join(cmd_opts.configs_path, save_name) + else: + filepath = os.path.join(get_configs_path(), save_name) + if ".json" not in filepath: + filepath += ".json" + with open(filepath, mode="w") as f: + f.write(json.dumps(config)) + return "..." + + def create_canvas(width, height): data = Image.fromarray( np.zeros( @@ -94,110 +206,27 @@ def create_canvas(width, height): ) img_dict = { "background": data, - "layers": [data], + "layers": [], "composite": None, } return EditorValue(img_dict) def import_original(original_img, width, height): - resized_img, _, _ = resize_stencil(original_img, width, height) - img_dict = { - "background": resized_img, - "layers": [resized_img], - "composite": None, - } - return gr.ImageEditor( - value=EditorValue(img_dict), - crop_size=(width, height), - ) - - -def update_cn_input( - model, - width, - height, - stencils, - images, - preprocessed_hints, -): - if model == None: - stencils[index] = None - images[index] = None - preprocessed_hints[index] = None - return [ - gr.update(), - gr.update(), - gr.update(), - gr.update(), - gr.update(), - gr.update(), - stencils, - images, - preprocessed_hints, - ] - elif model == "scribble": - return [ - gr.ImageEditor( - visible=True, - interactive=True, - show_label=False, - image_mode="RGB", - type="pil", - brush=Brush( - colors=["#000000"], - color_mode="fixed", - default_size=5, - ), - ), - gr.Image( - visible=True, - show_label=False, - interactive=True, - show_download_button=False, - ), - gr.Slider(visible=True, label="Canvas Width"), - gr.Slider(visible=True, label="Canvas Height"), - gr.Button(visible=True), - gr.Button(visible=False), - stencils, - images, - preprocessed_hints, - ] + if original_img is None: + resized_img = create_canvas(width, height) + return resized_img else: - return [ - gr.ImageEditor( - visible=True, - interactive=True, - show_label=False, - image_mode="RGB", - type="pil", - ), - gr.Image( - visible=True, - show_label=False, - interactive=True, - show_download_button=False, - ), - gr.Slider(visible=True, label="Canvas Width"), - gr.Slider(visible=True, label="Canvas Height"), - gr.Button(visible=True), - gr.Button(visible=False), - stencils, - images, - preprocessed_hints, - ] + resized_img, _, _ = resize_stencil(original_img, width, height) + img_dict = { + "background": resized_img, + "layers": [], + "composite": None, + } + return EditorValue(img_dict) -sd_fn_inputs = [] -sd_fn_sig = signature(shark_sd_fn).replace() -for i in sd_fn_sig.parameters: - sd_fn_inputs.append(i) - with gr.Blocks(title="Stable Diffusion") as sd_element: - # Get a list of arguments needed for the API call, then - # initialize an empty list that will manage the corresponding - # gradio values. with gr.Row(elem_id="ui_title"): nod_logo = Image.open(nodlogo_loc) with gr.Row(variant="compact", equal_height=True): @@ -216,33 +245,33 @@ def update_cn_input( ) with gr.Column(elem_id="ui_body"): with gr.Row(): - with gr.Column(scale=1, min_width=600): + with gr.Column(scale=2, min_width=600): with gr.Row(equal_height=True): with gr.Column(scale=3): sd_model_info = ( f"Checkpoint Path: {str(get_checkpoints_path())}" ) - sd_base = gr.Dropdown( + base_model_id = gr.Dropdown( label="Base Model", info="Select or enter HF model ID", elem_id="custom_model", value="stabilityai/stable-diffusion-2-1-base", - choices=sd_model_map.keys(), + choices=sd_default_models, ) # base_model_id - sd_custom_weights = gr.Dropdown( - label="Weights (Optional)", + custom_weights = gr.Dropdown( + label="Custom Weights", info="Select or enter HF model ID", elem_id="custom_model", value="None", allow_custom_value=True, - choices=get_checkpoints(sd_base), + choices=["None"] + get_checkpoints(base_model_id), ) # with gr.Column(scale=2): - sd_vae_info = ( - str(get_checkpoints_path("vae")) - ).replace("\\", "\n\\") + sd_vae_info = (str(get_checkpoints_path("vae"))).replace( + "\\", "\n\\" + ) sd_vae_info = f"VAE Path: {sd_vae_info}" - sd_custom_vae = gr.Dropdown( + custom_vae = gr.Dropdown( label=f"Custom VAE Models", info=sd_vae_info, elem_id="custom_model", @@ -253,28 +282,31 @@ def update_cn_input( allow_custom_value=True, scale=1, ) - with gr.Column(scale=1): - save_sd_config = gr.Button( - value="Save Config", size="sm" - ) - clear_sd_config = gr.ClearButton( - value="Clear Config", size="sm" - ) - load_sd_config = gr.FileExplorer( - label="Load Config", - root=os.path.basename("./configs"), - ) - + with gr.Row(): + ondemand = gr.Checkbox( + value=cmd_opts.lowvram, + label="Low VRAM", + interactive=True, + ) + precision = gr.Radio( + label="Precision", + value=cmd_opts.precision, + choices=[ + "fp16", + "fp32", + ], + visible=True, + ) with gr.Group(elem_id="prompt_box_outer"): prompt = gr.Textbox( label="Prompt", - value=cmd_opts.prompts[0], + value=cmd_opts.prompt[0], lines=2, elem_id="prompt_box", ) negative_prompt = gr.Textbox( label="Negative Prompt", - value=cmd_opts.negative_prompts[0], + value=cmd_opts.negative_prompt[0], lines=2, elem_id="negative_prompt_box", ) @@ -287,41 +319,39 @@ def update_cn_input( height=300, interactive=True, ) - with gr.Accordion( - label="Embeddings options", open=False, render=True - ): - sd_lora_info = ( - str(get_checkpoints_path("loras")) - ).replace("\\", "\n\\") - num_loras = gr.Slider( - 1, max_loras, value=1, step=1, label="LoRA Count" + with gr.Accordion(label="Embeddings options", open=True, render=True): + sd_lora_info = (str(get_checkpoints_path("loras"))).replace( + "\\", "\n\\" ) - loras = gr.State([]) - for i in range(max_loras): - with gr.Row(): - lora_opt = gr.Dropdown( - allow_custom_value=True, - label=f"Standalone LoRA Weights", - info=sd_lora_info, - elem_id="lora_weights", - value="None", - choices=["None"] + get_checkpoints("lora"), - ) - with gr.Row(): - lora_tags = gr.HTML( - value="
No LoRA selected
", - elem_classes="lora-tags", - ) - gr.on( - triggers=[lora_opt.change], - fn=lora_changed, - inputs=[lora_opt], - outputs=[lora_tags], - queue=True, + with gr.Row(): + embeddings_config = gr.JSON(min_width=50, scale=1) + lora_opt = gr.Dropdown( + allow_custom_value=True, + label=f"Standalone LoRA Weights", + info=sd_lora_info, + elem_id="lora_weights", + value=None, + multiselect=True, + choices=[] + get_checkpoints("lora"), + scale=2, ) - loras.value.append(lora_opt) - - num_loras.change(show_loras, [num_loras], [loras]) + lora_tags = gr.HTML( + value="
No LoRA selected
", + elem_classes="lora-tags", + ) + gr.on( + triggers=[lora_opt.change], + fn=lora_changed, + inputs=[lora_opt], + outputs=[lora_tags], + queue=True, + show_progress=False, + ).then( + fn=update_embeddings_json, + inputs=[lora_opt], + outputs=[embeddings_config], + show_progress=False, + ) with gr.Accordion(label="Advanced Options", open=True): with gr.Row(): scheduler = gr.Dropdown( @@ -331,7 +361,6 @@ def update_cn_input( choices=scheduler_model_map.keys(), allow_custom_value=False, ) - with gr.Row(): height = gr.Slider( 384, 768, @@ -397,20 +426,6 @@ def update_cn_input( step=0.1, label="CFG Scale", ) - ondemand = gr.Checkbox( - value=cmd_opts.lowvram, - label="Low VRAM", - interactive=True, - ) - precision = gr.Radio( - label="Precision", - value=cmd_opts.precision, - choices=[ - "fp16", - "fp32", - ], - visible=True, - ) with gr.Row(): seed = gr.Textbox( value=cmd_opts.seed, @@ -420,159 +435,149 @@ def update_cn_input( device = gr.Dropdown( elem_id="device", label="Device", - value=get_available_devices()[0], - choices=get_available_devices(), + value=global_obj.get_device_list()[0], + choices=global_obj.get_device_list(), allow_custom_value=False, ) with gr.Accordion( - label="Controlnet Options", open=False, render=False + label="Controlnet Options", + open=False, + visible=False, ): - sd_cnet_info = ( - str(get_checkpoints_path("controlnet")) - ).replace("\\", "\n\\") - num_cnets = gr.Slider( - 0, - max_controlnets, - value=0, - step=1, - label="Controlnet Count", - ) - cnet_rows = [] - stencils = gr.State([]) - images = gr.State([]) preprocessed_hints = gr.State([]) - control_mode = gr.Radio( - choices=["Prompt", "Balanced", "Controlnet"], - value="Balanced", - label="Control Mode", - ) - - for i in range(max_controlnets): - with gr.Row(visible=False) as cnet_row: - with gr.Column(): - cnet_gen = gr.Button( - value="Preprocess controlnet input", - ) - cnet_model = gr.Dropdown( - allow_custom_value=True, - label=f"Controlnet Model", - info=sd_cnet_info, - elem_id="lora_weights", - value="None", - choices=[ - "None", - "canny", - "openpose", - "scribble", - "zoedepth", - ] - + get_checkpoints("controlnet"), - ) + with gr.Column(): + sd_cnet_info = ( + str(get_checkpoints_path("controlnet")) + ).replace("\\", "\n\\") + with gr.Row(): + cnet_config = gr.JSON() + with gr.Column(): + clear_config = gr.ClearButton( + value="Clear Controlnet Config", + size="sm", + components=cnet_config, + ) + control_mode = gr.Radio( + choices=["Prompt", "Balanced", "Controlnet"], + value="Balanced", + label="Control Mode", + ) + with gr.Row(): + with gr.Column(scale=1): + cnet_model = gr.Dropdown( + allow_custom_value=True, + label=f"Controlnet Model", + info=sd_cnet_info, + value="None", + choices=[ + "None", + "canny", + "openpose", + "scribble", + "zoedepth", + ] + + get_checkpoints("controlnet"), + ) + cnet_strength = gr.Slider( + label="Controlnet Strength", + minimum=0, + maximum=100, + value=50, + step=1, + ) + with gr.Row(): canvas_width = gr.Slider( label="Canvas Width", minimum=256, maximum=1024, value=512, - step=1, - visible=False, + step=8, ) canvas_height = gr.Slider( label="Canvas Height", minimum=256, maximum=1024, value=512, - step=1, - visible=False, - ) - make_canvas = gr.Button( - value="Make Canvas!", - visible=False, + step=8, ) - use_input_img = gr.Button( - value="Use Original Image", - visible=False, - ) - cnet_input = gr.ImageEditor( - visible=True, - image_mode="RGB", - interactive=True, - show_label=True, - label="Input Image", - type="pil", + make_canvas = gr.Button( + value="Make Canvas!", ) + use_input_img = gr.Button( + value="Use Original Image", + size="sm", + ) + cnet_input = gr.Image( + value=None, + type="pil", + image_mode="RGB", + interactive=True, + ) + with gr.Column(scale=1): cnet_output = gr.Image( value=None, visible=True, label="Preprocessed Hint", - interactive=True, + interactive=False, show_label=True, ) - use_input_img.click( - import_original, - [sd_init_image, canvas_width, canvas_height], - [cnet_input], + cnet_gen = gr.Button( + value="Preprocess controlnet input", ) - cnet_model.change( - fn=update_cn_input, - inputs=[ - cnet_model, - canvas_width, - canvas_height, - stencils, - images, - preprocessed_hints, - ], - outputs=[ - cnet_input, - cnet_output, - canvas_width, - canvas_height, - make_canvas, - use_input_img, - stencils, - images, - preprocessed_hints, - ], - ) - make_canvas.click( - create_canvas, - [canvas_width, canvas_height], - [ - cnet_input, - ], + use_result = gr.Button( + "Submit", + size="sm", ) - gr.on( - triggers=[cnet_gen.click], - fn=cnet_preview, - inputs=[ - cnet_model, - cnet_input, - stencils, - images, - preprocessed_hints, - ], - outputs=[ - cnet_output, - stencils, - images, - preprocessed_hints, - ], - ) - cnet_rows.value.append(cnet_row) - - num_cnets.change( - show_controlnets, - [num_cnets], - [cnet_rows, stencils, images, preprocessed_hints], + use_input_img.click( + fn=import_original, + inputs=[ + sd_init_image, + canvas_width, + canvas_height, + ], + outputs=[cnet_input], + queue=False, + ) + make_canvas.click( + fn=create_canvas, + inputs=[canvas_width, canvas_height], + outputs=[cnet_input], + queue=False, + ) + cnet_gen.click( + fn=cnet_preview, + inputs=[ + cnet_model, + cnet_input, + ], + outputs=[ + cnet_output, + preprocessed_hints, + ], ) - with gr.Column(scale=1, min_width=600): + use_result.click( + fn=submit_to_cnet_config, + inputs=[ + cnet_model, + cnet_output, + cnet_strength, + control_mode, + cnet_config, + ], + outputs=[ + cnet_config, + ], + queue=False, + ) + with gr.Column(scale=3, min_width=600): with gr.Group(): sd_gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery", columns=2, - object_fit="contain", + object_fit="fit", + preview=True, ) std_output = gr.Textbox( value=f"{sd_model_info}\n" @@ -582,6 +587,7 @@ def update_cn_input( elem_id="std_output", show_label=False, ) + sd_element.load(logger.read_sd_logs, None, std_output, every=1) sd_status = gr.Textbox(visible=False) with gr.Row(): stable_diffusion = gr.Button("Generate Image(s)") @@ -591,11 +597,75 @@ def update_cn_input( inputs=[], outputs=[seed], queue=False, + show_progress=False, ) stop_batch = gr.Button("Stop Batch") + with gr.Group(): + with gr.Column(scale=3): + sd_json = gr.JSON( + value=view_json_file( + os.path.join( + get_configs_path(), + "default_sd_config.json", + ) + ) + ) + with gr.Column(scale=1): + clear_sd_config = gr.ClearButton( + value="Clear Config", size="sm", components=sd_json + ) + with gr.Row(): + save_sd_config = gr.Button(value="Save Config", size="sm") + sd_config_name = gr.Textbox( + value="Config Name", + info="Name of the file this config will be saved to.", + interactive=True, + ) + load_sd_config = gr.FileExplorer( + label="Load Config", + file_count="single", + root=cmd_opts.configs_path + if cmd_opts.configs_path + else get_configs_path(), + height=75, + ) + load_sd_config.change( + fn=load_sd_cfg, + inputs=[sd_json, load_sd_config], + outputs=[ + prompt, + negative_prompt, + sd_init_image, + height, + width, + steps, + strength, + guidance_scale, + seed, + batch_count, + batch_size, + scheduler, + base_model_id, + custom_weights, + custom_vae, + precision, + device, + ondemand, + repeatable_seeds, + resample_type, + cnet_config, + embeddings_config, + sd_json, + ], + ) + save_sd_config.click( + fn=save_sd_cfg, + inputs=[sd_json, sd_config_name], + outputs=[sd_config_name], + ) - kwargs = dict( - fn=shark_sd_fn, + pull_kwargs = dict( + fn=pull_sd_configs, inputs=[ prompt, negative_prompt, @@ -609,28 +679,20 @@ def update_cn_input( batch_count, batch_size, scheduler, - sd_base, - sd_custom_weights, - sd_custom_vae, + base_model_id, + custom_weights, + custom_vae, precision, device, - loras, ondemand, repeatable_seeds, resample_type, - control_mode, - stencils, - images, - preprocessed_hints, + cnet_config, + embeddings_config, ], outputs=[ - sd_gallery, - std_output, - sd_status, - stencils, - images, + sd_json, ], - show_progress="minimal", ) status_kwargs = dict( @@ -639,11 +701,22 @@ def update_cn_input( outputs=sd_status, ) - prompt_submit = prompt.submit(**status_kwargs).then(**kwargs) - neg_prompt_submit = negative_prompt.submit(**status_kwargs).then( - **kwargs + gen_kwargs = dict( + fn=shark_sd_fn_dict_input, + inputs=[sd_json], + outputs=[ + sd_gallery, + sd_status, + ], + ) + + prompt_submit = prompt.submit(**status_kwargs).then(**pull_kwargs) + neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(**pull_kwargs) + generate_click = ( + stable_diffusion.click(**status_kwargs) + .then(**pull_kwargs) + .then(**gen_kwargs) ) - generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs) stop_batch.click( fn=cancel_sd, cancels=[prompt_submit, neg_prompt_submit, generate_click], diff --git a/apps/shark_studio/web/ui/utils.py b/apps/shark_studio/web/ui/utils.py index ba62e5adc0..34a94fa014 100644 --- a/apps/shark_studio/web/ui/utils.py +++ b/apps/shark_studio/web/ui/utils.py @@ -6,9 +6,7 @@ def resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) return os.path.join(base_path, relative_path) diff --git a/apps/shark_studio/web/utils/file_utils.py b/apps/shark_studio/web/utils/file_utils.py new file mode 100644 index 0000000000..cae925f5e2 --- /dev/null +++ b/apps/shark_studio/web/utils/file_utils.py @@ -0,0 +1,83 @@ +import os +import sys +import glob +from datetime import datetime as dt +from pathlib import Path + +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + +checkpoints_filetypes = ( + "*.ckpt", + "*.safetensors", +) + + +def safe_name(name): + return name.replace("/", "_").replace("-", "_") + + +def get_path_stem(path): + path = Path(path) + return path.stem + + +def get_resource_path(relative_path): + """Get absolute path to resource, works for dev and for PyInstaller""" + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) + result = Path(os.path.join(base_path, relative_path)).resolve(strict=False) + return result + + +def get_configs_path() -> Path: + configs = get_resource_path(os.path.join("..", "configs")) + if not os.path.exists(configs): + os.mkdir(configs) + return Path(get_resource_path("../configs")) + + +def get_generated_imgs_path() -> Path: + return Path( + cmd_opts.output_dir + if cmd_opts.output_dir + else get_resource_path("../generated_imgs") + ) + + +def get_generated_imgs_todays_subdir() -> str: + return dt.now().strftime("%Y%m%d") + + +def create_checkpoint_folders(): + dir = ["vae", "lora", "../vmfb"] + if not cmd_opts.ckpt_dir: + dir.insert(0, "models") + else: + if not os.path.isdir(cmd_opts.ckpt_dir): + sys.exit( + f"Invalid --ckpt_dir argument, " + f"{cmd_opts.ckpt_dir} folder does not exists." + ) + for root in dir: + Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True) + + +def get_checkpoints_path(model=""): + return get_resource_path(f"../models/{model}") + + +def get_checkpoints(model="models"): + ckpt_files = [] + file_types = checkpoints_filetypes + if model == "lora": + file_types = file_types + ("*.pt", "*.bin") + for extn in file_types: + files = [ + os.path.basename(x) + for x in glob.glob(os.path.join(get_checkpoints_path(model), extn)) + ] + ckpt_files.extend(files) + return sorted(ckpt_files, key=str.casefold) + + +def get_checkpoint_pathfile(checkpoint_name, model="models"): + return os.path.join(get_checkpoints_path(model), checkpoint_name) diff --git a/apps/shark_studio/web/utils/globals.py b/apps/shark_studio/web/utils/globals.py index 0b5f54636a..977df7304a 100644 --- a/apps/shark_studio/web/utils/globals.py +++ b/apps/shark_studio/web/utils/globals.py @@ -1,4 +1,5 @@ import gc +from ...api.utils import get_available_devices """ The global objects include SD pipeline and config. @@ -9,11 +10,18 @@ def _init(): global _sd_obj - global _config_obj + global _devices + global _pipe_kwargs + global _prep_kwargs + global _gen_kwargs global _schedulers _sd_obj = None - _config_obj = None + _devices = None + _pipe_kwargs = None + _prep_kwargs = None + _gen_kwargs = None _schedulers = None + set_devices() def set_sd_obj(value): @@ -21,6 +29,11 @@ def set_sd_obj(value): _sd_obj = value +def set_devices(): + global _devices + _devices = get_available_devices() + + def set_sd_scheduler(key): global _sd_obj _sd_obj.scheduler = _schedulers[key] @@ -31,9 +44,19 @@ def set_sd_status(value): _sd_obj.status = value -def set_cfg_obj(value): - global _config_obj - _config_obj = value +def set_pipe_kwargs(value): + global _pipe_kwargs + _pipe_kwargs = value + + +def set_prep_kwargs(value): + global _prep_kwargs + _prep_kwargs = value + + +def set_gen_kwargs(value): + global _gen_kwargs + _gen_kwargs = value def set_schedulers(value): @@ -46,14 +69,29 @@ def get_sd_obj(): return _sd_obj +def get_device_list(): + global _devices + return _devices + + def get_sd_status(): global _sd_obj return _sd_obj.status -def get_cfg_obj(): - global _config_obj - return _config_obj +def get_pipe_kwargs(): + global _pipe_kwargs + return _pipe_kwargs + + +def get_prep_kwargs(): + global _prep_kwargs + return _prep_kwargs + + +def get_gen_kwargs(): + global _gen_kwargs + return _gen_kwargs def get_scheduler(key): @@ -63,12 +101,15 @@ def get_scheduler(key): def clear_cache(): global _sd_obj - global _config_obj + global _pipe_kwargs + global _prep_kwargs + global _gen_kwargs global _schedulers del _sd_obj - del _config_obj del _schedulers gc.collect() _sd_obj = None - _config_obj = None + _pipe_kwargs = None + _prep_kwargs = None + _gen_kwargs = None _schedulers = None diff --git a/apps/shark_studio/web/utils/metadata/csv_metadata.py b/apps/shark_studio/web/utils/metadata/csv_metadata.py index d617e802bf..d515234083 100644 --- a/apps/shark_studio/web/utils/metadata/csv_metadata.py +++ b/apps/shark_studio/web/utils/metadata/csv_metadata.py @@ -29,9 +29,7 @@ def parse_csv(image_filename: str): has_header = csv.Sniffer().has_header(csv_file.read(2048)) csv_file.seek(0) - reader = ( - csv.DictReader(csv_file) if has_header else csv.reader(csv_file) - ) + reader = csv.DictReader(csv_file) if has_header else csv.reader(csv_file) matches = [ # we rely on humanize and humanizable to work out the parsing of the individual .csv rows diff --git a/apps/shark_studio/web/utils/metadata/format.py b/apps/shark_studio/web/utils/metadata/format.py index f097dab54f..308d9f8e8b 100644 --- a/apps/shark_studio/web/utils/metadata/format.py +++ b/apps/shark_studio/web/utils/metadata/format.py @@ -92,15 +92,11 @@ def compact(metadata: dict) -> dict: result["Hires resize"] = f"{hires_y}x{hires_x}" # remove VAE if it exists and is empty - if (result.keys() & {"VAE"}) and ( - not result["VAE"] or result["VAE"] == "None" - ): + if (result.keys() & {"VAE"}) and (not result["VAE"] or result["VAE"] == "None"): result.pop("VAE") # remove LoRA if it exists and is empty - if (result.keys() & {"LoRA"}) and ( - not result["LoRA"] or result["LoRA"] == "None" - ): + if (result.keys() & {"LoRA"}) and (not result["LoRA"] or result["LoRA"] == "None"): result.pop("LoRA") return result diff --git a/apps/shark_studio/web/utils/metadata/png_metadata.py b/apps/shark_studio/web/utils/metadata/png_metadata.py index cffc385ab7..72f663f246 100644 --- a/apps/shark_studio/web/utils/metadata/png_metadata.py +++ b/apps/shark_studio/web/utils/metadata/png_metadata.py @@ -1,6 +1,6 @@ import re from pathlib import Path -from apps.shark_studio.api.utils import ( +from apps.shark_studio.web.utils.file_utils import ( get_checkpoint_pathfile, ) from apps.shark_studio.api.sd import ( @@ -66,9 +66,7 @@ def parse_generation_parameters(x: str): return res -def try_find_model_base_from_png_metadata( - file: str, folder: str = "models" -) -> str: +def try_find_model_base_from_png_metadata(file: str, folder: str = "models") -> str: custom = "" # Remove extension from file info @@ -101,16 +99,13 @@ def find_model_from_png_metadata( # No matching model was found if not png_custom and not png_hf_id: print( - "Import PNG info: Unable to find a matching model for %s" - % model_file + "Import PNG info: Unable to find a matching model for %s" % model_file ) return png_custom, png_hf_id -def find_vae_from_png_metadata( - key: str, metadata: dict[str, str | int] -) -> str: +def find_vae_from_png_metadata(key: str, metadata: dict[str, str | int]) -> str: vae_custom = "" if key in metadata: diff --git a/apps/shark_studio/web/utils/state.py b/apps/shark_studio/web/utils/state.py index 626d4ce53f..133c8fd82f 100644 --- a/apps/shark_studio/web/utils/state.py +++ b/apps/shark_studio/web/utils/state.py @@ -3,7 +3,6 @@ def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1): - print(f"Getting status label for {tab_name}") if batch_index < batch_count: bs = f"x{batch_size}" if batch_size > 1 else "" return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}" @@ -18,8 +17,7 @@ def get_generation_text_info(seeds, device): text_output = f"prompt={cfg_dump['prompts']}" text_output += f"\nnegative prompt={cfg_dump['negative_prompts']}" text_output += ( - f"\nmodel_id={cfg_dump['hf_model_id']}, " - f"ckpt_loc={cfg_dump['ckpt_loc']}" + f"\nmodel_id={cfg_dump['hf_model_id']}, " f"ckpt_loc={cfg_dump['ckpt_loc']}" ) text_output += f"\nscheduler={cfg_dump['scheduler']}, " f"device={device}" text_output += ( diff --git a/apps/shark_studio/web/utils/tmp_configs.py b/apps/shark_studio/web/utils/tmp_configs.py index 3e6ba46bfe..4415276ea3 100644 --- a/apps/shark_studio/web/utils/tmp_configs.py +++ b/apps/shark_studio/web/utils/tmp_configs.py @@ -7,9 +7,7 @@ def clear_tmp_mlir(): cleanup_start = time() - print( - "Clearing .mlir temporary files from a prior run. This may take some time..." - ) + print("Clearing .mlir temporary files from a prior run. This may take some time...") mlir_files = [ filename for filename in os.listdir(shark_tmp) @@ -18,9 +16,7 @@ def clear_tmp_mlir(): ] for filename in mlir_files: os.remove(shark_tmp + filename) - print( - f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds." - ) + print(f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds.") def clear_tmp_imgs(): diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index ca6a12c45b..f5f9557744 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -64,6 +64,14 @@ def get_iree_device_args(device, extra_args=[]): return get_iree_rocm_args(device_num=device_num, extra_args=extra_args) return [] +def get_iree_target_triple(device): + args = get_iree_device_args(device) + for flag in args: + if "triple" in flag.split("-"): + triple = flag.split("=") + return triple + return "" + def clean_device_info(raw_device): # return appropriate device and device_id for consumption by Studio pipeline @@ -105,7 +113,6 @@ def get_iree_frontend_args(frontend): # Common args to be used given any frontend or device. def get_iree_common_args(debug=False): common_args = [ - "--iree-stream-resource-max-allocation-size=4294967295", "--iree-vm-bytecode-module-strip-source-map=true", "--iree-util-zero-fill-elided-attrs", ]