diff --git a/apps/stable_diffusion/shark_sd.spec b/apps/stable_diffusion/shark_sd.spec index 0a25318086..07b4a81a24 100644 --- a/apps/stable_diffusion/shark_sd.spec +++ b/apps/stable_diffusion/shark_sd.spec @@ -19,6 +19,9 @@ a = Analysis( win_private_assemblies=False, cipher=block_cipher, noarchive=False, + module_collection_mode={ + 'gradio': 'py', # Collect gradio package as source .py files + }, ) pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) diff --git a/apps/stable_diffusion/shark_studio_imports.py b/apps/stable_diffusion/shark_studio_imports.py index 70874134d6..dfb5bf8027 100644 --- a/apps/stable_diffusion/shark_studio_imports.py +++ b/apps/stable_diffusion/shark_studio_imports.py @@ -31,6 +31,7 @@ datas += copy_metadata("sentencepiece") datas += copy_metadata("pyyaml") datas += copy_metadata("huggingface-hub") +datas += copy_metadata("gradio") datas += collect_data_files("torch") datas += collect_data_files("tokenizers") datas += collect_data_files("tiktoken") @@ -75,6 +76,7 @@ # hidden imports for pyinstaller hiddenimports = ["shark", "shark.shark_inference", "apps"] hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x] +hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x] hiddenimports += [ x for x in collect_submodules("diffusers") if "tests" not in x ] @@ -85,4 +87,4 @@ if not any(kw in x for kw in blacklist) ] hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x] -hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"] +hiddenimports += ["iree._runtime"] diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index b762b85b2c..863b6434b0 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -109,7 +109,7 @@ def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision): if "vulkan" in device: _device = args.iree_vulkan_target_triple _device = _device.replace("-", "_") - vmfb_path = Path(extended_model_name_for_vmfb + f"_{_device}.vmfb") + vmfb_path = Path(extended_model_name_for_vmfb + f"_vulkan.vmfb") if vmfb_path.exists(): shark_module = SharkInference( None, @@ -436,24 +436,48 @@ def __init__( super().__init__() self.vae = None if custom_vae == "": + print(f"Loading default vae, with target {model_id}") self.vae = AutoencoderKL.from_pretrained( model_id, subfolder="vae", low_cpu_mem_usage=low_cpu_mem_usage, ) elif not isinstance(custom_vae, dict): - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - ) + precision = "fp16" if "fp16" in custom_vae else None + print(f"Loading custom vae, with target {custom_vae}") + if os.path.exists(custom_vae): + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + else: + custom_vae = "/".join( + [ + custom_vae.split("/")[-2].split("\\")[-1], + custom_vae.split("/")[-1], + ] + ) + print("Using hub to get custom vae") + try: + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + low_cpu_mem_usage=low_cpu_mem_usage, + variant=precision, + ) + except: + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + low_cpu_mem_usage=low_cpu_mem_usage, + ) else: + print(f"Loading custom vae, with state {custom_vae}") self.vae = AutoencoderKL.from_pretrained( model_id, subfolder="vae", low_cpu_mem_usage=low_cpu_mem_usage, ) self.vae.load_state_dict(custom_vae) + self.base_vae = base_vae def forward(self, latents): image = self.vae.decode(latents / 0.13025, return_dict=False)[ @@ -465,7 +489,12 @@ def forward(self, latents): inputs = tuple(self.inputs["vae"]) # Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL # pipeline. - is_f16 = False + if not self.custom_vae: + is_f16 = False + elif "16" in self.custom_vae: + is_f16 = True + else: + is_f16 = False save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"]) if self.debug: os.makedirs(save_dir, exist_ok=True) @@ -917,11 +946,19 @@ def __init__( low_cpu_mem_usage=False, ): super().__init__() - self.unet = UNet2DConditionModel.from_pretrained( - model_id, - subfolder="unet", - low_cpu_mem_usage=low_cpu_mem_usage, - ) + try: + self.unet = UNet2DConditionModel.from_pretrained( + model_id, + subfolder="unet", + low_cpu_mem_usage=low_cpu_mem_usage, + variant="fp16", + ) + except: + self.unet = UNet2DConditionModel.from_pretrained( + model_id, + subfolder="unet", + low_cpu_mem_usage=low_cpu_mem_usage, + ) if ( args.attention_slicing is not None and args.attention_slicing != "none" diff --git a/apps/stable_diffusion/src/models/opt_params.py b/apps/stable_diffusion/src/models/opt_params.py index 5dd59b006e..f03897de6e 100644 --- a/apps/stable_diffusion/src/models/opt_params.py +++ b/apps/stable_diffusion/src/models/opt_params.py @@ -123,7 +123,10 @@ def get_clip(): return get_shark_model(bucket, model_name, iree_flags) -def get_tokenizer(subfolder="tokenizer"): +def get_tokenizer(subfolder="tokenizer", hf_model_id=None): + if hf_model_id is not None: + args.hf_model_id = hf_model_id + tokenizer = CLIPTokenizer.from_pretrained( args.hf_model_id, subfolder=subfolder ) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py index f340655e5e..ba1d52fd80 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py @@ -158,7 +158,10 @@ def generate_images( use_base_vae, cpu_scheduling, max_embeddings_multiples, + stencils, + images, resample_type, + control_mode, ): # prompts and negative prompts must be a list. if isinstance(prompts, str): diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py index 07b827a771..51a790bb41 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py @@ -18,7 +18,10 @@ KDPM2AncestralDiscreteScheduler, HeunDiscreteScheduler, ) -from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler +from apps.stable_diffusion.src.schedulers import ( + SharkEulerDiscreteScheduler, + SharkEulerAncestralDiscreteScheduler, +) from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( StableDiffusionPipeline, ) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py index 42df60af9a..a3b52793e9 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py @@ -16,7 +16,10 @@ KDPM2AncestralDiscreteScheduler, HeunDiscreteScheduler, ) -from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler +from apps.stable_diffusion.src.schedulers import ( + SharkEulerDiscreteScheduler, + SharkEulerAncestralDiscreteScheduler, +) from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( StableDiffusionPipeline, ) @@ -38,6 +41,7 @@ def __init__( EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, SharkEulerDiscreteScheduler, + SharkEulerAncestralDiscreteScheduler, DEISMultistepScheduler, DDPMScheduler, DPMSolverSinglestepScheduler, @@ -48,8 +52,10 @@ def __init__( import_mlir: bool, use_lora: str, ondemand: bool, + is_fp32_vae: bool, ): super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand) + self.is_fp32_vae = is_fp32_vae def prepare_latents( self, @@ -203,10 +209,10 @@ def generate_images( # Img latents -> PIL images. all_imgs = [] self.load_vae() - # imgs = self.decode_latents_sdxl(None) - # all_imgs.extend(imgs) for i in range(0, latents.shape[0], batch_size): - imgs = self.decode_latents_sdxl(latents[i : i + batch_size]) + imgs = self.decode_latents_sdxl( + latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae + ) all_imgs.extend(imgs) if self.ondemand: self.unload_vae() diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py index 7a6f763e24..7d3d6b1c5b 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py @@ -20,7 +20,10 @@ HeunDiscreteScheduler, ) from shark.shark_inference import SharkInference -from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler +from apps.stable_diffusion.src.schedulers import ( + SharkEulerDiscreteScheduler, + SharkEulerAncestralDiscreteScheduler, +) from apps.stable_diffusion.src.models import ( SharkifyStableDiffusionModel, get_vae, @@ -52,6 +55,7 @@ def __init__( EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, SharkEulerDiscreteScheduler, + SharkEulerAncestralDiscreteScheduler, DEISMultistepScheduler, DDPMScheduler, DPMSolverSinglestepScheduler, @@ -62,6 +66,7 @@ def __init__( import_mlir: bool, use_lora: str, ondemand: bool, + is_f32_vae: bool = False, ): self.vae = None self.text_encoder = None @@ -69,14 +74,15 @@ def __init__( self.unet = None self.unet_512 = None self.model_max_length = 77 - self.scheduler = scheduler # TODO: Implement using logging python utility. self.log = "" self.status = SD_STATE_IDLE self.sd_model = sd_model + self.scheduler = scheduler self.import_mlir = import_mlir self.use_lora = use_lora self.ondemand = ondemand + self.is_f32_vae = is_f32_vae # TODO: Find a better workaround for fetching base_model_id early # enough for CLIPTokenizer. try: @@ -202,6 +208,9 @@ def encode_prompt_sdxl( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + hf_model_id: Optional[ + str + ] = "stabilityai/stable-diffusion-xl-base-1.0", ): if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -211,7 +220,7 @@ def encode_prompt_sdxl( batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders - self.tokenizer_2 = get_tokenizer("tokenizer_2") + self.tokenizer_2 = get_tokenizer("tokenizer_2", hf_model_id) self.load_clip_sdxl() tokenizers = ( [self.tokenizer, self.tokenizer_2] @@ -332,7 +341,7 @@ def encode_prompt_sdxl( gc.collect() # TODO: Look into dtype for text_encoder_2! - prompt_embeds = prompt_embeds.to(dtype=torch.float32) + prompt_embeds = prompt_embeds.to(dtype=torch.float16) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -523,6 +532,9 @@ def produce_img_latents_sdxl( cpu_scheduling, guidance_scale, dtype, + mask=None, + masked_image_latents=None, + return_all_latents=False, ): # return None self.status = SD_STATE_IDLE @@ -533,11 +545,22 @@ def produce_img_latents_sdxl( step_start_time = time.time() timestep = torch.tensor([t]).to(dtype).detach().numpy() # expand the latents if we are doing classifier free guidance + if isinstance(latents, np.ndarray): + latents = torch.tensor(latents) latent_model_input = torch.cat([latents] * 2) latent_model_input = self.scheduler.scale_model_input( latent_model_input, t - ).to(dtype) + ) + if mask is not None and masked_image_latents is not None: + latent_model_input = torch.cat( + [ + torch.from_numpy(np.asarray(latent_model_input)), + mask, + masked_image_latents, + ], + dim=1, + ).to(dtype) noise_pred = self.unet( "forward", @@ -549,11 +572,17 @@ def produce_img_latents_sdxl( add_time_ids, guidance_scale, ), - send_to_host=False, + send_to_host=True, ) + if not isinstance(latents, torch.Tensor): + latents = torch.from_numpy(latents).to("cpu") + noise_pred = torch.from_numpy(noise_pred).to("cpu") + latents = self.scheduler.step( noise_pred, t, latents, **extra_step_kwargs, return_dict=False )[0] + latents = latents.detach().numpy() + noise_pred = noise_pred.detach().numpy() step_time = (time.time() - step_start_time) * 1000 step_time_sum += step_time @@ -569,11 +598,15 @@ def produce_img_latents_sdxl( return latents - def decode_latents_sdxl(self, latents): - latents = latents.to(torch.float32) + def decode_latents_sdxl(self, latents, is_fp32_vae): + # latents are in unet dtype here so switch if we want to use fp32 + if is_fp32_vae: + print("Casting latents to float32 for VAE") + latents = latents.to(torch.float32) images = self.vae("forward", (latents,)) images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1) images = images.cpu().permute(0, 2, 3, 1).float().numpy() + images = (images * 255).round().astype("uint8") pil_images = [Image.fromarray(image[:, :, :3]) for image in images] @@ -666,6 +699,17 @@ def from_pretrained( return cls( scheduler, sd_model, import_mlir, use_lora, ondemand, stencils ) + if cls.__name__ == "Text2ImageSDXLPipeline": + is_fp32_vae = True if "16" not in custom_vae else False + return cls( + scheduler, + sd_model, + import_mlir, + use_lora, + ondemand, + is_fp32_vae, + ) + return cls(scheduler, sd_model, import_mlir, use_lora, ondemand) # ##################################################### diff --git a/apps/stable_diffusion/src/schedulers/__init__.py b/apps/stable_diffusion/src/schedulers/__init__.py index 8a939b0f05..4e6d8db9e1 100644 --- a/apps/stable_diffusion/src/schedulers/__init__.py +++ b/apps/stable_diffusion/src/schedulers/__init__.py @@ -2,3 +2,6 @@ from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import ( SharkEulerDiscreteScheduler, ) +from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import ( + SharkEulerAncestralDiscreteScheduler, +) diff --git a/apps/stable_diffusion/src/schedulers/sd_schedulers.py b/apps/stable_diffusion/src/schedulers/sd_schedulers.py index 325047c8b0..544fa1efa1 100644 --- a/apps/stable_diffusion/src/schedulers/sd_schedulers.py +++ b/apps/stable_diffusion/src/schedulers/sd_schedulers.py @@ -15,9 +15,22 @@ from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import ( SharkEulerDiscreteScheduler, ) +from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import ( + SharkEulerAncestralDiscreteScheduler, +) def get_schedulers(model_id): + # TODO: Robust scheduler setup on pipeline creation -- if we don't + # set batch_size here, the SHARK schedulers will + # compile with batch size = 1 regardless of whether the model + # outputs latents of a larger batch size, e.g. SDXL. + # This also goes towards enabling batch size cfg for SD in general. + # However, obviously, searching for whether the base model ID + # contains "xl" is not very robust. + + batch_size = 2 if "xl" in model_id.lower() else 1 + schedulers = dict() schedulers["PNDM"] = PNDMScheduler.from_pretrained( model_id, @@ -84,6 +97,12 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) + schedulers[ + "SharkEulerAncestralDiscrete" + ] = SharkEulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) schedulers[ "DPMSolverSinglestep" ] = DPMSolverSinglestepScheduler.from_pretrained( @@ -100,5 +119,6 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) - schedulers["SharkEulerDiscrete"].compile() + schedulers["SharkEulerDiscrete"].compile(batch_size) + schedulers["SharkEulerAncestralDiscrete"].compile(batch_size) return schedulers diff --git a/apps/stable_diffusion/src/schedulers/shark_eulerancestraldiscrete.py b/apps/stable_diffusion/src/schedulers/shark_eulerancestraldiscrete.py new file mode 100644 index 0000000000..c941e56220 --- /dev/null +++ b/apps/stable_diffusion/src/schedulers/shark_eulerancestraldiscrete.py @@ -0,0 +1,251 @@ +import sys +import numpy as np +from typing import List, Optional, Tuple, Union +from diffusers import ( + EulerAncestralDiscreteScheduler, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.configuration_utils import register_to_config +from apps.stable_diffusion.src.utils import ( + compile_through_fx, + get_shark_model, + args, +) +import torch + + +class SharkEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler): + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + super().__init__( + num_train_timesteps, + beta_start, + beta_end, + beta_schedule, + trained_betas, + prediction_type, + timestep_spacing, + steps_offset, + ) + # TODO: make it dynamic so we dont have to worry about batch size + self.batch_size = None + self.init_input_shape = None + + def compile(self, batch_size=1): + SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers" + device = args.device.split(":", 1)[0].strip() + self.batch_size = batch_size + + model_input = { + "eulera": { + "output": torch.randn( + batch_size, 4, args.height // 8, args.width // 8 + ), + "latent": torch.randn( + batch_size, 4, args.height // 8, args.width // 8 + ), + "sigma": torch.tensor(1).to(torch.float32), + "sigma_from": torch.tensor(1).to(torch.float32), + "sigma_to": torch.tensor(1).to(torch.float32), + "noise": torch.randn( + batch_size, 4, args.height // 8, args.width // 8 + ), + }, + } + + example_latent = model_input["eulera"]["latent"] + example_output = model_input["eulera"]["output"] + example_noise = model_input["eulera"]["noise"] + if args.precision == "fp16": + example_latent = example_latent.half() + example_output = example_output.half() + example_noise = example_noise.half() + example_sigma = model_input["eulera"]["sigma"] + example_sigma_from = model_input["eulera"]["sigma_from"] + example_sigma_to = model_input["eulera"]["sigma_to"] + + class ScalingModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, latent, sigma): + return latent / ((sigma**2 + 1) ** 0.5) + + class SchedulerStepEpsilonModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, noise_pred, latent, sigma, sigma_from, sigma_to, noise + ): + sigma_up = ( + sigma_to**2 + * (sigma_from**2 - sigma_to**2) + / sigma_from**2 + ) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + dt = sigma_down - sigma + pred_original_sample = latent - sigma * noise_pred + derivative = (latent - pred_original_sample) / sigma + prev_sample = latent + derivative * dt + return prev_sample + noise * sigma_up + + class SchedulerStepVPredictionModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, noise_pred, sigma, sigma_from, sigma_to, latent, noise + ): + sigma_up = ( + sigma_to**2 + * (sigma_from**2 - sigma_to**2) + / sigma_from**2 + ) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + dt = sigma_down - sigma + pred_original_sample = noise_pred * ( + -sigma / (sigma**2 + 1) ** 0.5 + ) + (latent / (sigma**2 + 1)) + derivative = (latent - pred_original_sample) / sigma + prev_sample = latent + derivative * dt + return prev_sample + noise * sigma_up + + iree_flags = [] + if len(args.iree_vulkan_target_triple) > 0: + iree_flags.append( + f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" + ) + + def _import(self): + scaling_model = ScalingModel() + self.scaling_model, _ = compile_through_fx( + model=scaling_model, + inputs=(example_latent, example_sigma), + extended_model_name=f"euler_a_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_" + + args.precision, + extra_args=iree_flags, + ) + + pred_type_model_dict = { + "epsilon": SchedulerStepEpsilonModel(), + "v_prediction": SchedulerStepVPredictionModel(), + } + step_model = pred_type_model_dict[self.config.prediction_type] + self.step_model, _ = compile_through_fx( + step_model, + ( + example_output, + example_latent, + example_sigma, + example_sigma_from, + example_sigma_to, + example_noise, + ), + extended_model_name=f"euler_a_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_" + + args.precision, + extra_args=iree_flags, + ) + + if args.import_mlir: + _import(self) + + else: + try: + self.scaling_model = get_shark_model( + SCHEDULER_BUCKET, + "euler_a_scale_model_input_" + args.precision, + iree_flags, + ) + self.step_model = get_shark_model( + SCHEDULER_BUCKET, + "euler_a_step_" + + self.config.prediction_type + + args.precision, + iree_flags, + ) + except: + print( + "failed to download model, falling back and using import_mlir" + ) + args.import_mlir = True + _import(self) + + def scale_model_input(self, sample, timestep): + if self.step_index is None: + self._init_step_index(timestep) + sigma = self.sigmas[self.step_index] + return self.scaling_model( + "forward", + ( + sample, + sigma, + ), + send_to_host=False, + ) + + def step( + self, + noise_pred, + timestep, + latent, + generator: Optional[torch.Generator] = None, + return_dict: Optional[bool] = False, + ): + step_inputs = [] + + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + sigma_from = self.sigmas[self.step_index] + sigma_to = self.sigmas[self.step_index + 1] + noise = randn_tensor( + torch.Size(noise_pred.shape), + dtype=torch.float16, + device="cpu", + generator=generator, + ) + step_inputs = [ + noise_pred, + latent, + sigma, + sigma_from, + sigma_to, + noise, + ] + # TODO: deal with dynamic inputs in turbine flow. + # update step index since we're done with the variable and will return with compiled module output. + self._step_index += 1 + + if noise_pred.shape[0] < self.batch_size: + for i in [0, 1, 5]: + try: + step_inputs[i] = torch.tensor(step_inputs[i]) + except: + step_inputs[i] = torch.tensor(step_inputs[i].to_host()) + step_inputs[i] = torch.cat( + (step_inputs[i], step_inputs[i]), axis=0 + ) + return self.step_model( + "forward", + tuple(step_inputs), + send_to_host=True, + ) + + return self.step_model( + "forward", + tuple(step_inputs), + send_to_host=False, + ) diff --git a/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py b/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py index 85aba5d870..c074af4d7c 100644 --- a/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py +++ b/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py @@ -2,12 +2,9 @@ import numpy as np from typing import List, Optional, Tuple, Union from diffusers import ( - LMSDiscreteScheduler, - PNDMScheduler, - DDIMScheduler, - DPMSolverMultistepScheduler, EulerDiscreteScheduler, ) +from diffusers.utils.torch_utils import randn_tensor from diffusers.configuration_utils import register_to_config from apps.stable_diffusion.src.utils import ( compile_through_fx, @@ -27,6 +24,13 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + interpolation_type: str = "linear", + use_karras_sigmas: bool = False, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + timestep_spacing: str = "linspace", + timestep_type: str = "discrete", + steps_offset: int = 0, ): super().__init__( num_train_timesteps, @@ -35,20 +39,29 @@ def __init__( beta_schedule, trained_betas, prediction_type, + interpolation_type, + use_karras_sigmas, + sigma_min, + sigma_max, + timestep_spacing, + timestep_type, + steps_offset, ) + # TODO: make it dynamic so we dont have to worry about batch size + self.batch_size = None - def compile(self): + def compile(self, batch_size=1): SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers" - BATCH_SIZE = args.batch_size device = args.device.split(":", 1)[0].strip() + self.batch_size = batch_size model_input = { "euler": { "latent": torch.randn( - BATCH_SIZE, 4, args.height // 8, args.width // 8 + batch_size, 4, args.height // 8, args.width // 8 ), "output": torch.randn( - BATCH_SIZE, 4, args.height // 8, args.width // 8 + batch_size, 4, args.height // 8, args.width // 8 ), "sigma": torch.tensor(1).to(torch.float32), "dt": torch.tensor(1).to(torch.float32), @@ -70,12 +83,32 @@ def __init__(self): def forward(self, latent, sigma): return latent / ((sigma**2 + 1) ** 0.5) - class SchedulerStepModel(torch.nn.Module): + class SchedulerStepEpsilonModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, noise_pred, sigma_hat, latent, dt): + pred_original_sample = latent - sigma_hat * noise_pred + derivative = (latent - pred_original_sample) / sigma_hat + return latent + derivative * dt + + class SchedulerStepSampleModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, noise_pred, sigma_hat, latent, dt): + pred_original_sample = noise_pred + derivative = (latent - pred_original_sample) / sigma_hat + return latent + derivative * dt + + class SchedulerStepVPredictionModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, noise_pred, sigma, latent, dt): - pred_original_sample = latent - sigma * noise_pred + pred_original_sample = noise_pred * ( + -sigma / (sigma**2 + 1) ** 0.5 + ) + (latent / (sigma**2 + 1)) derivative = (latent - pred_original_sample) / sigma return latent + derivative * dt @@ -90,16 +123,22 @@ def _import(self): self.scaling_model, _ = compile_through_fx( model=scaling_model, inputs=(example_latent, example_sigma), - extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_" + extended_model_name=f"euler_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_" + args.precision, extra_args=iree_flags, ) - step_model = SchedulerStepModel() + pred_type_model_dict = { + "epsilon": SchedulerStepEpsilonModel(), + "v_prediction": SchedulerStepVPredictionModel(), + "sample": SchedulerStepSampleModel(), + "original_sample": SchedulerStepSampleModel(), + } + step_model = pred_type_model_dict[self.config.prediction_type] self.step_model, _ = compile_through_fx( step_model, (example_output, example_sigma, example_latent, example_dt), - extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_" + extended_model_name=f"euler_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_" + args.precision, extra_args=iree_flags, ) @@ -109,6 +148,11 @@ def _import(self): else: try: + step_model_type = ( + "sample" + if "sample" in self.config.prediction_type + else self.config.prediction_type + ) self.scaling_model = get_shark_model( SCHEDULER_BUCKET, "euler_scale_model_input_" + args.precision, @@ -116,7 +160,7 @@ def _import(self): ) self.step_model = get_shark_model( SCHEDULER_BUCKET, - "euler_step_" + args.precision, + "euler_step_" + step_model_type + args.precision, iree_flags, ) except: @@ -138,15 +182,52 @@ def scale_model_input(self, sample, timestep): send_to_host=False, ) - def step(self, noise_pred, timestep, latent): - step_index = (self.timesteps == timestep).nonzero().item() - sigma = self.sigmas[step_index] - dt = self.sigmas[step_index + 1] - sigma + def step( + self, + noise_pred, + timestep, + latent, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: Optional[bool] = False, + ): + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + gamma = ( + min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigma <= s_tmax + else 0.0 + ) + + sigma_hat = sigma * (gamma + 1) + + noise = randn_tensor( + noise_pred.shape, + dtype=noise_pred.dtype, + device="cpu", + generator=generator, + ) + + eps = noise * s_noise + + if gamma > 0: + latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + if self.config.prediction_type == "v_prediction": + sigma_hat = sigma + + dt = self.sigmas[self.step_index + 1] - sigma_hat return self.step_model( "forward", ( noise_pred, - sigma, + sigma_hat, latent, dt, ), diff --git a/apps/stable_diffusion/src/utils/resources/base_model.json b/apps/stable_diffusion/src/utils/resources/base_model.json index 09617ca40a..5cbf965581 100644 --- a/apps/stable_diffusion/src/utils/resources/base_model.json +++ b/apps/stable_diffusion/src/utils/resources/base_model.json @@ -189,6 +189,49 @@ "dtype": "i64" } }, + "stabilityai/sdxl-turbo": { + "latents": { + "shape": [ + "2*batch_size", + 4, + "height", + "width" + ], + "dtype": "f32" + }, + "timesteps": { + "shape": [ + 1 + ], + "dtype": "f32" + }, + "prompt_embeds": { + "shape": [ + "2*batch_size", + "max_len", + 2048 + ], + "dtype": "f32" + }, + "text_embeds": { + "shape": [ + "2*batch_size", + 1280 + ], + "dtype": "f32" + }, + "time_ids": { + "shape": [ + "2*batch_size", + 6 + ], + "dtype": "f32" + }, + "guidance_scale": { + "shape": 1, + "dtype": "f32" + } + }, "stabilityai/stable-diffusion-xl-base-1.0": { "latents": { "shape": [ @@ -449,4 +492,4 @@ } } } -} \ No newline at end of file +} diff --git a/apps/stable_diffusion/src/utils/resources/prompts.json b/apps/stable_diffusion/src/utils/resources/prompts.json index 8387178460..7ecce99e5a 100644 --- a/apps/stable_diffusion/src/utils/resources/prompts.json +++ b/apps/stable_diffusion/src/utils/resources/prompts.json @@ -1,4 +1,5 @@ [["A high tech solarpunk utopia in the Amazon rainforest"], +["Astrophotography, the shark nebula, nebula with a tiny shark-like cloud in the middle in the middle, hubble telescope, vivid colors"], ["A pikachu fine dining with a view to the Eiffel Tower"], ["A mecha robot in a favela in expressionist style"], ["an insect robot preparing a delicious meal"], diff --git a/apps/stable_diffusion/studio_bundle.spec b/apps/stable_diffusion/studio_bundle.spec index 2403b58128..d73abd1e38 100644 --- a/apps/stable_diffusion/studio_bundle.spec +++ b/apps/stable_diffusion/studio_bundle.spec @@ -19,6 +19,9 @@ a = Analysis( win_private_assemblies=False, cipher=block_cipher, noarchive=False, + module_collection_mode={ + 'gradio': 'py', # Collect gradio package as source .py files + }, ) pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index 07242e3fc7..7bb7be9c7c 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -110,11 +110,15 @@ def resource_path(relative_path): txt2img_sendto_outpaint, txt2img_sendto_upscaler, # SDXL - txt2img_sdxl_inf, txt2img_sdxl_web, txt2img_sdxl_custom_model, txt2img_sdxl_gallery, + txt2img_sdxl_png_info_img, txt2img_sdxl_status, + txt2img_sdxl_sendto_img2img, + txt2img_sdxl_sendto_inpaint, + txt2img_sdxl_sendto_outpaint, + txt2img_sdxl_sendto_upscaler, # h2ogpt_upload, # h2ogpt_web, img2img_web, @@ -151,7 +155,7 @@ def resource_path(relative_path): upscaler_sendto_outpaint, # lora_train_web, # model_web, - # model_config_web, + model_config_web, hf_models, modelmanager_sendto_txt2img, modelmanager_sendto_img2img, @@ -165,6 +169,7 @@ def resource_path(relative_path): outputgallery_watch, outputgallery_filename, outputgallery_sendto_txt2img, + outputgallery_sendto_txt2img_sdxl, outputgallery_sendto_img2img, outputgallery_sendto_inpaint, outputgallery_sendto_outpaint, @@ -178,7 +183,7 @@ def register_button_click(button, selectedid, inputs, outputs): button.click( lambda x: ( x[0]["name"] if len(x) != 0 else None, - gr.Tabs.update(selected=selectedid), + gr.Tabs(selected=selectedid), ), inputs, outputs, @@ -189,7 +194,7 @@ def register_modelmanager_button(button, selectedid, inputs, outputs): lambda x: ( "None", x, - gr.Tabs.update(selected=selectedid), + gr.Tabs(selected=selectedid), ), inputs, outputs, @@ -199,7 +204,7 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): button.click( lambda x: ( x, - gr.Tabs.update(selected=selectedid), + gr.Tabs(selected=selectedid), ), inputs, outputs, @@ -241,6 +246,7 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): inpaint_status, outpaint_status, upscaler_status, + txt2img_sdxl_status, ] ) # with gr.TabItem(label="Model Manager", id=6): @@ -249,17 +255,17 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): # lora_train_web.render() with gr.TabItem(label="Chat Bot", id=8): stablelm_chat.render() - # with gr.TabItem( - # label="Generate Sharding Config (Experimental)", id=9 - # ): - # model_config_web.render() - with gr.TabItem(label="MultiModal (Experimental)", id=10): - minigpt4_web.render() + with gr.TabItem( + label="Generate Sharding Config (Experimental)", id=9 + ): + model_config_web.render() + # with gr.TabItem(label="MultiModal (Experimental)", id=10): + # minigpt4_web.render() # with gr.TabItem(label="DocuChat Upload", id=11): # h2ogpt_upload.render() # with gr.TabItem(label="DocuChat(Experimental)", id=12): # h2ogpt_web.render() - with gr.TabItem(label="Text-to-Image-SDXL (Experimental)", id=13): + with gr.TabItem(label="Text-to-Image (SDXL)", id=13): txt2img_sdxl_web.render() actual_port = app.usable_port() diff --git a/apps/stable_diffusion/web/ui/__init__.py b/apps/stable_diffusion/web/ui/__init__.py index 7ba21b73ef..979c129850 100644 --- a/apps/stable_diffusion/web/ui/__init__.py +++ b/apps/stable_diffusion/web/ui/__init__.py @@ -16,6 +16,11 @@ txt2img_sdxl_custom_model, txt2img_sdxl_gallery, txt2img_sdxl_status, + txt2img_sdxl_png_info_img, + txt2img_sdxl_sendto_img2img, + txt2img_sdxl_sendto_inpaint, + txt2img_sdxl_sendto_outpaint, + txt2img_sdxl_sendto_upscaler, ) from apps.stable_diffusion.web.ui.img2img_ui import ( img2img_inf, @@ -83,6 +88,7 @@ outputgallery_watch, outputgallery_filename, outputgallery_sendto_txt2img, + outputgallery_sendto_txt2img_sdxl, outputgallery_sendto_img2img, outputgallery_sendto_inpaint, outputgallery_sendto_outpaint, diff --git a/apps/stable_diffusion/web/ui/img2img_ui.py b/apps/stable_diffusion/web/ui/img2img_ui.py index 0e03287b33..1d8a7f5395 100644 --- a/apps/stable_diffusion/web/ui/img2img_ui.py +++ b/apps/stable_diffusion/web/ui/img2img_ui.py @@ -34,6 +34,7 @@ from apps.stable_diffusion.src.utils.stencils import ( CannyDetector, OpenposeDetector, + ZoeDetector, ) from apps.stable_diffusion.web.utils.common_label_calc import status_label import numpy as np @@ -97,18 +98,21 @@ def img2img_inf( for i, stencil in enumerate(stencils): if images[i] is None and stencil is not None: - return None, "A stencil must have an Image input" + return if images[i] is not None: + if isinstance(images[i], dict): + images[i] = images[i]["composite"] images[i] = images[i].convert("RGB") - if image_dict is None: + if image_dict is None and images[0] is None: return None, "An Initial Image is required" - # if use_stencil == "scribble": - # image = image_dict["mask"].convert("RGB") if isinstance(image_dict, PIL.Image.Image): image = image_dict.convert("RGB") - else: + elif image_dict: image = image_dict["image"].convert("RGB") + else: + # TODO: enable t2i + controlnets + image = None # set ckpt_loc and hf_model_id. args.ckpt_loc = "" @@ -140,9 +144,8 @@ def img2img_inf( if stencil is not None: stencil_count += 1 if stencil_count > 0: - args.scheduler = "DDIM" args.hf_model_id = "runwayml/stable-diffusion-v1-5" - # image, width, height = resize_stencil(image) + image, width, height = resize_stencil(image) elif "Shark" in args.scheduler: print( f"Shark schedulers are not supported. Switching to EulerDiscrete " @@ -363,71 +366,182 @@ def img2img_inf( # TODO: make this import image prompt info if it exists img2img_init_image = gr.Image( label="Input Image", - source="upload", - tool="sketch", type="pil", - height=300, + height=512, + interactive=True, ) with gr.Accordion(label="Multistencil Options", open=False): - choices = ["None", "canny", "openpose", "scribble"] + choices = [ + "None", + "canny", + "openpose", + "scribble", + "zoedepth", + ] def cnet_preview( - checked, model, input_image, index, stencils, images + model, input_image, index, stencils, images ): - if not checked: - stencils[index] = None - images[index] = None - return (None, stencils, images) images[index] = input_image stencils[index] = model match model: case "canny": canny = CannyDetector() - result = canny(np.array(input_image), 100, 200) + result = canny( + np.array(input_image["composite"]), + 100, + 200, + ) return ( - [Image.fromarray(result), result], + Image.fromarray(result), stencils, images, ) case "openpose": openpose = OpenposeDetector() - result = openpose(np.array(input_image)) + result = openpose( + np.array(input_image["composite"]) + ) # TODO: This is just an empty canvas, need to draw the candidates (which are in result[1]) return ( - [Image.fromarray(result[0]), result], + Image.fromarray(result[0]), stencils, images, ) + case "zoedepth": + zoedepth = ZoeDetector() + result = zoedepth( + np.array(input_image["composite"]) + ) + return ( + Image.fromarray(result[0]), + stencils, + images, + ) + case "scribble": + result = input_image["composite"].convert("L") + return (result, stencils, images) case _: return (None, stencils, images) - with gr.Row(): - cnet_1 = gr.Checkbox(show_label=False) - cnet_1_model = gr.Dropdown( - label="Controlnet 1", - value="None", - choices=choices, + def create_canvas(width, height): + data = ( + np.zeros( + shape=(height, width, 3), + dtype=np.uint8, + ) + + 255 ) - cnet_1_image = gr.Image( - source="upload", - tool=None, + return data + + def update_cn_input(model, width, height): + if model == "scribble": + return [ + gr.ImageEditor( + visible=True, + image_mode="RGB", + interactive=True, + show_label=False, + type="pil", + value=create_canvas(width, height), + crop_size=(width, height), + ), + gr.Image( + visible=True, + show_label=False, + interactive=False, + ), + gr.Slider(visible=True), + gr.Slider(visible=True), + gr.Button(visible=True), + ] + else: + return [ + gr.ImageEditor( + visible=True, + image_mode="RGB", + type="pil", + interactive=True, + value=None, + ), + gr.Image( + visible=True, + show_label=False, + interactive=True, + ), + gr.Slider(visible=False), + gr.Slider(visible=False), + gr.Button(visible=False), + ] + + with gr.Row(): + with gr.Column(): + cnet_1 = gr.Button( + value="Generate controlnet input" + ) + cnet_1_model = gr.Dropdown( + label="Controlnet 1", + value="None", + choices=choices, + ) + canvas_width = gr.Slider( + label="Canvas Width", + minimum=256, + maximum=1024, + value=512, + step=1, + visible=False, + ) + canvas_height = gr.Slider( + label="Canvas Height", + minimum=256, + maximum=1024, + value=512, + step=1, + visible=False, + ) + make_canvas = gr.Button( + value="Make Canvas!", + visible=False, + ) + cnet_1_image = gr.ImageEditor( + visible=False, + image_mode="RGB", + interactive=True, + show_label=False, type="pil", ) - cnet_1_output = gr.Gallery( - show_label=False, - object_fit="scale-down", - rows=1, - columns=1, + cnet_1_output = gr.Image( + visible=True, show_label=False + ) + cnet_1_model.input( + update_cn_input, + [cnet_1_model, canvas_width, canvas_height], + [ + cnet_1_image, + cnet_1_output, + canvas_width, + canvas_height, + make_canvas, + ], ) - cnet_1.change( + make_canvas.click( + update_cn_input, + [cnet_1_model, canvas_width, canvas_height], + [ + cnet_1_image, + cnet_1_output, + canvas_width, + canvas_height, + make_canvas, + ], + ) + cnet_1.click( fn=( - lambda a, b, c, s, i: cnet_preview( - a, b, c, 0, s, i - ) + lambda a, b, s, i: cnet_preview(a, b, 0, s, i) ), inputs=[ - cnet_1, cnet_1_model, cnet_1_image, stencils, @@ -436,31 +550,72 @@ def cnet_preview( outputs=[cnet_1_output, stencils, images], ) with gr.Row(): - cnet_2 = gr.Checkbox(show_label=False) - cnet_2_model = gr.Dropdown( - label="Controlnet 2", - value="None", - choices=choices, - ) - cnet_2_image = gr.Image( - source="upload", - tool=None, + with gr.Column(): + cnet_2 = gr.Button( + value="Generate controlnet input" + ) + cnet_2_model = gr.Dropdown( + label="Controlnet 2", + value="None", + choices=choices, + ) + canvas_width = gr.Slider( + label="Canvas Width", + minimum=256, + maximum=1024, + value=512, + step=1, + visible=False, + ) + canvas_height = gr.Slider( + label="Canvas Height", + minimum=256, + maximum=1024, + value=512, + step=1, + visible=False, + ) + make_canvas = gr.Button( + value="Make Canvas!", + visible=False, + ) + cnet_2_image = gr.ImageEditor( + visible=False, + image_mode="RGB", + interactive=True, + show_label=False, type="pil", ) - cnet_2_output = gr.Gallery( - show_label=False, - object_fit="scale-down", - rows=1, - columns=1, + cnet_2_output = gr.Image( + visible=True, show_label=False ) - cnet_2.change( + cnet_2_model.select( + update_cn_input, + [cnet_2_model, canvas_width, canvas_height], + [ + cnet_2_image, + cnet_2_output, + canvas_width, + canvas_height, + make_canvas, + ], + ) + make_canvas.click( + update_cn_input, + [cnet_2_model, canvas_width, canvas_height], + [ + cnet_2_image, + cnet_2_output, + canvas_width, + canvas_height, + make_canvas, + ], + ) + cnet_2.click( fn=( - lambda a, b, c, s, i: cnet_preview( - a, b, c, 1, s, i - ) + lambda a, b, s, i: cnet_preview(a, b, 1, s, i) ), inputs=[ - cnet_2, cnet_2_model, cnet_2_image, stencils, diff --git a/apps/stable_diffusion/web/ui/inpaint_ui.py b/apps/stable_diffusion/web/ui/inpaint_ui.py index 97e643d259..3cc3fc3c96 100644 --- a/apps/stable_diffusion/web/ui/inpaint_ui.py +++ b/apps/stable_diffusion/web/ui/inpaint_ui.py @@ -290,8 +290,7 @@ def inpaint_inf( inpaint_init_image = gr.Image( label="Masked Image", - source="upload", - tool="sketch", + sources="upload", type="pil", height=350, ) diff --git a/apps/stable_diffusion/web/ui/model_manager.py b/apps/stable_diffusion/web/ui/model_manager.py index 16e812b4ed..11e01fe873 100644 --- a/apps/stable_diffusion/web/ui/model_manager.py +++ b/apps/stable_diffusion/web/ui/model_manager.py @@ -104,7 +104,6 @@ def get_image_from_model(model_json): civit_models = gr.Gallery( label="Civitai Model Gallery", value=None, - interactive=True, visible=False, ) diff --git a/apps/stable_diffusion/web/ui/outputgallery_ui.py b/apps/stable_diffusion/web/ui/outputgallery_ui.py index f5b96a6808..711b70a7e0 100644 --- a/apps/stable_diffusion/web/ui/outputgallery_ui.py +++ b/apps/stable_diffusion/web/ui/outputgallery_ui.py @@ -95,7 +95,7 @@ def output_subdirs() -> list[str]: ) with gr.Column(scale=4): - with gr.Box(): + with gr.Group(): with gr.Row(): with gr.Column( scale=15, @@ -152,6 +152,7 @@ def output_subdirs() -> list[str]: wrap=True, elem_classes="output_parameters_dataframe", value=[["Status", "No image selected"]], + interactive=True, ) with gr.Accordion(label="Send To", open=True): @@ -162,6 +163,12 @@ def output_subdirs() -> list[str]: elem_classes="outputgallery_sendto", size="sm", ) + outputgallery_sendto_txt2img_sdxl = gr.Button( + value="Txt2Img XL", + interactive=False, + elem_classes="outputgallery_sendto", + size="sm", + ) outputgallery_sendto_img2img = gr.Button( value="Img2Img", @@ -195,17 +202,17 @@ def output_subdirs() -> list[str]: def on_clear_gallery(): return [ - gr.Gallery.update( + gr.Gallery( value=[], visible=False, ), - gr.Image.update( + gr.Image( visible=True, ), ] def on_image_columns_change(columns): - return gr.Gallery.update(columns=columns) + return gr.Gallery(columns=columns) def on_select_subdir(subdir) -> list: # evt.value is the subdirectory name @@ -215,12 +222,12 @@ def on_select_subdir(subdir) -> list: ) return [ new_images, - gr.Gallery.update( + gr.Gallery( value=new_images, label=new_label, visible=len(new_images) > 0, ), - gr.Image.update( + gr.Image( label=new_label, visible=len(new_images) == 0, ), @@ -254,16 +261,16 @@ def on_refresh(current_subdir: str) -> list: ) return [ - gr.Dropdown.update( + gr.Dropdown( choices=refreshed_subdirs, value=new_subdir, ), refreshed_subdirs, new_images, - gr.Gallery.update( + gr.Gallery( value=new_images, label=new_label, visible=len(new_images) > 0 ), - gr.Image.update( + gr.Image( label=new_label, visible=len(new_images) == 0, ), @@ -289,12 +296,12 @@ def on_new_image(subdir, subdir_paths, status) -> list: return [ new_images, - gr.Gallery.update( + gr.Gallery( value=new_images, label=new_label, visible=len(new_images) > 0, ), - gr.Image.update( + gr.Image( label=new_label, visible=len(new_images) == 0, ), @@ -332,12 +339,12 @@ def on_outputgallery_filename_change(filename: str) -> list: return [ # disable or enable each of the sendto button based on whether # an image is selected - gr.Button.update(interactive=exists), - gr.Button.update(interactive=exists), - gr.Button.update(interactive=exists), - gr.Button.update(interactive=exists), - gr.Button.update(interactive=exists), - gr.Button.update(interactive=exists), + gr.Button(interactive=exists), + gr.Button(interactive=exists), + gr.Button(interactive=exists), + gr.Button(interactive=exists), + gr.Button(interactive=exists), + gr.Button(interactive=exists), ] # The time first our tab is selected we need to do an initial refresh @@ -414,6 +421,7 @@ def on_select_tab(subdir_paths, request: gr.Request): [outputgallery_filename], [ outputgallery_sendto_txt2img, + outputgallery_sendto_txt2img_sdxl, outputgallery_sendto_img2img, outputgallery_sendto_inpaint, outputgallery_sendto_outpaint, diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index 0df3f8442d..f3baa3c5ce 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -431,8 +431,8 @@ def view_json_file(file_obj): config_file = gr.File( label="Upload sharding configuration", visible=False ) - json_view_button = gr.Button(label="View as JSON", visible=False) - json_view = gr.JSON(interactive=True, visible=False) + json_view_button = gr.Button(value="View as JSON", visible=False) + json_view = gr.JSON(visible=False) json_view_button.click( fn=view_json_file, inputs=[config_file], outputs=[json_view] ) diff --git a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py index 7bcbb3c147..c3a653bd13 100644 --- a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py @@ -11,9 +11,11 @@ get_custom_model_path, get_custom_model_files, scheduler_list, - predefined_models, + predefined_sdxl_models, cancel_sd, + set_model_default_configs, ) +from apps.stable_diffusion.web.ui.common_ui_events import lora_changed from apps.stable_diffusion.web.utils.metadata import import_png_metadata from apps.stable_diffusion.web.utils.common_label_calc import status_label from apps.stable_diffusion.src import ( @@ -50,17 +52,17 @@ def txt2img_sdxl_inf( batch_size: int, scheduler: str, model_id: str, + custom_vae: str, precision: str, device: str, max_length: int, save_metadata_to_json: bool, save_metadata_to_png: bool, + lora_weights: str, + lora_hf_id: str, ondemand: bool, repeatable_seeds: bool, ): - if precision != "fp16": - print("currently we support fp16 for SDXL") - precision = "fp16" from apps.stable_diffusion.web.ui.utils import ( get_custom_model_pathfile, get_custom_vae_or_lora_weights, @@ -71,6 +73,10 @@ def txt2img_sdxl_inf( SD_STATE_CANCEL, ) + if precision != "fp16": + print("currently we support fp16 for SDXL") + precision = "fp16" + args.prompts = [prompt] args.negative_prompts = [negative_prompt] args.guidance_scale = guidance_scale @@ -93,13 +99,15 @@ def txt2img_sdxl_inf( else: args.hf_model_id = model_id - # if custom_vae != "None": - # args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae") + if custom_vae: + args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae") args.save_metadata_to_json = save_metadata_to_json args.write_metadata_to_png = save_metadata_to_png - args.use_lora = "" + args.use_lora = get_custom_vae_or_lora_weights( + lora_weights, lora_hf_id, "lora" + ) dtype = torch.float32 if precision == "fp32" else torch.half cpu_scheduling = not scheduler.startswith("Shark") @@ -115,7 +123,7 @@ def txt2img_sdxl_inf( width, device, use_lora=args.use_lora, - use_stencil=None, + stencils=None, ondemand=ondemand, ) if ( @@ -144,31 +152,29 @@ def txt2img_sdxl_inf( ) global_obj.set_schedulers(get_schedulers(model_id)) scheduler_obj = global_obj.get_scheduler(scheduler) - # For SDXL we set max_length as 77. - print("Setting max_length = 77") - max_length = 77 if global_obj.get_cfg_obj().ondemand: print("Running txt2img in memory efficient mode.") - txt2img_sdxl_obj = Text2ImageSDXLPipeline.from_pretrained( - scheduler=scheduler_obj, - import_mlir=args.import_mlir, - model_id=args.hf_model_id, - ckpt_loc=args.ckpt_loc, - precision=precision, - max_length=max_length, - batch_size=batch_size, - height=height, - width=width, - use_base_vae=args.use_base_vae, - use_tuned=args.use_tuned, - custom_vae=args.custom_vae, - low_cpu_mem_usage=args.low_cpu_mem_usage, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - use_quantize=args.use_quantize, - ondemand=global_obj.get_cfg_obj().ondemand, + global_obj.set_sd_obj( + Text2ImageSDXLPipeline.from_pretrained( + scheduler=scheduler_obj, + import_mlir=args.import_mlir, + model_id=args.hf_model_id, + ckpt_loc=args.ckpt_loc, + precision=precision, + max_length=max_length, + batch_size=batch_size, + height=height, + width=width, + use_base_vae=args.use_base_vae, + use_tuned=args.use_tuned, + custom_vae=args.custom_vae, + low_cpu_mem_usage=args.low_cpu_mem_usage, + debug=args.import_debug if args.import_mlir else False, + use_lora=args.use_lora, + use_quantize=args.use_quantize, + ondemand=global_obj.get_cfg_obj().ondemand, + ) ) - global_obj.set_sd_obj(txt2img_sdxl_obj) global_obj.set_sd_scheduler(scheduler) @@ -239,7 +245,7 @@ def txt2img_sdxl_inf( with gr.Row(): with gr.Column(scale=10): with gr.Row(): - t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}" + t2i_sdxl_model_info = f"Custom Model Path: {str(get_custom_model_path())}" txt2img_sdxl_custom_model = gr.Dropdown( label=f"Models", info="Select, or enter HuggingFace Model ID or Civitai model download URL", @@ -247,12 +253,39 @@ def txt2img_sdxl_inf( value=os.path.basename(args.ckpt_loc) if args.ckpt_loc else "stabilityai/stable-diffusion-xl-base-1.0", - choices=[ - "stabilityai/stable-diffusion-xl-base-1.0" - ], + choices=predefined_sdxl_models + + get_custom_model_files( + custom_checkpoint_type="sdxl" + ), allow_custom_value=True, scale=2, ) + t2i_sdxl_vae_info = ( + str(get_custom_model_path("vae")) + ).replace("\\", "\n\\") + t2i_sdxl_vae_info = ( + f"VAE Path: {t2i_sdxl_vae_info}" + ) + custom_vae = gr.Dropdown( + label=f"VAE Models", + info=t2i_sdxl_vae_info, + elem_id="custom_model", + value="None", + choices=[ + None, + "madebyollin/sdxl-vae-fp16-fix", + ] + + get_custom_model_files("vae"), + allow_custom_value=True, + scale=1, + ) + with gr.Column(scale=1, min_width=170): + txt2img_sdxl_png_info_img = gr.Image( + label="Import PNG info", + elem_id="txt2img_prompt_image", + type="pil", + visible=True, + ) with gr.Group(elem_id="prompt_box_outer"): prompt = gr.Textbox( @@ -267,16 +300,49 @@ def txt2img_sdxl_inf( lines=2, elem_id="negative_prompt_box", ) - + with gr.Accordion(label="LoRA Options", open=False): + with gr.Row(): + # janky fix for overflowing text + t2i_sdxl_lora_info = ( + str(get_custom_model_path("lora")) + ).replace("\\", "\n\\") + t2i_sdxl_lora_info = f"LoRA Path: {t2i_sdxl_lora_info}" + lora_weights = gr.Dropdown( + label=f"Standalone LoRA Weights", + info=t2i_sdxl_lora_info, + elem_id="lora_weights", + value="None", + choices=["None"] + get_custom_model_files("lora"), + allow_custom_value=True, + ) + lora_hf_id = gr.Textbox( + elem_id="lora_hf_id", + placeholder="Select 'None' in the Standalone LoRA " + "weights dropdown on the left if you want to use " + "a standalone HuggingFace model ID for LoRA here " + "e.g: sayakpaul/sd-model-finetuned-lora-t4", + value="", + label="HuggingFace Model ID", + lines=3, + ) + with gr.Row(): + lora_tags = gr.HTML( + value="
No LoRA selected
", + elem_classes="lora-tags", + ) with gr.Accordion(label="Advanced Options", open=False): with gr.Row(): scheduler = gr.Dropdown( elem_id="scheduler", label="Scheduler", - value="DDIM", - choices=["DDIM"], - allow_custom_value=True, - visible=False, + value=args.scheduler, + choices=[ + "DDIM", + "EulerAncestralDiscrete", + "EulerDiscrete", + ], + allow_custom_value=False, + visible=True, ) with gr.Column(): save_metadata_to_png = gr.Checkbox( @@ -291,18 +357,22 @@ def txt2img_sdxl_inf( ) with gr.Row(): height = gr.Slider( + 512, 1024, value=1024, - step=8, + step=512, label="Height", - visible=False, + visible=True, + interactive=True, ) width = gr.Slider( + 512, 1024, value=1024, - step=8, + step=512, label="Width", - visible=False, + visible=True, + interactive=True, ) precision = gr.Radio( label="Precision", @@ -315,7 +385,7 @@ def txt2img_sdxl_inf( ) max_length = gr.Radio( label="Max Length", - value=args.max_length, + value=77, choices=[ 64, 77, @@ -333,7 +403,7 @@ def txt2img_sdxl_inf( 50, value=args.guidance_scale, step=0.1, - label="CFG Scale", + label="Guidance Scale", ) ondemand = gr.Checkbox( value=args.ondemand, @@ -391,10 +461,10 @@ def txt2img_sdxl_inf( show_label=False, elem_id="gallery", columns=[2], - object_fit="contain", + object_fit="scale_down", ) std_output = gr.Textbox( - value=f"{t2i_model_info}\n" + value=f"{t2i_sdxl_model_info}\n" f"Images will be saved at " f"{get_generated_imgs_path()}", lines=1, @@ -413,7 +483,18 @@ def txt2img_sdxl_inf( ) stop_batch = gr.Button("Stop Batch") with gr.Row(): - blank_thing_for_row = None + txt2img_sdxl_sendto_img2img = gr.Button( + value="Send To Img2Img" + ) + txt2img_sdxl_sendto_inpaint = gr.Button( + value="Send To Inpaint" + ) + txt2img_sdxl_sendto_outpaint = gr.Button( + value="Send To Outpaint" + ) + txt2img_sdxl_sendto_upscaler = gr.Button( + value="Send To Upscaler" + ) kwargs = dict( fn=txt2img_sdxl_inf, @@ -429,11 +510,14 @@ def txt2img_sdxl_inf( batch_size, scheduler, txt2img_sdxl_custom_model, + custom_vae, precision, device, max_length, save_metadata_to_json, save_metadata_to_png, + lora_weights, + lora_hf_id, ondemand, repeatable_seeds, ], @@ -456,3 +540,59 @@ def txt2img_sdxl_inf( fn=cancel_sd, cancels=[prompt_submit, neg_prompt_submit, generate_click], ) + + txt2img_sdxl_png_info_img.change( + fn=import_png_metadata, + inputs=[ + txt2img_sdxl_png_info_img, + prompt, + negative_prompt, + steps, + scheduler, + guidance_scale, + seed, + width, + height, + txt2img_sdxl_custom_model, + lora_weights, + lora_hf_id, + custom_vae, + ], + outputs=[ + txt2img_sdxl_png_info_img, + prompt, + negative_prompt, + steps, + scheduler, + guidance_scale, + seed, + width, + height, + txt2img_sdxl_custom_model, + lora_weights, + lora_hf_id, + custom_vae, + ], + ) + txt2img_sdxl_custom_model.change( + fn=set_model_default_configs, + inputs=[ + txt2img_sdxl_custom_model, + ], + outputs=[ + prompt, + negative_prompt, + steps, + scheduler, + guidance_scale, + width, + height, + custom_vae, + ], + ) + lora_weights.change( + fn=lora_changed, + inputs=[lora_weights], + outputs=[lora_tags], + queue=True, + ) diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index 6a8ee91664..d6b4abd03a 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -281,6 +281,7 @@ def txt2img_inf( cpu_scheduling, args.max_embeddings_multiples, stencils=[], + control_mode=None, resample_type=resample_type, ) total_time = time.time() - start_time @@ -302,7 +303,17 @@ def txt2img_inf( return generated_imgs, text_output, "" -with gr.Blocks(title="Text-to-Image") as txt2img_web: +def resource_path(relative_path): + """Get absolute path to resource, works for dev and for PyInstaller""" + base_path = getattr( + sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) + ) + return os.path.join(base_path, relative_path) + + +dark_theme = resource_path("ui/css/sd_dark_theme.css") + +with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web: with gr.Row(elem_id="ui_title"): nod_logo = Image.open(nodlogo_loc) with gr.Row(): @@ -356,7 +367,6 @@ def txt2img_inf( label="Import PNG info", elem_id="txt2img_prompt_image", type="pil", - tool="None", visible=True, ) @@ -367,6 +377,11 @@ def txt2img_inf( lines=2, elem_id="prompt_box", ) + # TODO: coming soon + autogen = gr.Checkbox( + label="Continuous Generation", + visible=False, + ) negative_prompt = gr.Textbox( label="Negative Prompt", value=args.negative_prompts[0], @@ -680,12 +695,12 @@ def txt2img_inf( # SharkEulerDiscrete doesn't work with img2img which hires_fix uses def set_compatible_schedulers(hires_fix_selected): if hires_fix_selected: - return gr.Dropdown.update( + return gr.Dropdown( choices=scheduler_list_cpu_only, value="DEISMultistep", ) else: - return gr.Dropdown.update( + return gr.Dropdown( choices=scheduler_list, value="SharkEulerDiscrete", ) diff --git a/apps/stable_diffusion/web/ui/utils.py b/apps/stable_diffusion/web/ui/utils.py index 08be24633b..3770f98203 100644 --- a/apps/stable_diffusion/web/ui/utils.py +++ b/apps/stable_diffusion/web/ui/utils.py @@ -4,6 +4,7 @@ import math import json import safetensors +import gradio as gr from pathlib import Path from apps.stable_diffusion.src import args @@ -64,9 +65,11 @@ class HSLHue(IntEnum): "DPMSolverSinglestep", "DDPM", "HeunDiscrete", + "LCMScheduler", ] scheduler_list = scheduler_list_cpu_only + [ "SharkEulerDiscrete", + "SharkEulerAncestralDiscrete", ] predefined_models = [ @@ -87,6 +90,10 @@ class HSLHue(IntEnum): predefined_upscaler_models = [ "stabilityai/stable-diffusion-x4-upscaler", ] +predefined_sdxl_models = [ + "stabilityai/sdxl-turbo", + "stabilityai/stable-diffusion-xl-base-1.0", +] def resource_path(relative_path): @@ -140,6 +147,12 @@ def get_custom_model_files(model="models", custom_checkpoint_type=""): ) ] match custom_checkpoint_type: + case "sdxl": + files = [ + val + for val in files + if any(x in val for x in ["XL", "xl", "Xl"]) + ] case "inpainting": files = [ val @@ -247,6 +260,84 @@ def cancel_sd(): pass +def set_model_default_configs(model_ckpt_or_id, jsonconfig=None): + import gradio as gr + + config_modelname = default_config_exists(model_ckpt_or_id) + if jsonconfig: + return get_config_from_json(jsonconfig) + elif config_modelname: + return default_configs[config_modelname] + # TODO: Use HF metadata to setup pipeline if available + # elif is_valid_hf_id(model_ckpt_or_id): + # return get_HF_default_configs(model_ckpt_or_id) + else: + # We don't have default metadata to setup a good config. Do not change configs. + return [ + gr.Textbox(label="Prompt", interactive=True, visible=True), + gr.Textbox(label="Negative Prompt", interactive=True), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ] + + +def get_config_from_json(model_ckpt_or_id, jsonconfig): + # TODO: make this work properly. It is currently not user-exposed. + cfgdata = json.load(jsonconfig) + return [ + cfgdata["prompt_box_behavior"], + cfgdata["neg_prompt_box_behavior"], + cfgdata["steps"], + cfgdata["scheduler"], + cfgdata["guidance_scale"], + cfgdata["width"], + cfgdata["height"], + cfgdata["custom_vae"], + ] + + +def default_config_exists(model_ckpt_or_id): + if model_ckpt_or_id in [ + "stabilityai/sdxl-turbo", + "stabilityai/stable_diffusion-xl-base-1.0", + ]: + return model_ckpt_or_id + elif "turbo" in model_ckpt_or_id.lower(): + return "stabilityai/sdxl-turbo" + else: + return None + + +default_configs = { + "stabilityai/sdxl-turbo": [ + gr.Textbox(label="", interactive=False, value=None, visible=False), + gr.Textbox( + label="Prompt", + value="role-playing game (RPG) style fantasy, An enchanting image featuring an adorable kitten mage wearing intricate ancient robes, holding an ancient staff, hard at work in her fantastical workshop, magic runes floating in the air", + ), + gr.Slider(0, 10, value=2), + gr.Dropdown(value="EulerAncestralDiscrete"), + gr.Slider(0, value=0), + 512, + 512, + "madebyollin/sdxl-vae-fp16-fix", + ], + "stabilityai/stable-diffusion-xl-base-1.0": [ + gr.Textbox(label="Prompt", interactive=True, visible=True), + gr.Textbox(label="Negative Prompt", interactive=True), + 40, + "EulerDiscrete", + 7.5, + gr.Slider(value=1024, interactive=False), + gr.Slider(value=1024, interactive=False), + "madebyollin/sdxl-vae-fp16-fix", + ], +} + nodlogo_loc = resource_path("logos/nod-logo.png") nodicon_loc = resource_path("logos/nod-icon.png") available_devices = get_available_devices() diff --git a/requirements.txt b/requirements.txt index 76498241d4..a97baa83a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ diffusers accelerate scipy ftfy -gradio==3.44.3 +gradio==4.7.1 altair omegaconf # 0.3.2 doesn't have binaries for arm64 diff --git a/shark/iree_utils/vulkan_utils.py b/shark/iree_utils/vulkan_utils.py index 859c4ba833..a08fb6f5aa 100644 --- a/shark/iree_utils/vulkan_utils.py +++ b/shark/iree_utils/vulkan_utils.py @@ -183,6 +183,9 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]): # res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"] res_vulkan_flag = [] + res_vulkan_flag += [ + "--iree-stream-resource-max-allocation-size=3221225472" + ] vulkan_triple_flag = None for arg in extra_args: if "-iree-vulkan-target-triple=" in arg: @@ -204,7 +207,9 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]): @functools.cache def get_iree_vulkan_runtime_flags(): vulkan_runtime_flags = [ - f"--vulkan_validation_layers={'true' if shark_args.vulkan_validation_layers else 'false'}", + f"--vulkan_validation_layers={'true' if shark_args.vulkan_debug_utils else 'false'}", + f"--vulkan_debug_verbosity={'4' if shark_args.vulkan_debug_utils else '0'}" + f"--vulkan-robust-buffer-access={'true' if shark_args.vulkan_debug_utils else 'false'}", ] return vulkan_runtime_flags