diff --git a/apps/stable_diffusion/shark_sd.spec b/apps/stable_diffusion/shark_sd.spec index 0a25318086..07b4a81a24 100644 --- a/apps/stable_diffusion/shark_sd.spec +++ b/apps/stable_diffusion/shark_sd.spec @@ -19,6 +19,9 @@ a = Analysis( win_private_assemblies=False, cipher=block_cipher, noarchive=False, + module_collection_mode={ + 'gradio': 'py', # Collect gradio package as source .py files + }, ) pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) diff --git a/apps/stable_diffusion/shark_studio_imports.py b/apps/stable_diffusion/shark_studio_imports.py index 70874134d6..dfb5bf8027 100644 --- a/apps/stable_diffusion/shark_studio_imports.py +++ b/apps/stable_diffusion/shark_studio_imports.py @@ -31,6 +31,7 @@ datas += copy_metadata("sentencepiece") datas += copy_metadata("pyyaml") datas += copy_metadata("huggingface-hub") +datas += copy_metadata("gradio") datas += collect_data_files("torch") datas += collect_data_files("tokenizers") datas += collect_data_files("tiktoken") @@ -75,6 +76,7 @@ # hidden imports for pyinstaller hiddenimports = ["shark", "shark.shark_inference", "apps"] hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x] +hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x] hiddenimports += [ x for x in collect_submodules("diffusers") if "tests" not in x ] @@ -85,4 +87,4 @@ if not any(kw in x for kw in blacklist) ] hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x] -hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"] +hiddenimports += ["iree._runtime"] diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index b762b85b2c..863b6434b0 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -109,7 +109,7 @@ def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision): if "vulkan" in device: _device = args.iree_vulkan_target_triple _device = _device.replace("-", "_") - vmfb_path = Path(extended_model_name_for_vmfb + f"_{_device}.vmfb") + vmfb_path = Path(extended_model_name_for_vmfb + f"_vulkan.vmfb") if vmfb_path.exists(): shark_module = SharkInference( None, @@ -436,24 +436,48 @@ def __init__( super().__init__() self.vae = None if custom_vae == "": + print(f"Loading default vae, with target {model_id}") self.vae = AutoencoderKL.from_pretrained( model_id, subfolder="vae", low_cpu_mem_usage=low_cpu_mem_usage, ) elif not isinstance(custom_vae, dict): - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - ) + precision = "fp16" if "fp16" in custom_vae else None + print(f"Loading custom vae, with target {custom_vae}") + if os.path.exists(custom_vae): + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + else: + custom_vae = "/".join( + [ + custom_vae.split("/")[-2].split("\\")[-1], + custom_vae.split("/")[-1], + ] + ) + print("Using hub to get custom vae") + try: + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + low_cpu_mem_usage=low_cpu_mem_usage, + variant=precision, + ) + except: + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + low_cpu_mem_usage=low_cpu_mem_usage, + ) else: + print(f"Loading custom vae, with state {custom_vae}") self.vae = AutoencoderKL.from_pretrained( model_id, subfolder="vae", low_cpu_mem_usage=low_cpu_mem_usage, ) self.vae.load_state_dict(custom_vae) + self.base_vae = base_vae def forward(self, latents): image = self.vae.decode(latents / 0.13025, return_dict=False)[ @@ -465,7 +489,12 @@ def forward(self, latents): inputs = tuple(self.inputs["vae"]) # Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL # pipeline. - is_f16 = False + if not self.custom_vae: + is_f16 = False + elif "16" in self.custom_vae: + is_f16 = True + else: + is_f16 = False save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"]) if self.debug: os.makedirs(save_dir, exist_ok=True) @@ -917,11 +946,19 @@ def __init__( low_cpu_mem_usage=False, ): super().__init__() - self.unet = UNet2DConditionModel.from_pretrained( - model_id, - subfolder="unet", - low_cpu_mem_usage=low_cpu_mem_usage, - ) + try: + self.unet = UNet2DConditionModel.from_pretrained( + model_id, + subfolder="unet", + low_cpu_mem_usage=low_cpu_mem_usage, + variant="fp16", + ) + except: + self.unet = UNet2DConditionModel.from_pretrained( + model_id, + subfolder="unet", + low_cpu_mem_usage=low_cpu_mem_usage, + ) if ( args.attention_slicing is not None and args.attention_slicing != "none" diff --git a/apps/stable_diffusion/src/models/opt_params.py b/apps/stable_diffusion/src/models/opt_params.py index 5dd59b006e..f03897de6e 100644 --- a/apps/stable_diffusion/src/models/opt_params.py +++ b/apps/stable_diffusion/src/models/opt_params.py @@ -123,7 +123,10 @@ def get_clip(): return get_shark_model(bucket, model_name, iree_flags) -def get_tokenizer(subfolder="tokenizer"): +def get_tokenizer(subfolder="tokenizer", hf_model_id=None): + if hf_model_id is not None: + args.hf_model_id = hf_model_id + tokenizer = CLIPTokenizer.from_pretrained( args.hf_model_id, subfolder=subfolder ) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py index f340655e5e..ba1d52fd80 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py @@ -158,7 +158,10 @@ def generate_images( use_base_vae, cpu_scheduling, max_embeddings_multiples, + stencils, + images, resample_type, + control_mode, ): # prompts and negative prompts must be a list. if isinstance(prompts, str): diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py index 07b827a771..51a790bb41 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py @@ -18,7 +18,10 @@ KDPM2AncestralDiscreteScheduler, HeunDiscreteScheduler, ) -from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler +from apps.stable_diffusion.src.schedulers import ( + SharkEulerDiscreteScheduler, + SharkEulerAncestralDiscreteScheduler, +) from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( StableDiffusionPipeline, ) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py index 42df60af9a..a3b52793e9 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py @@ -16,7 +16,10 @@ KDPM2AncestralDiscreteScheduler, HeunDiscreteScheduler, ) -from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler +from apps.stable_diffusion.src.schedulers import ( + SharkEulerDiscreteScheduler, + SharkEulerAncestralDiscreteScheduler, +) from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( StableDiffusionPipeline, ) @@ -38,6 +41,7 @@ def __init__( EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, SharkEulerDiscreteScheduler, + SharkEulerAncestralDiscreteScheduler, DEISMultistepScheduler, DDPMScheduler, DPMSolverSinglestepScheduler, @@ -48,8 +52,10 @@ def __init__( import_mlir: bool, use_lora: str, ondemand: bool, + is_fp32_vae: bool, ): super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand) + self.is_fp32_vae = is_fp32_vae def prepare_latents( self, @@ -203,10 +209,10 @@ def generate_images( # Img latents -> PIL images. all_imgs = [] self.load_vae() - # imgs = self.decode_latents_sdxl(None) - # all_imgs.extend(imgs) for i in range(0, latents.shape[0], batch_size): - imgs = self.decode_latents_sdxl(latents[i : i + batch_size]) + imgs = self.decode_latents_sdxl( + latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae + ) all_imgs.extend(imgs) if self.ondemand: self.unload_vae() diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py index 7a6f763e24..7d3d6b1c5b 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py @@ -20,7 +20,10 @@ HeunDiscreteScheduler, ) from shark.shark_inference import SharkInference -from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler +from apps.stable_diffusion.src.schedulers import ( + SharkEulerDiscreteScheduler, + SharkEulerAncestralDiscreteScheduler, +) from apps.stable_diffusion.src.models import ( SharkifyStableDiffusionModel, get_vae, @@ -52,6 +55,7 @@ def __init__( EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, SharkEulerDiscreteScheduler, + SharkEulerAncestralDiscreteScheduler, DEISMultistepScheduler, DDPMScheduler, DPMSolverSinglestepScheduler, @@ -62,6 +66,7 @@ def __init__( import_mlir: bool, use_lora: str, ondemand: bool, + is_f32_vae: bool = False, ): self.vae = None self.text_encoder = None @@ -69,14 +74,15 @@ def __init__( self.unet = None self.unet_512 = None self.model_max_length = 77 - self.scheduler = scheduler # TODO: Implement using logging python utility. self.log = "" self.status = SD_STATE_IDLE self.sd_model = sd_model + self.scheduler = scheduler self.import_mlir = import_mlir self.use_lora = use_lora self.ondemand = ondemand + self.is_f32_vae = is_f32_vae # TODO: Find a better workaround for fetching base_model_id early # enough for CLIPTokenizer. try: @@ -202,6 +208,9 @@ def encode_prompt_sdxl( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + hf_model_id: Optional[ + str + ] = "stabilityai/stable-diffusion-xl-base-1.0", ): if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -211,7 +220,7 @@ def encode_prompt_sdxl( batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders - self.tokenizer_2 = get_tokenizer("tokenizer_2") + self.tokenizer_2 = get_tokenizer("tokenizer_2", hf_model_id) self.load_clip_sdxl() tokenizers = ( [self.tokenizer, self.tokenizer_2] @@ -332,7 +341,7 @@ def encode_prompt_sdxl( gc.collect() # TODO: Look into dtype for text_encoder_2! - prompt_embeds = prompt_embeds.to(dtype=torch.float32) + prompt_embeds = prompt_embeds.to(dtype=torch.float16) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -523,6 +532,9 @@ def produce_img_latents_sdxl( cpu_scheduling, guidance_scale, dtype, + mask=None, + masked_image_latents=None, + return_all_latents=False, ): # return None self.status = SD_STATE_IDLE @@ -533,11 +545,22 @@ def produce_img_latents_sdxl( step_start_time = time.time() timestep = torch.tensor([t]).to(dtype).detach().numpy() # expand the latents if we are doing classifier free guidance + if isinstance(latents, np.ndarray): + latents = torch.tensor(latents) latent_model_input = torch.cat([latents] * 2) latent_model_input = self.scheduler.scale_model_input( latent_model_input, t - ).to(dtype) + ) + if mask is not None and masked_image_latents is not None: + latent_model_input = torch.cat( + [ + torch.from_numpy(np.asarray(latent_model_input)), + mask, + masked_image_latents, + ], + dim=1, + ).to(dtype) noise_pred = self.unet( "forward", @@ -549,11 +572,17 @@ def produce_img_latents_sdxl( add_time_ids, guidance_scale, ), - send_to_host=False, + send_to_host=True, ) + if not isinstance(latents, torch.Tensor): + latents = torch.from_numpy(latents).to("cpu") + noise_pred = torch.from_numpy(noise_pred).to("cpu") + latents = self.scheduler.step( noise_pred, t, latents, **extra_step_kwargs, return_dict=False )[0] + latents = latents.detach().numpy() + noise_pred = noise_pred.detach().numpy() step_time = (time.time() - step_start_time) * 1000 step_time_sum += step_time @@ -569,11 +598,15 @@ def produce_img_latents_sdxl( return latents - def decode_latents_sdxl(self, latents): - latents = latents.to(torch.float32) + def decode_latents_sdxl(self, latents, is_fp32_vae): + # latents are in unet dtype here so switch if we want to use fp32 + if is_fp32_vae: + print("Casting latents to float32 for VAE") + latents = latents.to(torch.float32) images = self.vae("forward", (latents,)) images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1) images = images.cpu().permute(0, 2, 3, 1).float().numpy() + images = (images * 255).round().astype("uint8") pil_images = [Image.fromarray(image[:, :, :3]) for image in images] @@ -666,6 +699,17 @@ def from_pretrained( return cls( scheduler, sd_model, import_mlir, use_lora, ondemand, stencils ) + if cls.__name__ == "Text2ImageSDXLPipeline": + is_fp32_vae = True if "16" not in custom_vae else False + return cls( + scheduler, + sd_model, + import_mlir, + use_lora, + ondemand, + is_fp32_vae, + ) + return cls(scheduler, sd_model, import_mlir, use_lora, ondemand) # ##################################################### diff --git a/apps/stable_diffusion/src/schedulers/__init__.py b/apps/stable_diffusion/src/schedulers/__init__.py index 8a939b0f05..4e6d8db9e1 100644 --- a/apps/stable_diffusion/src/schedulers/__init__.py +++ b/apps/stable_diffusion/src/schedulers/__init__.py @@ -2,3 +2,6 @@ from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import ( SharkEulerDiscreteScheduler, ) +from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import ( + SharkEulerAncestralDiscreteScheduler, +) diff --git a/apps/stable_diffusion/src/schedulers/sd_schedulers.py b/apps/stable_diffusion/src/schedulers/sd_schedulers.py index 325047c8b0..544fa1efa1 100644 --- a/apps/stable_diffusion/src/schedulers/sd_schedulers.py +++ b/apps/stable_diffusion/src/schedulers/sd_schedulers.py @@ -15,9 +15,22 @@ from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import ( SharkEulerDiscreteScheduler, ) +from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import ( + SharkEulerAncestralDiscreteScheduler, +) def get_schedulers(model_id): + # TODO: Robust scheduler setup on pipeline creation -- if we don't + # set batch_size here, the SHARK schedulers will + # compile with batch size = 1 regardless of whether the model + # outputs latents of a larger batch size, e.g. SDXL. + # This also goes towards enabling batch size cfg for SD in general. + # However, obviously, searching for whether the base model ID + # contains "xl" is not very robust. + + batch_size = 2 if "xl" in model_id.lower() else 1 + schedulers = dict() schedulers["PNDM"] = PNDMScheduler.from_pretrained( model_id, @@ -84,6 +97,12 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) + schedulers[ + "SharkEulerAncestralDiscrete" + ] = SharkEulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) schedulers[ "DPMSolverSinglestep" ] = DPMSolverSinglestepScheduler.from_pretrained( @@ -100,5 +119,6 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) - schedulers["SharkEulerDiscrete"].compile() + schedulers["SharkEulerDiscrete"].compile(batch_size) + schedulers["SharkEulerAncestralDiscrete"].compile(batch_size) return schedulers diff --git a/apps/stable_diffusion/src/schedulers/shark_eulerancestraldiscrete.py b/apps/stable_diffusion/src/schedulers/shark_eulerancestraldiscrete.py new file mode 100644 index 0000000000..c941e56220 --- /dev/null +++ b/apps/stable_diffusion/src/schedulers/shark_eulerancestraldiscrete.py @@ -0,0 +1,251 @@ +import sys +import numpy as np +from typing import List, Optional, Tuple, Union +from diffusers import ( + EulerAncestralDiscreteScheduler, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.configuration_utils import register_to_config +from apps.stable_diffusion.src.utils import ( + compile_through_fx, + get_shark_model, + args, +) +import torch + + +class SharkEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler): + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + super().__init__( + num_train_timesteps, + beta_start, + beta_end, + beta_schedule, + trained_betas, + prediction_type, + timestep_spacing, + steps_offset, + ) + # TODO: make it dynamic so we dont have to worry about batch size + self.batch_size = None + self.init_input_shape = None + + def compile(self, batch_size=1): + SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers" + device = args.device.split(":", 1)[0].strip() + self.batch_size = batch_size + + model_input = { + "eulera": { + "output": torch.randn( + batch_size, 4, args.height // 8, args.width // 8 + ), + "latent": torch.randn( + batch_size, 4, args.height // 8, args.width // 8 + ), + "sigma": torch.tensor(1).to(torch.float32), + "sigma_from": torch.tensor(1).to(torch.float32), + "sigma_to": torch.tensor(1).to(torch.float32), + "noise": torch.randn( + batch_size, 4, args.height // 8, args.width // 8 + ), + }, + } + + example_latent = model_input["eulera"]["latent"] + example_output = model_input["eulera"]["output"] + example_noise = model_input["eulera"]["noise"] + if args.precision == "fp16": + example_latent = example_latent.half() + example_output = example_output.half() + example_noise = example_noise.half() + example_sigma = model_input["eulera"]["sigma"] + example_sigma_from = model_input["eulera"]["sigma_from"] + example_sigma_to = model_input["eulera"]["sigma_to"] + + class ScalingModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, latent, sigma): + return latent / ((sigma**2 + 1) ** 0.5) + + class SchedulerStepEpsilonModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, noise_pred, latent, sigma, sigma_from, sigma_to, noise + ): + sigma_up = ( + sigma_to**2 + * (sigma_from**2 - sigma_to**2) + / sigma_from**2 + ) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + dt = sigma_down - sigma + pred_original_sample = latent - sigma * noise_pred + derivative = (latent - pred_original_sample) / sigma + prev_sample = latent + derivative * dt + return prev_sample + noise * sigma_up + + class SchedulerStepVPredictionModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, noise_pred, sigma, sigma_from, sigma_to, latent, noise + ): + sigma_up = ( + sigma_to**2 + * (sigma_from**2 - sigma_to**2) + / sigma_from**2 + ) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + dt = sigma_down - sigma + pred_original_sample = noise_pred * ( + -sigma / (sigma**2 + 1) ** 0.5 + ) + (latent / (sigma**2 + 1)) + derivative = (latent - pred_original_sample) / sigma + prev_sample = latent + derivative * dt + return prev_sample + noise * sigma_up + + iree_flags = [] + if len(args.iree_vulkan_target_triple) > 0: + iree_flags.append( + f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" + ) + + def _import(self): + scaling_model = ScalingModel() + self.scaling_model, _ = compile_through_fx( + model=scaling_model, + inputs=(example_latent, example_sigma), + extended_model_name=f"euler_a_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_" + + args.precision, + extra_args=iree_flags, + ) + + pred_type_model_dict = { + "epsilon": SchedulerStepEpsilonModel(), + "v_prediction": SchedulerStepVPredictionModel(), + } + step_model = pred_type_model_dict[self.config.prediction_type] + self.step_model, _ = compile_through_fx( + step_model, + ( + example_output, + example_latent, + example_sigma, + example_sigma_from, + example_sigma_to, + example_noise, + ), + extended_model_name=f"euler_a_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_" + + args.precision, + extra_args=iree_flags, + ) + + if args.import_mlir: + _import(self) + + else: + try: + self.scaling_model = get_shark_model( + SCHEDULER_BUCKET, + "euler_a_scale_model_input_" + args.precision, + iree_flags, + ) + self.step_model = get_shark_model( + SCHEDULER_BUCKET, + "euler_a_step_" + + self.config.prediction_type + + args.precision, + iree_flags, + ) + except: + print( + "failed to download model, falling back and using import_mlir" + ) + args.import_mlir = True + _import(self) + + def scale_model_input(self, sample, timestep): + if self.step_index is None: + self._init_step_index(timestep) + sigma = self.sigmas[self.step_index] + return self.scaling_model( + "forward", + ( + sample, + sigma, + ), + send_to_host=False, + ) + + def step( + self, + noise_pred, + timestep, + latent, + generator: Optional[torch.Generator] = None, + return_dict: Optional[bool] = False, + ): + step_inputs = [] + + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + sigma_from = self.sigmas[self.step_index] + sigma_to = self.sigmas[self.step_index + 1] + noise = randn_tensor( + torch.Size(noise_pred.shape), + dtype=torch.float16, + device="cpu", + generator=generator, + ) + step_inputs = [ + noise_pred, + latent, + sigma, + sigma_from, + sigma_to, + noise, + ] + # TODO: deal with dynamic inputs in turbine flow. + # update step index since we're done with the variable and will return with compiled module output. + self._step_index += 1 + + if noise_pred.shape[0] < self.batch_size: + for i in [0, 1, 5]: + try: + step_inputs[i] = torch.tensor(step_inputs[i]) + except: + step_inputs[i] = torch.tensor(step_inputs[i].to_host()) + step_inputs[i] = torch.cat( + (step_inputs[i], step_inputs[i]), axis=0 + ) + return self.step_model( + "forward", + tuple(step_inputs), + send_to_host=True, + ) + + return self.step_model( + "forward", + tuple(step_inputs), + send_to_host=False, + ) diff --git a/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py b/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py index 85aba5d870..c074af4d7c 100644 --- a/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py +++ b/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py @@ -2,12 +2,9 @@ import numpy as np from typing import List, Optional, Tuple, Union from diffusers import ( - LMSDiscreteScheduler, - PNDMScheduler, - DDIMScheduler, - DPMSolverMultistepScheduler, EulerDiscreteScheduler, ) +from diffusers.utils.torch_utils import randn_tensor from diffusers.configuration_utils import register_to_config from apps.stable_diffusion.src.utils import ( compile_through_fx, @@ -27,6 +24,13 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + interpolation_type: str = "linear", + use_karras_sigmas: bool = False, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + timestep_spacing: str = "linspace", + timestep_type: str = "discrete", + steps_offset: int = 0, ): super().__init__( num_train_timesteps, @@ -35,20 +39,29 @@ def __init__( beta_schedule, trained_betas, prediction_type, + interpolation_type, + use_karras_sigmas, + sigma_min, + sigma_max, + timestep_spacing, + timestep_type, + steps_offset, ) + # TODO: make it dynamic so we dont have to worry about batch size + self.batch_size = None - def compile(self): + def compile(self, batch_size=1): SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers" - BATCH_SIZE = args.batch_size device = args.device.split(":", 1)[0].strip() + self.batch_size = batch_size model_input = { "euler": { "latent": torch.randn( - BATCH_SIZE, 4, args.height // 8, args.width // 8 + batch_size, 4, args.height // 8, args.width // 8 ), "output": torch.randn( - BATCH_SIZE, 4, args.height // 8, args.width // 8 + batch_size, 4, args.height // 8, args.width // 8 ), "sigma": torch.tensor(1).to(torch.float32), "dt": torch.tensor(1).to(torch.float32), @@ -70,12 +83,32 @@ def __init__(self): def forward(self, latent, sigma): return latent / ((sigma**2 + 1) ** 0.5) - class SchedulerStepModel(torch.nn.Module): + class SchedulerStepEpsilonModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, noise_pred, sigma_hat, latent, dt): + pred_original_sample = latent - sigma_hat * noise_pred + derivative = (latent - pred_original_sample) / sigma_hat + return latent + derivative * dt + + class SchedulerStepSampleModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, noise_pred, sigma_hat, latent, dt): + pred_original_sample = noise_pred + derivative = (latent - pred_original_sample) / sigma_hat + return latent + derivative * dt + + class SchedulerStepVPredictionModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, noise_pred, sigma, latent, dt): - pred_original_sample = latent - sigma * noise_pred + pred_original_sample = noise_pred * ( + -sigma / (sigma**2 + 1) ** 0.5 + ) + (latent / (sigma**2 + 1)) derivative = (latent - pred_original_sample) / sigma return latent + derivative * dt @@ -90,16 +123,22 @@ def _import(self): self.scaling_model, _ = compile_through_fx( model=scaling_model, inputs=(example_latent, example_sigma), - extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_" + extended_model_name=f"euler_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_" + args.precision, extra_args=iree_flags, ) - step_model = SchedulerStepModel() + pred_type_model_dict = { + "epsilon": SchedulerStepEpsilonModel(), + "v_prediction": SchedulerStepVPredictionModel(), + "sample": SchedulerStepSampleModel(), + "original_sample": SchedulerStepSampleModel(), + } + step_model = pred_type_model_dict[self.config.prediction_type] self.step_model, _ = compile_through_fx( step_model, (example_output, example_sigma, example_latent, example_dt), - extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_" + extended_model_name=f"euler_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_" + args.precision, extra_args=iree_flags, ) @@ -109,6 +148,11 @@ def _import(self): else: try: + step_model_type = ( + "sample" + if "sample" in self.config.prediction_type + else self.config.prediction_type + ) self.scaling_model = get_shark_model( SCHEDULER_BUCKET, "euler_scale_model_input_" + args.precision, @@ -116,7 +160,7 @@ def _import(self): ) self.step_model = get_shark_model( SCHEDULER_BUCKET, - "euler_step_" + args.precision, + "euler_step_" + step_model_type + args.precision, iree_flags, ) except: @@ -138,15 +182,52 @@ def scale_model_input(self, sample, timestep): send_to_host=False, ) - def step(self, noise_pred, timestep, latent): - step_index = (self.timesteps == timestep).nonzero().item() - sigma = self.sigmas[step_index] - dt = self.sigmas[step_index + 1] - sigma + def step( + self, + noise_pred, + timestep, + latent, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: Optional[bool] = False, + ): + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + gamma = ( + min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigma <= s_tmax + else 0.0 + ) + + sigma_hat = sigma * (gamma + 1) + + noise = randn_tensor( + noise_pred.shape, + dtype=noise_pred.dtype, + device="cpu", + generator=generator, + ) + + eps = noise * s_noise + + if gamma > 0: + latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + if self.config.prediction_type == "v_prediction": + sigma_hat = sigma + + dt = self.sigmas[self.step_index + 1] - sigma_hat return self.step_model( "forward", ( noise_pred, - sigma, + sigma_hat, latent, dt, ), diff --git a/apps/stable_diffusion/src/utils/resources/base_model.json b/apps/stable_diffusion/src/utils/resources/base_model.json index 09617ca40a..5cbf965581 100644 --- a/apps/stable_diffusion/src/utils/resources/base_model.json +++ b/apps/stable_diffusion/src/utils/resources/base_model.json @@ -189,6 +189,49 @@ "dtype": "i64" } }, + "stabilityai/sdxl-turbo": { + "latents": { + "shape": [ + "2*batch_size", + 4, + "height", + "width" + ], + "dtype": "f32" + }, + "timesteps": { + "shape": [ + 1 + ], + "dtype": "f32" + }, + "prompt_embeds": { + "shape": [ + "2*batch_size", + "max_len", + 2048 + ], + "dtype": "f32" + }, + "text_embeds": { + "shape": [ + "2*batch_size", + 1280 + ], + "dtype": "f32" + }, + "time_ids": { + "shape": [ + "2*batch_size", + 6 + ], + "dtype": "f32" + }, + "guidance_scale": { + "shape": 1, + "dtype": "f32" + } + }, "stabilityai/stable-diffusion-xl-base-1.0": { "latents": { "shape": [ @@ -449,4 +492,4 @@ } } } -} \ No newline at end of file +} diff --git a/apps/stable_diffusion/src/utils/resources/prompts.json b/apps/stable_diffusion/src/utils/resources/prompts.json index 8387178460..7ecce99e5a 100644 --- a/apps/stable_diffusion/src/utils/resources/prompts.json +++ b/apps/stable_diffusion/src/utils/resources/prompts.json @@ -1,4 +1,5 @@ [["A high tech solarpunk utopia in the Amazon rainforest"], +["Astrophotography, the shark nebula, nebula with a tiny shark-like cloud in the middle in the middle, hubble telescope, vivid colors"], ["A pikachu fine dining with a view to the Eiffel Tower"], ["A mecha robot in a favela in expressionist style"], ["an insect robot preparing a delicious meal"], diff --git a/apps/stable_diffusion/studio_bundle.spec b/apps/stable_diffusion/studio_bundle.spec index 2403b58128..d73abd1e38 100644 --- a/apps/stable_diffusion/studio_bundle.spec +++ b/apps/stable_diffusion/studio_bundle.spec @@ -19,6 +19,9 @@ a = Analysis( win_private_assemblies=False, cipher=block_cipher, noarchive=False, + module_collection_mode={ + 'gradio': 'py', # Collect gradio package as source .py files + }, ) pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index 07242e3fc7..7bb7be9c7c 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -110,11 +110,15 @@ def resource_path(relative_path): txt2img_sendto_outpaint, txt2img_sendto_upscaler, # SDXL - txt2img_sdxl_inf, txt2img_sdxl_web, txt2img_sdxl_custom_model, txt2img_sdxl_gallery, + txt2img_sdxl_png_info_img, txt2img_sdxl_status, + txt2img_sdxl_sendto_img2img, + txt2img_sdxl_sendto_inpaint, + txt2img_sdxl_sendto_outpaint, + txt2img_sdxl_sendto_upscaler, # h2ogpt_upload, # h2ogpt_web, img2img_web, @@ -151,7 +155,7 @@ def resource_path(relative_path): upscaler_sendto_outpaint, # lora_train_web, # model_web, - # model_config_web, + model_config_web, hf_models, modelmanager_sendto_txt2img, modelmanager_sendto_img2img, @@ -165,6 +169,7 @@ def resource_path(relative_path): outputgallery_watch, outputgallery_filename, outputgallery_sendto_txt2img, + outputgallery_sendto_txt2img_sdxl, outputgallery_sendto_img2img, outputgallery_sendto_inpaint, outputgallery_sendto_outpaint, @@ -178,7 +183,7 @@ def register_button_click(button, selectedid, inputs, outputs): button.click( lambda x: ( x[0]["name"] if len(x) != 0 else None, - gr.Tabs.update(selected=selectedid), + gr.Tabs(selected=selectedid), ), inputs, outputs, @@ -189,7 +194,7 @@ def register_modelmanager_button(button, selectedid, inputs, outputs): lambda x: ( "None", x, - gr.Tabs.update(selected=selectedid), + gr.Tabs(selected=selectedid), ), inputs, outputs, @@ -199,7 +204,7 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): button.click( lambda x: ( x, - gr.Tabs.update(selected=selectedid), + gr.Tabs(selected=selectedid), ), inputs, outputs, @@ -241,6 +246,7 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): inpaint_status, outpaint_status, upscaler_status, + txt2img_sdxl_status, ] ) # with gr.TabItem(label="Model Manager", id=6): @@ -249,17 +255,17 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): # lora_train_web.render() with gr.TabItem(label="Chat Bot", id=8): stablelm_chat.render() - # with gr.TabItem( - # label="Generate Sharding Config (Experimental)", id=9 - # ): - # model_config_web.render() - with gr.TabItem(label="MultiModal (Experimental)", id=10): - minigpt4_web.render() + with gr.TabItem( + label="Generate Sharding Config (Experimental)", id=9 + ): + model_config_web.render() + # with gr.TabItem(label="MultiModal (Experimental)", id=10): + # minigpt4_web.render() # with gr.TabItem(label="DocuChat Upload", id=11): # h2ogpt_upload.render() # with gr.TabItem(label="DocuChat(Experimental)", id=12): # h2ogpt_web.render() - with gr.TabItem(label="Text-to-Image-SDXL (Experimental)", id=13): + with gr.TabItem(label="Text-to-Image (SDXL)", id=13): txt2img_sdxl_web.render() actual_port = app.usable_port() diff --git a/apps/stable_diffusion/web/ui/__init__.py b/apps/stable_diffusion/web/ui/__init__.py index 7ba21b73ef..979c129850 100644 --- a/apps/stable_diffusion/web/ui/__init__.py +++ b/apps/stable_diffusion/web/ui/__init__.py @@ -16,6 +16,11 @@ txt2img_sdxl_custom_model, txt2img_sdxl_gallery, txt2img_sdxl_status, + txt2img_sdxl_png_info_img, + txt2img_sdxl_sendto_img2img, + txt2img_sdxl_sendto_inpaint, + txt2img_sdxl_sendto_outpaint, + txt2img_sdxl_sendto_upscaler, ) from apps.stable_diffusion.web.ui.img2img_ui import ( img2img_inf, @@ -83,6 +88,7 @@ outputgallery_watch, outputgallery_filename, outputgallery_sendto_txt2img, + outputgallery_sendto_txt2img_sdxl, outputgallery_sendto_img2img, outputgallery_sendto_inpaint, outputgallery_sendto_outpaint, diff --git a/apps/stable_diffusion/web/ui/img2img_ui.py b/apps/stable_diffusion/web/ui/img2img_ui.py index 0e03287b33..1d8a7f5395 100644 --- a/apps/stable_diffusion/web/ui/img2img_ui.py +++ b/apps/stable_diffusion/web/ui/img2img_ui.py @@ -34,6 +34,7 @@ from apps.stable_diffusion.src.utils.stencils import ( CannyDetector, OpenposeDetector, + ZoeDetector, ) from apps.stable_diffusion.web.utils.common_label_calc import status_label import numpy as np @@ -97,18 +98,21 @@ def img2img_inf( for i, stencil in enumerate(stencils): if images[i] is None and stencil is not None: - return None, "A stencil must have an Image input" + return if images[i] is not None: + if isinstance(images[i], dict): + images[i] = images[i]["composite"] images[i] = images[i].convert("RGB") - if image_dict is None: + if image_dict is None and images[0] is None: return None, "An Initial Image is required" - # if use_stencil == "scribble": - # image = image_dict["mask"].convert("RGB") if isinstance(image_dict, PIL.Image.Image): image = image_dict.convert("RGB") - else: + elif image_dict: image = image_dict["image"].convert("RGB") + else: + # TODO: enable t2i + controlnets + image = None # set ckpt_loc and hf_model_id. args.ckpt_loc = "" @@ -140,9 +144,8 @@ def img2img_inf( if stencil is not None: stencil_count += 1 if stencil_count > 0: - args.scheduler = "DDIM" args.hf_model_id = "runwayml/stable-diffusion-v1-5" - # image, width, height = resize_stencil(image) + image, width, height = resize_stencil(image) elif "Shark" in args.scheduler: print( f"Shark schedulers are not supported. Switching to EulerDiscrete " @@ -363,71 +366,182 @@ def img2img_inf( # TODO: make this import image prompt info if it exists img2img_init_image = gr.Image( label="Input Image", - source="upload", - tool="sketch", type="pil", - height=300, + height=512, + interactive=True, ) with gr.Accordion(label="Multistencil Options", open=False): - choices = ["None", "canny", "openpose", "scribble"] + choices = [ + "None", + "canny", + "openpose", + "scribble", + "zoedepth", + ] def cnet_preview( - checked, model, input_image, index, stencils, images + model, input_image, index, stencils, images ): - if not checked: - stencils[index] = None - images[index] = None - return (None, stencils, images) images[index] = input_image stencils[index] = model match model: case "canny": canny = CannyDetector() - result = canny(np.array(input_image), 100, 200) + result = canny( + np.array(input_image["composite"]), + 100, + 200, + ) return ( - [Image.fromarray(result), result], + Image.fromarray(result), stencils, images, ) case "openpose": openpose = OpenposeDetector() - result = openpose(np.array(input_image)) + result = openpose( + np.array(input_image["composite"]) + ) # TODO: This is just an empty canvas, need to draw the candidates (which are in result[1]) return ( - [Image.fromarray(result[0]), result], + Image.fromarray(result[0]), stencils, images, ) + case "zoedepth": + zoedepth = ZoeDetector() + result = zoedepth( + np.array(input_image["composite"]) + ) + return ( + Image.fromarray(result[0]), + stencils, + images, + ) + case "scribble": + result = input_image["composite"].convert("L") + return (result, stencils, images) case _: return (None, stencils, images) - with gr.Row(): - cnet_1 = gr.Checkbox(show_label=False) - cnet_1_model = gr.Dropdown( - label="Controlnet 1", - value="None", - choices=choices, + def create_canvas(width, height): + data = ( + np.zeros( + shape=(height, width, 3), + dtype=np.uint8, + ) + + 255 ) - cnet_1_image = gr.Image( - source="upload", - tool=None, + return data + + def update_cn_input(model, width, height): + if model == "scribble": + return [ + gr.ImageEditor( + visible=True, + image_mode="RGB", + interactive=True, + show_label=False, + type="pil", + value=create_canvas(width, height), + crop_size=(width, height), + ), + gr.Image( + visible=True, + show_label=False, + interactive=False, + ), + gr.Slider(visible=True), + gr.Slider(visible=True), + gr.Button(visible=True), + ] + else: + return [ + gr.ImageEditor( + visible=True, + image_mode="RGB", + type="pil", + interactive=True, + value=None, + ), + gr.Image( + visible=True, + show_label=False, + interactive=True, + ), + gr.Slider(visible=False), + gr.Slider(visible=False), + gr.Button(visible=False), + ] + + with gr.Row(): + with gr.Column(): + cnet_1 = gr.Button( + value="Generate controlnet input" + ) + cnet_1_model = gr.Dropdown( + label="Controlnet 1", + value="None", + choices=choices, + ) + canvas_width = gr.Slider( + label="Canvas Width", + minimum=256, + maximum=1024, + value=512, + step=1, + visible=False, + ) + canvas_height = gr.Slider( + label="Canvas Height", + minimum=256, + maximum=1024, + value=512, + step=1, + visible=False, + ) + make_canvas = gr.Button( + value="Make Canvas!", + visible=False, + ) + cnet_1_image = gr.ImageEditor( + visible=False, + image_mode="RGB", + interactive=True, + show_label=False, type="pil", ) - cnet_1_output = gr.Gallery( - show_label=False, - object_fit="scale-down", - rows=1, - columns=1, + cnet_1_output = gr.Image( + visible=True, show_label=False + ) + cnet_1_model.input( + update_cn_input, + [cnet_1_model, canvas_width, canvas_height], + [ + cnet_1_image, + cnet_1_output, + canvas_width, + canvas_height, + make_canvas, + ], ) - cnet_1.change( + make_canvas.click( + update_cn_input, + [cnet_1_model, canvas_width, canvas_height], + [ + cnet_1_image, + cnet_1_output, + canvas_width, + canvas_height, + make_canvas, + ], + ) + cnet_1.click( fn=( - lambda a, b, c, s, i: cnet_preview( - a, b, c, 0, s, i - ) + lambda a, b, s, i: cnet_preview(a, b, 0, s, i) ), inputs=[ - cnet_1, cnet_1_model, cnet_1_image, stencils, @@ -436,31 +550,72 @@ def cnet_preview( outputs=[cnet_1_output, stencils, images], ) with gr.Row(): - cnet_2 = gr.Checkbox(show_label=False) - cnet_2_model = gr.Dropdown( - label="Controlnet 2", - value="None", - choices=choices, - ) - cnet_2_image = gr.Image( - source="upload", - tool=None, + with gr.Column(): + cnet_2 = gr.Button( + value="Generate controlnet input" + ) + cnet_2_model = gr.Dropdown( + label="Controlnet 2", + value="None", + choices=choices, + ) + canvas_width = gr.Slider( + label="Canvas Width", + minimum=256, + maximum=1024, + value=512, + step=1, + visible=False, + ) + canvas_height = gr.Slider( + label="Canvas Height", + minimum=256, + maximum=1024, + value=512, + step=1, + visible=False, + ) + make_canvas = gr.Button( + value="Make Canvas!", + visible=False, + ) + cnet_2_image = gr.ImageEditor( + visible=False, + image_mode="RGB", + interactive=True, + show_label=False, type="pil", ) - cnet_2_output = gr.Gallery( - show_label=False, - object_fit="scale-down", - rows=1, - columns=1, + cnet_2_output = gr.Image( + visible=True, show_label=False ) - cnet_2.change( + cnet_2_model.select( + update_cn_input, + [cnet_2_model, canvas_width, canvas_height], + [ + cnet_2_image, + cnet_2_output, + canvas_width, + canvas_height, + make_canvas, + ], + ) + make_canvas.click( + update_cn_input, + [cnet_2_model, canvas_width, canvas_height], + [ + cnet_2_image, + cnet_2_output, + canvas_width, + canvas_height, + make_canvas, + ], + ) + cnet_2.click( fn=( - lambda a, b, c, s, i: cnet_preview( - a, b, c, 1, s, i - ) + lambda a, b, s, i: cnet_preview(a, b, 1, s, i) ), inputs=[ - cnet_2, cnet_2_model, cnet_2_image, stencils, diff --git a/apps/stable_diffusion/web/ui/inpaint_ui.py b/apps/stable_diffusion/web/ui/inpaint_ui.py index 97e643d259..3cc3fc3c96 100644 --- a/apps/stable_diffusion/web/ui/inpaint_ui.py +++ b/apps/stable_diffusion/web/ui/inpaint_ui.py @@ -290,8 +290,7 @@ def inpaint_inf( inpaint_init_image = gr.Image( label="Masked Image", - source="upload", - tool="sketch", + sources="upload", type="pil", height=350, ) diff --git a/apps/stable_diffusion/web/ui/model_manager.py b/apps/stable_diffusion/web/ui/model_manager.py index 16e812b4ed..11e01fe873 100644 --- a/apps/stable_diffusion/web/ui/model_manager.py +++ b/apps/stable_diffusion/web/ui/model_manager.py @@ -104,7 +104,6 @@ def get_image_from_model(model_json): civit_models = gr.Gallery( label="Civitai Model Gallery", value=None, - interactive=True, visible=False, ) diff --git a/apps/stable_diffusion/web/ui/outputgallery_ui.py b/apps/stable_diffusion/web/ui/outputgallery_ui.py index f5b96a6808..711b70a7e0 100644 --- a/apps/stable_diffusion/web/ui/outputgallery_ui.py +++ b/apps/stable_diffusion/web/ui/outputgallery_ui.py @@ -95,7 +95,7 @@ def output_subdirs() -> list[str]: ) with gr.Column(scale=4): - with gr.Box(): + with gr.Group(): with gr.Row(): with gr.Column( scale=15, @@ -152,6 +152,7 @@ def output_subdirs() -> list[str]: wrap=True, elem_classes="output_parameters_dataframe", value=[["Status", "No image selected"]], + interactive=True, ) with gr.Accordion(label="Send To", open=True): @@ -162,6 +163,12 @@ def output_subdirs() -> list[str]: elem_classes="outputgallery_sendto", size="sm", ) + outputgallery_sendto_txt2img_sdxl = gr.Button( + value="Txt2Img XL", + interactive=False, + elem_classes="outputgallery_sendto", + size="sm", + ) outputgallery_sendto_img2img = gr.Button( value="Img2Img", @@ -195,17 +202,17 @@ def output_subdirs() -> list[str]: def on_clear_gallery(): return [ - gr.Gallery.update( + gr.Gallery( value=[], visible=False, ), - gr.Image.update( + gr.Image( visible=True, ), ] def on_image_columns_change(columns): - return gr.Gallery.update(columns=columns) + return gr.Gallery(columns=columns) def on_select_subdir(subdir) -> list: # evt.value is the subdirectory name @@ -215,12 +222,12 @@ def on_select_subdir(subdir) -> list: ) return [ new_images, - gr.Gallery.update( + gr.Gallery( value=new_images, label=new_label, visible=len(new_images) > 0, ), - gr.Image.update( + gr.Image( label=new_label, visible=len(new_images) == 0, ), @@ -254,16 +261,16 @@ def on_refresh(current_subdir: str) -> list: ) return [ - gr.Dropdown.update( + gr.Dropdown( choices=refreshed_subdirs, value=new_subdir, ), refreshed_subdirs, new_images, - gr.Gallery.update( + gr.Gallery( value=new_images, label=new_label, visible=len(new_images) > 0 ), - gr.Image.update( + gr.Image( label=new_label, visible=len(new_images) == 0, ), @@ -289,12 +296,12 @@ def on_new_image(subdir, subdir_paths, status) -> list: return [ new_images, - gr.Gallery.update( + gr.Gallery( value=new_images, label=new_label, visible=len(new_images) > 0, ), - gr.Image.update( + gr.Image( label=new_label, visible=len(new_images) == 0, ), @@ -332,12 +339,12 @@ def on_outputgallery_filename_change(filename: str) -> list: return [ # disable or enable each of the sendto button based on whether # an image is selected - gr.Button.update(interactive=exists), - gr.Button.update(interactive=exists), - gr.Button.update(interactive=exists), - gr.Button.update(interactive=exists), - gr.Button.update(interactive=exists), - gr.Button.update(interactive=exists), + gr.Button(interactive=exists), + gr.Button(interactive=exists), + gr.Button(interactive=exists), + gr.Button(interactive=exists), + gr.Button(interactive=exists), + gr.Button(interactive=exists), ] # The time first our tab is selected we need to do an initial refresh @@ -414,6 +421,7 @@ def on_select_tab(subdir_paths, request: gr.Request): [outputgallery_filename], [ outputgallery_sendto_txt2img, + outputgallery_sendto_txt2img_sdxl, outputgallery_sendto_img2img, outputgallery_sendto_inpaint, outputgallery_sendto_outpaint, diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index 0df3f8442d..f3baa3c5ce 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -431,8 +431,8 @@ def view_json_file(file_obj): config_file = gr.File( label="Upload sharding configuration", visible=False ) - json_view_button = gr.Button(label="View as JSON", visible=False) - json_view = gr.JSON(interactive=True, visible=False) + json_view_button = gr.Button(value="View as JSON", visible=False) + json_view = gr.JSON(visible=False) json_view_button.click( fn=view_json_file, inputs=[config_file], outputs=[json_view] ) diff --git a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py index 7bcbb3c147..c3a653bd13 100644 --- a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py @@ -11,9 +11,11 @@ get_custom_model_path, get_custom_model_files, scheduler_list, - predefined_models, + predefined_sdxl_models, cancel_sd, + set_model_default_configs, ) +from apps.stable_diffusion.web.ui.common_ui_events import lora_changed from apps.stable_diffusion.web.utils.metadata import import_png_metadata from apps.stable_diffusion.web.utils.common_label_calc import status_label from apps.stable_diffusion.src import ( @@ -50,17 +52,17 @@ def txt2img_sdxl_inf( batch_size: int, scheduler: str, model_id: str, + custom_vae: str, precision: str, device: str, max_length: int, save_metadata_to_json: bool, save_metadata_to_png: bool, + lora_weights: str, + lora_hf_id: str, ondemand: bool, repeatable_seeds: bool, ): - if precision != "fp16": - print("currently we support fp16 for SDXL") - precision = "fp16" from apps.stable_diffusion.web.ui.utils import ( get_custom_model_pathfile, get_custom_vae_or_lora_weights, @@ -71,6 +73,10 @@ def txt2img_sdxl_inf( SD_STATE_CANCEL, ) + if precision != "fp16": + print("currently we support fp16 for SDXL") + precision = "fp16" + args.prompts = [prompt] args.negative_prompts = [negative_prompt] args.guidance_scale = guidance_scale @@ -93,13 +99,15 @@ def txt2img_sdxl_inf( else: args.hf_model_id = model_id - # if custom_vae != "None": - # args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae") + if custom_vae: + args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae") args.save_metadata_to_json = save_metadata_to_json args.write_metadata_to_png = save_metadata_to_png - args.use_lora = "" + args.use_lora = get_custom_vae_or_lora_weights( + lora_weights, lora_hf_id, "lora" + ) dtype = torch.float32 if precision == "fp32" else torch.half cpu_scheduling = not scheduler.startswith("Shark") @@ -115,7 +123,7 @@ def txt2img_sdxl_inf( width, device, use_lora=args.use_lora, - use_stencil=None, + stencils=None, ondemand=ondemand, ) if ( @@ -144,31 +152,29 @@ def txt2img_sdxl_inf( ) global_obj.set_schedulers(get_schedulers(model_id)) scheduler_obj = global_obj.get_scheduler(scheduler) - # For SDXL we set max_length as 77. - print("Setting max_length = 77") - max_length = 77 if global_obj.get_cfg_obj().ondemand: print("Running txt2img in memory efficient mode.") - txt2img_sdxl_obj = Text2ImageSDXLPipeline.from_pretrained( - scheduler=scheduler_obj, - import_mlir=args.import_mlir, - model_id=args.hf_model_id, - ckpt_loc=args.ckpt_loc, - precision=precision, - max_length=max_length, - batch_size=batch_size, - height=height, - width=width, - use_base_vae=args.use_base_vae, - use_tuned=args.use_tuned, - custom_vae=args.custom_vae, - low_cpu_mem_usage=args.low_cpu_mem_usage, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - use_quantize=args.use_quantize, - ondemand=global_obj.get_cfg_obj().ondemand, + global_obj.set_sd_obj( + Text2ImageSDXLPipeline.from_pretrained( + scheduler=scheduler_obj, + import_mlir=args.import_mlir, + model_id=args.hf_model_id, + ckpt_loc=args.ckpt_loc, + precision=precision, + max_length=max_length, + batch_size=batch_size, + height=height, + width=width, + use_base_vae=args.use_base_vae, + use_tuned=args.use_tuned, + custom_vae=args.custom_vae, + low_cpu_mem_usage=args.low_cpu_mem_usage, + debug=args.import_debug if args.import_mlir else False, + use_lora=args.use_lora, + use_quantize=args.use_quantize, + ondemand=global_obj.get_cfg_obj().ondemand, + ) ) - global_obj.set_sd_obj(txt2img_sdxl_obj) global_obj.set_sd_scheduler(scheduler) @@ -239,7 +245,7 @@ def txt2img_sdxl_inf( with gr.Row(): with gr.Column(scale=10): with gr.Row(): - t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}" + t2i_sdxl_model_info = f"Custom Model Path: {str(get_custom_model_path())}" txt2img_sdxl_custom_model = gr.Dropdown( label=f"Models", info="Select, or enter HuggingFace Model ID or Civitai model download URL", @@ -247,12 +253,39 @@ def txt2img_sdxl_inf( value=os.path.basename(args.ckpt_loc) if args.ckpt_loc else "stabilityai/stable-diffusion-xl-base-1.0", - choices=[ - "stabilityai/stable-diffusion-xl-base-1.0" - ], + choices=predefined_sdxl_models + + get_custom_model_files( + custom_checkpoint_type="sdxl" + ), allow_custom_value=True, scale=2, ) + t2i_sdxl_vae_info = ( + str(get_custom_model_path("vae")) + ).replace("\\", "\n\\") + t2i_sdxl_vae_info = ( + f"VAE Path: {t2i_sdxl_vae_info}" + ) + custom_vae = gr.Dropdown( + label=f"VAE Models", + info=t2i_sdxl_vae_info, + elem_id="custom_model", + value="None", + choices=[ + None, + "madebyollin/sdxl-vae-fp16-fix", + ] + + get_custom_model_files("vae"), + allow_custom_value=True, + scale=1, + ) + with gr.Column(scale=1, min_width=170): + txt2img_sdxl_png_info_img = gr.Image( + label="Import PNG info", + elem_id="txt2img_prompt_image", + type="pil", + visible=True, + ) with gr.Group(elem_id="prompt_box_outer"): prompt = gr.Textbox( @@ -267,16 +300,49 @@ def txt2img_sdxl_inf( lines=2, elem_id="negative_prompt_box", ) - + with gr.Accordion(label="LoRA Options", open=False): + with gr.Row(): + # janky fix for overflowing text + t2i_sdxl_lora_info = ( + str(get_custom_model_path("lora")) + ).replace("\\", "\n\\") + t2i_sdxl_lora_info = f"LoRA Path: {t2i_sdxl_lora_info}" + lora_weights = gr.Dropdown( + label=f"Standalone LoRA Weights", + info=t2i_sdxl_lora_info, + elem_id="lora_weights", + value="None", + choices=["None"] + get_custom_model_files("lora"), + allow_custom_value=True, + ) + lora_hf_id = gr.Textbox( + elem_id="lora_hf_id", + placeholder="Select 'None' in the Standalone LoRA " + "weights dropdown on the left if you want to use " + "a standalone HuggingFace model ID for LoRA here " + "e.g: sayakpaul/sd-model-finetuned-lora-t4", + value="", + label="HuggingFace Model ID", + lines=3, + ) + with gr.Row(): + lora_tags = gr.HTML( + value="