From 2004d1694580cdfef48b4e00a0041a81c7aeeb2d Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Mon, 9 Oct 2023 20:01:44 -0500 Subject: [PATCH] Revert "[SDXL] Add SDXL pipeline to SHARK (#1731)" (#1882) This reverts commit 9f0a42176474e36f8348f941fd6b9167052365de. --- apps/stable_diffusion/scripts/txt2img.py | 71 ++--- apps/stable_diffusion/src/__init__.py | 1 - .../src/models/model_wrappers.py | 253 +---------------- .../stable_diffusion/src/models/opt_params.py | 4 +- .../src/pipelines/__init__.py | 3 - ...ine_shark_stable_diffusion_txt2img_sdxl.py | 212 -------------- .../pipeline_shark_stable_diffusion_utils.py | 266 ------------------ .../src/utils/resources/base_model.json | 52 ---- .../stable_diffusion/src/utils/stable_args.py | 4 +- apps/stable_diffusion/web/ui/txt2img_ui.py | 49 +--- 10 files changed, 40 insertions(+), 875 deletions(-) delete mode 100644 apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py diff --git a/apps/stable_diffusion/scripts/txt2img.py b/apps/stable_diffusion/scripts/txt2img.py index 235bb96402..f425f48c0f 100644 --- a/apps/stable_diffusion/scripts/txt2img.py +++ b/apps/stable_diffusion/scripts/txt2img.py @@ -1,9 +1,9 @@ import torch +import transformers import time from apps.stable_diffusion.src import ( args, Text2ImagePipeline, - Text2ImageSDXLPipeline, get_schedulers, set_init_device_flags, utils, @@ -16,62 +16,31 @@ def main(): if args.clear_all: clear_all() - # TODO: prompt_embeds and text_embeds form base_model.json requires fixing dtype = torch.float32 if args.precision == "fp32" else torch.half cpu_scheduling = not args.scheduler.startswith("Shark") set_init_device_flags() schedulers = get_schedulers(args.hf_model_id) scheduler_obj = schedulers[args.scheduler] seed = args.seed - if args.height == 1024: - assert ( - args.width == 1024 - ), "currently we support only 1024x1024 image size via SDXL" - assert args.precision == "fp16", "currently we support fp16 for SDXL" - # For SDXL we set max_length as 77. - args.max_length = 77 - txt2img_obj = Text2ImageSDXLPipeline.from_pretrained( - scheduler=scheduler_obj, - import_mlir=args.import_mlir, - model_id=args.hf_model_id, - ckpt_loc=args.ckpt_loc, - precision=args.precision, - max_length=args.max_length, - batch_size=args.batch_size, - height=args.height, - width=args.width, - use_base_vae=args.use_base_vae, - use_tuned=args.use_tuned, - custom_vae=args.custom_vae, - low_cpu_mem_usage=args.low_cpu_mem_usage, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - use_quantize=args.use_quantize, - ondemand=args.ondemand, - ) - else: - assert ( - args.height <= 768 and args.width <= 768 - ), "height/width not in supported range" - txt2img_obj = Text2ImagePipeline.from_pretrained( - scheduler=scheduler_obj, - import_mlir=args.import_mlir, - model_id=args.hf_model_id, - ckpt_loc=args.ckpt_loc, - precision=args.precision, - max_length=args.max_length, - batch_size=args.batch_size, - height=args.height, - width=args.width, - use_base_vae=args.use_base_vae, - use_tuned=args.use_tuned, - custom_vae=args.custom_vae, - low_cpu_mem_usage=args.low_cpu_mem_usage, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - use_quantize=args.use_quantize, - ondemand=args.ondemand, - ) + txt2img_obj = Text2ImagePipeline.from_pretrained( + scheduler=scheduler_obj, + import_mlir=args.import_mlir, + model_id=args.hf_model_id, + ckpt_loc=args.ckpt_loc, + precision=args.precision, + max_length=args.max_length, + batch_size=args.batch_size, + height=args.height, + width=args.width, + use_base_vae=args.use_base_vae, + use_tuned=args.use_tuned, + custom_vae=args.custom_vae, + low_cpu_mem_usage=args.low_cpu_mem_usage, + debug=args.import_debug if args.import_mlir else False, + use_lora=args.use_lora, + use_quantize=args.use_quantize, + ondemand=args.ondemand, + ) seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds) for current_batch in range(args.batch_count): diff --git a/apps/stable_diffusion/src/__init__.py b/apps/stable_diffusion/src/__init__.py index a40bafb798..a30ee16b32 100644 --- a/apps/stable_diffusion/src/__init__.py +++ b/apps/stable_diffusion/src/__init__.py @@ -9,7 +9,6 @@ ) from apps.stable_diffusion.src.pipelines import ( Text2ImagePipeline, - Text2ImageSDXLPipeline, Image2ImagePipeline, InpaintPipeline, OutpaintPipeline, diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index d04e3c0ffe..2ef9f696fa 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -1,5 +1,5 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel -from transformers import CLIPTextModel, CLIPTextModelWithProjection +from transformers import CLIPTextModel from collections import defaultdict from pathlib import Path import torch @@ -53,10 +53,6 @@ def replace_shape_str(shape, max_len, width, height, batch_size): new_shape.append(math.ceil(height / div_val)) elif "width" in shape[i]: new_shape.append(math.ceil(width / div_val)) - elif "+" in shape[i]: - # Currently this case only hits for SDXL. So, in case any other - # case requires this operator, change this. - new_shape.append(height + width) else: new_shape.append(shape[i]) return new_shape @@ -88,7 +84,6 @@ def __init__( generate_vmfb: bool = True, is_inpaint: bool = False, is_upscaler: bool = False, - is_sdxl: bool = False, use_stencil: str = None, use_lora: str = "", use_quantize: str = None, @@ -96,14 +91,8 @@ def __init__( ): self.check_params(max_len, width, height) self.max_len = max_len - self.is_sdxl = is_sdxl - self.height = height - self.width = width - if is_sdxl: - # We need to scale down the height/width by vae_scale_factor, which - # happens to be 8 in this case. - self.height = height // 8 - self.width = width // 8 + self.height = height // 8 + self.width = width // 8 self.batch_size = batch_size self.custom_weights = custom_weights self.use_quantize = use_quantize @@ -185,7 +174,6 @@ def get_extended_name_for_all_model(self): model_name = {} sub_model_list = [ "clip", - "clip2", "unet", "unet512", "stencil_unet", @@ -353,71 +341,6 @@ def forward(self, input): ) return shark_vae, vae_mlir - def get_vae_sdxl(self): - class VaeModel(torch.nn.Module): - def __init__( - self, - model_id=self.model_id, - base_vae=self.base_vae, - custom_vae=self.custom_vae, - low_cpu_mem_usage=False, - ): - super().__init__() - self.vae = None - if custom_vae == "": - self.vae = AutoencoderKL.from_pretrained( - model_id, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - elif not isinstance(custom_vae, dict): - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - else: - self.vae = AutoencoderKL.from_pretrained( - model_id, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - self.vae.load_state_dict(custom_vae) - - def forward(self, latents): - image = self.vae.decode(latents / 0.13025, return_dict=False)[ - 0 - ] - return image - - vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage) - inputs = tuple(self.inputs["vae"]) - # Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL - # pipeline. - is_f16 = False - save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"]) - if self.debug: - os.makedirs(save_dir, exist_ok=True) - vae_name_split = self.model_name["vae"].split("_") - vae_name_split[5] = "fp32" - extended_model_name = "_".join(vae_name_split) - shark_vae, vae_mlir = compile_through_fx( - vae, - inputs, - is_f16=is_f16, - use_tuned=self.use_tuned, - extended_model_name=extended_model_name, - debug=self.debug, - generate_vmfb=self.generate_vmfb, - save_dir=save_dir, - extra_args=get_opt_flags("vae", precision=self.precision), - base_model_id=self.base_model_id, - model_name="vae", - precision=self.precision, - return_mlir=self.return_mlir, - ) - return shark_vae, vae_mlir - def get_controlled_unet(self, use_large=False): class ControlledUnetModel(torch.nn.Module): def __init__( @@ -764,85 +687,6 @@ def forward(self, latent, timestep, text_embedding, noise_level): ) return shark_unet, unet_mlir - def get_unet_sdxl(self): - class UnetModel(torch.nn.Module): - def __init__( - self, - model_id=self.model_id, - low_cpu_mem_usage=False, - ): - super().__init__() - self.unet = UNet2DConditionModel.from_pretrained( - model_id, - subfolder="unet", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - if ( - args.attention_slicing is not None - and args.attention_slicing != "none" - ): - if args.attention_slicing.isdigit(): - self.unet.set_attention_slice( - int(args.attention_slicing) - ) - else: - self.unet.set_attention_slice(args.attention_slicing) - - def forward( - self, - latent, - timestep, - prompt_embeds, - text_embeds, - time_ids, - guidance_scale, - ): - added_cond_kwargs = { - "text_embeds": text_embeds, - "time_ids": time_ids, - } - noise_pred = self.unet.forward( - latent, - timestep, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=None, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - return noise_pred - - unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage) - is_f16 = True if self.precision == "fp16" else False - inputs = tuple(self.inputs["unet"]) - save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"]) - input_mask = [True, True, True, True, True, True] - if self.debug: - os.makedirs( - save_dir, - exist_ok=True, - ) - shark_unet, unet_mlir = compile_through_fx( - unet, - inputs, - extended_model_name=self.model_name["unet"], - is_f16=is_f16, - f16_input_mask=input_mask, - use_tuned=self.use_tuned, - debug=self.debug, - generate_vmfb=self.generate_vmfb, - save_dir=save_dir, - extra_args=get_opt_flags("unet", precision=self.precision), - base_model_id=self.base_model_id, - model_name="unet", - precision=self.precision, - return_mlir=self.return_mlir, - ) - return shark_unet, unet_mlir - def get_clip(self): class CLIPText(torch.nn.Module): def __init__( @@ -890,68 +734,6 @@ def forward(self, input): ) return shark_clip, clip_mlir - def get_clip_sdxl(self, clip_index=1): - class CLIPText(torch.nn.Module): - def __init__( - self, - model_id=self.model_id, - low_cpu_mem_usage=False, - clip_index=1, - ): - super().__init__() - if clip_index == 1: - self.text_encoder = CLIPTextModel.from_pretrained( - model_id, - subfolder="text_encoder", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - else: - self.text_encoder = ( - CLIPTextModelWithProjection.from_pretrained( - model_id, - subfolder="text_encoder_2", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - ) - - def forward(self, input): - prompt_embeds = self.text_encoder( - input, - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] - return prompt_embeds, pooled_prompt_embeds - - clip_model = CLIPText( - low_cpu_mem_usage=self.low_cpu_mem_usage, clip_index=clip_index - ) - if clip_index == 1: - model_name = self.model_name["clip"] - else: - model_name = self.model_name["clip2"] - save_dir = os.path.join(self.sharktank_dir, model_name) - if self.debug: - os.makedirs( - save_dir, - exist_ok=True, - ) - shark_clip, clip_mlir = compile_through_fx( - clip_model, - tuple(self.inputs["clip"]), - extended_model_name=model_name, - debug=self.debug, - generate_vmfb=self.generate_vmfb, - save_dir=save_dir, - extra_args=get_opt_flags("clip", precision="fp32"), - base_model_id=self.base_model_id, - model_name="clip", - precision=self.precision, - return_mlir=self.return_mlir, - ) - return shark_clip, clip_mlir - def process_custom_vae(self): custom_vae = self.custom_vae.lower() if not custom_vae.endswith((".ckpt", ".safetensors")): @@ -984,9 +766,7 @@ def process_custom_vae(self): } return vae_dict - def compile_unet_variants(self, model, use_large=False, base_model=""): - if self.is_sdxl: - return self.get_unet_sdxl() + def compile_unet_variants(self, model, use_large=False): if model == "unet": if self.is_upscaler: return self.get_unet_upscaler(use_large=use_large) @@ -1028,22 +808,6 @@ def clip(self): except Exception as e: sys.exit(e) - def sdxl_clip(self): - try: - self.inputs["clip"] = self.get_input_info_for( - base_models["sdxl_clip"] - ) - compiled_clip, clip_mlir = self.get_clip_sdxl(clip_index=1) - compiled_clip2, clip_mlir2 = self.get_clip_sdxl(clip_index=2) - - check_compilation(compiled_clip, "Clip") - check_compilation(compiled_clip, "Clip2") - if self.return_mlir: - return clip_mlir, clip_mlir2 - return compiled_clip, compiled_clip2 - except Exception as e: - sys.exit(e) - def unet(self, use_large=False): try: model = "stencil_unet" if self.use_stencil is not None else "unet" @@ -1055,7 +819,7 @@ def unet(self, use_large=False): unet_inputs[self.base_model_id] ) compiled_unet, unet_mlir = self.compile_unet_variants( - model, use_large=use_large, base_model=self.base_model_id + model, use_large=use_large ) else: for model_id in unet_inputs: @@ -1066,7 +830,7 @@ def unet(self, use_large=False): try: compiled_unet, unet_mlir = self.compile_unet_variants( - model, use_large=use_large, base_model=model_id + model, use_large=use_large ) except Exception as e: print(e) @@ -1105,10 +869,7 @@ def vae(self): is_base_vae = self.base_vae if self.is_upscaler: self.base_vae = True - if self.is_sdxl: - compiled_vae, vae_mlir = self.get_vae_sdxl() - else: - compiled_vae, vae_mlir = self.get_vae() + compiled_vae, vae_mlir = self.get_vae() self.base_vae = is_base_vae check_compilation(compiled_vae, "Vae") diff --git a/apps/stable_diffusion/src/models/opt_params.py b/apps/stable_diffusion/src/models/opt_params.py index 5dd59b006e..3706a4978d 100644 --- a/apps/stable_diffusion/src/models/opt_params.py +++ b/apps/stable_diffusion/src/models/opt_params.py @@ -123,8 +123,8 @@ def get_clip(): return get_shark_model(bucket, model_name, iree_flags) -def get_tokenizer(subfolder="tokenizer"): +def get_tokenizer(): tokenizer = CLIPTokenizer.from_pretrained( - args.hf_model_id, subfolder=subfolder + args.hf_model_id, subfolder="tokenizer" ) return tokenizer diff --git a/apps/stable_diffusion/src/pipelines/__init__.py b/apps/stable_diffusion/src/pipelines/__init__.py index d65921c781..79d4122c2a 100644 --- a/apps/stable_diffusion/src/pipelines/__init__.py +++ b/apps/stable_diffusion/src/pipelines/__init__.py @@ -1,9 +1,6 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import ( Text2ImagePipeline, ) -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import ( - Text2ImageSDXLPipeline, -) from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import ( Image2ImagePipeline, ) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py deleted file mode 100644 index 6f7146dd5f..0000000000 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py +++ /dev/null @@ -1,212 +0,0 @@ -import torch -import numpy as np -from random import randint -from typing import Union -from diffusers import ( - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - KDPM2DiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DEISMultistepScheduler, - DDPMScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, -) -from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( - StableDiffusionPipeline, -) -from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -class Text2ImageSDXLPipeline(StableDiffusionPipeline): - def __init__( - self, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - KDPM2DiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - SharkEulerDiscreteScheduler, - DEISMultistepScheduler, - DDPMScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, - ], - sd_model: SharkifyStableDiffusionModel, - import_mlir: bool, - use_lora: str, - ondemand: bool, - ): - super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand) - - def prepare_latents( - self, - batch_size, - height, - width, - generator, - num_inference_steps, - dtype, - ): - latents = torch.randn( - ( - batch_size, - 4, - height // 8, - width // 8, - ), - generator=generator, - dtype=torch.float32, - ).to(dtype) - - self.scheduler.set_timesteps(num_inference_steps) - self.scheduler.is_scale_input_called = True - latents = latents * self.scheduler.init_noise_sigma - return latents - - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, dtype - ): - add_time_ids = list( - original_size + crops_coords_top_left + target_size - ) - - # self.unet.config.addition_time_embed_dim IS 256. - # self.text_encoder_2.config.projection_dim IS 1280. - passed_add_embed_dim = 256 * len(add_time_ids) + 1280 - expected_add_embed_dim = 2816 - # self.unet.add_embedding.linear_1.in_features IS 2816. - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids - - def generate_images( - self, - prompts, - neg_prompts, - batch_size, - height, - width, - num_inference_steps, - guidance_scale, - seed, - max_length, - dtype, - use_base_vae, - cpu_scheduling, - max_embeddings_multiples, - ): - # prompts and negative prompts must be a list. - if isinstance(prompts, str): - prompts = [prompts] - - if isinstance(neg_prompts, str): - neg_prompts = [neg_prompts] - - prompts = prompts * batch_size - neg_prompts = neg_prompts * batch_size - - # seed generator to create the inital latent noise. Also handle out of range seeds. - # TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly? - uint32_info = np.iinfo(np.uint32) - uint32_min, uint32_max = uint32_info.min, uint32_info.max - if seed < uint32_min or seed >= uint32_max: - seed = randint(uint32_min, uint32_max) - generator = torch.manual_seed(seed) - - # Get initial latents. - init_latents = self.prepare_latents( - batch_size=batch_size, - height=height, - width=width, - generator=generator, - num_inference_steps=num_inference_steps, - dtype=dtype, - ) - - # Get text embeddings. - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt_sdxl( - prompt=prompts, - num_images_per_prompt=1, - do_classifier_free_guidance=True, - negative_prompt=neg_prompts, - ) - - # Prepare timesteps. - self.scheduler.set_timesteps(num_inference_steps) - - timesteps = self.scheduler.timesteps - - # Prepare added time ids & embeddings. - original_size = (height, width) - target_size = (height, width) - crops_coords_top_left = (0, 0) - add_text_embeds = pooled_prompt_embeds - add_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - dtype=prompt_embeds.dtype, - ) - - prompt_embeds = torch.cat( - [negative_prompt_embeds, prompt_embeds], dim=0 - ) - add_text_embeds = torch.cat( - [negative_pooled_prompt_embeds, add_text_embeds], dim=0 - ) - add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) - - prompt_embeds = prompt_embeds - add_text_embeds = add_text_embeds.to(dtype) - add_time_ids = add_time_ids.repeat(batch_size * 1, 1) - - # guidance scale as a float32 tensor. - guidance_scale = torch.tensor(guidance_scale).to(dtype) - prompt_embeds = prompt_embeds.to(dtype) - add_time_ids = add_time_ids.to(dtype) - - # Get Image latents. - latents = self.produce_img_latents_sdxl( - init_latents, - timesteps, - add_text_embeds, - add_time_ids, - prompt_embeds, - cpu_scheduling, - guidance_scale, - dtype, - ) - - # Img latents -> PIL images. - all_imgs = [] - self.load_vae() - for i in range(0, latents.shape[0], batch_size): - imgs = self.decode_latents_sdxl(latents[i : i + batch_size]) - all_imgs.extend(imgs) - if self.ondemand: - self.unload_vae() - - return all_imgs diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py index 59e1ed3132..dd81f55341 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py @@ -33,7 +33,6 @@ end_profiling, ) import sys -from typing import List, Optional SD_STATE_IDLE = "idle" SD_STATE_CANCEL = "cancel" @@ -64,7 +63,6 @@ def __init__( ): self.vae = None self.text_encoder = None - self.text_encoder_2 = None self.unet = None self.unet_512 = None self.model_max_length = 77 @@ -108,34 +106,6 @@ def unload_clip(self): del self.text_encoder self.text_encoder = None - def load_clip_sdxl(self): - if self.text_encoder and self.text_encoder_2: - return - - if self.import_mlir or self.use_lora: - if not self.import_mlir: - print( - "Warning: LoRA provided but import_mlir not specified. " - "Importing MLIR anyways." - ) - self.text_encoder, self.text_encoder_2 = self.sd_model.sdxl_clip() - else: - try: - # TODO: Fix this for SDXL - self.text_encoder = get_clip() - except Exception as e: - print(e) - print("download pipeline failed, falling back to import_mlir") - ( - self.text_encoder, - self.text_encoder_2, - ) = self.sd_model.sdxl_clip() - - def unload_clip_sdxl(self): - del self.text_encoder, self.text_encoder_2 - self.text_encoder = None - self.text_encoder_2 = None - def load_unet(self): if self.unet is not None: return @@ -190,177 +160,6 @@ def unload_vae(self): del self.vae self.vae = None - def encode_prompt_sdxl( - self, - prompt: str, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - self.tokenizer_2 = get_tokenizer("tokenizer_2") - self.load_clip_sdxl() - tokenizers = ( - [self.tokenizer, self.tokenizer_2] - if self.tokenizer is not None - else [self.tokenizer_2] - ) - text_encoders = ( - [self.text_encoder, self.text_encoder_2] - if self.text_encoder is not None - else [self.text_encoder_2] - ) - - # textual inversion: procecss multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt] - for prompt, tokenizer, text_encoder in zip( - prompts, tokenizers, text_encoders - ): - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = tokenizer.batch_decode( - untruncated_ids[:, tokenizer.model_max_length - 1 : -1] - ) - print( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - text_encoder_output = text_encoder("forward", (text_input_ids,)) - prompt_embeds = torch.from_numpy(text_encoder_output[0]) - pooled_prompt_embeds = torch.from_numpy(text_encoder_output[1]) - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = ( - negative_prompt is None - and self.config.force_zeros_for_empty_prompt - ) - if ( - do_classifier_free_guidance - and negative_prompt_embeds is None - and zero_out_negative_prompt - ): - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like( - pooled_prompt_embeds - ) - elif do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt - - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type( - negative_prompt - ): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] - - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip( - uncond_tokens, tokenizers, text_encoders - ): - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - text_encoder_output = text_encoder( - "forward", (uncond_input.input_ids,) - ) - negative_prompt_embeds = torch.from_numpy( - text_encoder_output[0] - ) - negative_pooled_prompt_embeds = torch.from_numpy( - text_encoder_output[1] - ) - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat( - negative_prompt_embeds_list, dim=-1 - ) - - if self.ondemand: - self.unload_clip_sdxl() - - # TODO: Look into dtype for text_encoder_2! - prompt_embeds = prompt_embeds.to(dtype=torch.float32) - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=torch.float32) - negative_prompt_embeds = negative_prompt_embeds.repeat( - 1, num_images_per_prompt, 1 - ) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat( - 1, num_images_per_prompt - ).view(bs_embed * num_images_per_prompt, -1) - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( - 1, num_images_per_prompt - ).view(bs_embed * num_images_per_prompt, -1) - - return ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) - def encode_prompts(self, prompts, neg_prompts, max_length): # Tokenize text and get embeddings text_input = self.tokenizer( @@ -507,69 +306,6 @@ def produce_img_latents( all_latents = torch.cat(latent_history, dim=0) return all_latents - def produce_img_latents_sdxl( - self, - latents, - total_timesteps, - add_text_embeds, - add_time_ids, - prompt_embeds, - cpu_scheduling, - guidance_scale, - dtype, - ): - self.status = SD_STATE_IDLE - step_time_sum = 0 - extra_step_kwargs = {"generator": None} - self.load_unet() - for i, t in tqdm(enumerate(total_timesteps)): - step_start_time = time.time() - timestep = torch.tensor([t]).to(dtype).detach().numpy() - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) - - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ).to(dtype) - - noise_pred = self.unet( - "forward", - ( - latent_model_input, - timestep, - prompt_embeds, - add_text_embeds, - add_time_ids, - guidance_scale, - ), - send_to_host=False, - ) - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs, return_dict=False - )[0] - - step_time = (time.time() - step_start_time) * 1000 - step_time_sum += step_time - - if self.status == SD_STATE_CANCEL: - break - if self.ondemand: - self.unload_unet() - avg_step_time = step_time_sum / len(total_timesteps) - self.log += f"\nAverage step time: {avg_step_time}ms/it" - - return latents - - def decode_latents_sdxl(self, latents): - latents = latents.to(torch.float32) - images = self.vae("forward", (latents,)) - images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1) - images = images.cpu().permute(0, 2, 3, 1).float().numpy() - images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image[:, :, :3]) for image in images] - - return pil_images - @classmethod def from_pretrained( cls, @@ -619,7 +355,6 @@ def from_pretrained( "OutpaintPipeline", ] is_upscaler = cls.__name__ in ["UpscalerPipeline"] - is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"] sd_model = SharkifyStableDiffusionModel( model_id, @@ -636,7 +371,6 @@ def from_pretrained( debug=debug, is_inpaint=is_inpaint, is_upscaler=is_upscaler, - is_sdxl=is_sdxl, use_stencil=use_stencil, use_lora=use_lora, use_quantize=use_quantize, diff --git a/apps/stable_diffusion/src/utils/resources/base_model.json b/apps/stable_diffusion/src/utils/resources/base_model.json index cede8e0e6e..3666119b97 100644 --- a/apps/stable_diffusion/src/utils/resources/base_model.json +++ b/apps/stable_diffusion/src/utils/resources/base_model.json @@ -8,15 +8,6 @@ "dtype":"i64" } }, - "sdxl_clip": { - "token" : { - "shape" : [ - "1*batch_size", - "max_len" - ], - "dtype":"i64" - } - }, "vae_encode": { "image" : { "shape" : [ @@ -188,49 +179,6 @@ "shape": [2], "dtype": "i64" } - }, - "stabilityai/stable-diffusion-xl-base-1.0": { - "latents": { - "shape": [ - "2*batch_size", - 4, - "height", - "width" - ], - "dtype": "f32" - }, - "timesteps": { - "shape": [ - 1 - ], - "dtype": "f32" - }, - "prompt_embeds": { - "shape": [ - "2*batch_size", - "max_len", - 2048 - ], - "dtype": "f32" - }, - "text_embeds": { - "shape": [ - "2*batch_size", - 1280 - ], - "dtype": "f32" - }, - "time_ids": { - "shape": [ - "2*batch_size", - 6 - ], - "dtype": "f32" - }, - "guidance_scale": { - "shape": 1, - "dtype": "f32" - } } }, "stencil_adaptor": { diff --git a/apps/stable_diffusion/src/utils/stable_args.py b/apps/stable_diffusion/src/utils/stable_args.py index 82da4a34f8..bb5773421a 100644 --- a/apps/stable_diffusion/src/utils/stable_args.py +++ b/apps/stable_diffusion/src/utils/stable_args.py @@ -83,7 +83,7 @@ def is_valid_file(arg): "--height", type=int, default=512, - choices=range(128, 1025, 8), + choices=range(128, 769, 8), help="The height of the output image.", ) @@ -91,7 +91,7 @@ def is_valid_file(arg): "--width", type=int, default=512, - choices=range(128, 1025, 8), + choices=range(128, 769, 8), help="The width of the output image.", ) diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index 81da402bbd..ba2f8c3497 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -22,7 +22,6 @@ from apps.stable_diffusion.src import ( args, Text2ImagePipeline, - Text2ImageSDXLPipeline, get_schedulers, set_init_device_flags, utils, @@ -160,37 +159,8 @@ def txt2img_inf( ) global_obj.set_schedulers(get_schedulers(model_id)) scheduler_obj = global_obj.get_scheduler(scheduler) - if height == 1024: - assert ( - width == 1024 - ), "currently we support only 1024x1024 image size via SDXL" - assert precision == "fp16", "currently we support fp16 for SDXL" - # For SDXL we set max_length as 77. - max_length = 77 - txt2img_obj = Text2ImageSDXLPipeline.from_pretrained( - scheduler=scheduler_obj, - import_mlir=args.import_mlir, - model_id=args.hf_model_id, - ckpt_loc=args.ckpt_loc, - precision=precision, - max_length=max_length, - batch_size=batch_size, - height=height, - width=width, - use_base_vae=args.use_base_vae, - use_tuned=args.use_tuned, - custom_vae=args.custom_vae, - low_cpu_mem_usage=args.low_cpu_mem_usage, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - use_quantize=args.use_quantize, - ondemand=args.ondemand, - ) - else: - assert ( - height <= 768 and width <= 768 - ), "height/width not in supported range" - txt2img_obj = Text2ImagePipeline.from_pretrained( + global_obj.set_sd_obj( + Text2ImagePipeline.from_pretrained( scheduler=scheduler_obj, import_mlir=args.import_mlir, model_id=args.hf_model_id, @@ -198,18 +168,17 @@ def txt2img_inf( precision=args.precision, max_length=args.max_length, batch_size=args.batch_size, - height=height, - width=width, + height=args.height, + width=args.width, use_base_vae=args.use_base_vae, use_tuned=args.use_tuned, custom_vae=args.custom_vae, low_cpu_mem_usage=args.low_cpu_mem_usage, debug=args.import_debug if args.import_mlir else False, use_lora=args.use_lora, - use_quantize=args.use_quantize, ondemand=args.ondemand, ) - global_obj.set_sd_obj(txt2img_obj) + ) global_obj.set_sd_scheduler(scheduler) @@ -533,15 +502,15 @@ def txt2img_api( ) with gr.Row(): height = gr.Slider( - 128, - 1024, + 384, + 768, value=args.height, step=8, label="Height", ) width = gr.Slider( - 128, - 1024, + 384, + 768, value=args.width, step=8, label="Width",