diff --git a/apps/stable_diffusion/scripts/txt2img_sdxl.py b/apps/stable_diffusion/scripts/txt2img_sdxl.py new file mode 100644 index 0000000000..e930c6052c --- /dev/null +++ b/apps/stable_diffusion/scripts/txt2img_sdxl.py @@ -0,0 +1,96 @@ +import torch +import time +from apps.stable_diffusion.src import ( + args, + Text2ImageSDXLPipeline, + get_schedulers, + set_init_device_flags, + utils, + clear_all, + save_output_img, +) + + +def main(): + if args.clear_all: + clear_all() + + # TODO: prompt_embeds and text_embeds form base_model.json requires fixing + args.precision = "fp16" + args.height = 1024 + args.width = 1024 + args.max_length = 77 + args.scheduler = "DDIM" + print( + "Using default supported configuration for SDXL :-\nprecision=fp16, width*height= 1024*1024, max_length=77 and scheduler=DDIM" + ) + dtype = torch.float32 if args.precision == "fp32" else torch.half + cpu_scheduling = not args.scheduler.startswith("Shark") + set_init_device_flags() + schedulers = get_schedulers(args.hf_model_id) + scheduler_obj = schedulers[args.scheduler] + seed = args.seed + txt2img_obj = Text2ImageSDXLPipeline.from_pretrained( + scheduler=scheduler_obj, + import_mlir=args.import_mlir, + model_id=args.hf_model_id, + ckpt_loc=args.ckpt_loc, + precision=args.precision, + max_length=args.max_length, + batch_size=args.batch_size, + height=args.height, + width=args.width, + use_base_vae=args.use_base_vae, + use_tuned=args.use_tuned, + custom_vae=args.custom_vae, + low_cpu_mem_usage=args.low_cpu_mem_usage, + debug=args.import_debug if args.import_mlir else False, + use_lora=args.use_lora, + use_quantize=args.use_quantize, + ondemand=args.ondemand, + ) + + seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds) + for current_batch in range(args.batch_count): + start_time = time.time() + generated_imgs = txt2img_obj.generate_images( + args.prompts, + args.negative_prompts, + args.batch_size, + args.height, + args.width, + args.steps, + args.guidance_scale, + seeds[current_batch], + args.max_length, + dtype, + args.use_base_vae, + cpu_scheduling, + args.max_embeddings_multiples, + ) + total_time = time.time() - start_time + text_output = f"prompt={args.prompts}" + text_output += f"\nnegative prompt={args.negative_prompts}" + text_output += ( + f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}" + ) + text_output += f"\nscheduler={args.scheduler}, device={args.device}" + text_output += ( + f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}," + ) + text_output += ( + f"seed={seeds[current_batch]}, size={args.height}x{args.width}" + ) + text_output += ( + f", batch size={args.batch_size}, max_length={args.max_length}" + ) + # TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start + text_output += txt2img_obj.log + text_output += f"\nTotal image generation time: {total_time:.4f}sec" + + save_output_img(generated_imgs[0], seed) + print(text_output) + + +if __name__ == "__main__": + main() diff --git a/apps/stable_diffusion/src/__init__.py b/apps/stable_diffusion/src/__init__.py index a30ee16b32..a40bafb798 100644 --- a/apps/stable_diffusion/src/__init__.py +++ b/apps/stable_diffusion/src/__init__.py @@ -9,6 +9,7 @@ ) from apps.stable_diffusion.src.pipelines import ( Text2ImagePipeline, + Text2ImageSDXLPipeline, Image2ImagePipeline, InpaintPipeline, OutpaintPipeline, diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index 016bf1787e..357d8e93b6 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -1,5 +1,5 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel -from transformers import CLIPTextModel +from transformers import CLIPTextModel, CLIPTextModelWithProjection from collections import defaultdict from pathlib import Path import torch @@ -24,6 +24,8 @@ get_stencil_model_id, update_lora_weight, ) +from shark.shark_downloader import download_public_file +from shark.shark_inference import SharkInference # These shapes are parameter dependent. @@ -55,6 +57,10 @@ def replace_shape_str(shape, max_len, width, height, batch_size): new_shape.append(math.ceil(height / div_val)) elif "width" in shape[i]: new_shape.append(math.ceil(width / div_val)) + elif "+" in shape[i]: + # Currently this case only hits for SDXL. So, in case any other + # case requires this operator, change this. + new_shape.append(height + width) else: new_shape.append(shape[i]) return new_shape @@ -67,6 +73,70 @@ def check_compilation(model, model_name): ) +def shark_compile_after_ir( + module_name, + device, + vmfb_path, + precision, + ir_path=None, +): + if ir_path: + print(f"[DEBUG] mlir found at {ir_path.absolute()}") + + module = SharkInference( + mlir_module=ir_path, + device=device, + mlir_dialect="tm_tensor", + ) + print(f"Will get extra flag for {module_name} and precision = {precision}") + path = module.save_module( + vmfb_path.parent.absolute(), + vmfb_path.stem, + extra_args=get_opt_flags(module_name, precision=precision), + ) + print(f"Saved {module_name} vmfb at {path}") + module.load_module(path) + return module + + +def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision): + name_split = extended_model_name.split("_") + if "vae" in model_name: + name_split[5] = "fp32" + extended_model_name_for_vmfb = "_".join(name_split) + extended_model_name_for_mlir = "_".join(name_split[:-1]) + vmfb_path = Path(extended_model_name_for_vmfb + ".vmfb") + if "vulkan" in device: + _device = args.iree_vulkan_target_triple + _device = _device.replace("-", "_") + vmfb_path = Path(extended_model_name_for_vmfb + f"_{_device}.vmfb") + if vmfb_path.exists(): + shark_module = SharkInference( + None, + device=device, + mlir_dialect="tm_tensor", + ) + print(f"loading existing vmfb from: {vmfb_path}") + shark_module.load_module(vmfb_path, extra_args=[]) + return shark_module, None + mlir_path = Path(extended_model_name_for_mlir + ".mlir") + if not mlir_path.exists(): + print(f"Looking into gs://shark_tank/SDXL/mlir/{mlir_path.name}") + download_public_file( + f"gs://shark_tank/SDXL/mlir/{mlir_path.name}", + mlir_path.absolute(), + single_file=True, + ) + if mlir_path.exists(): + return ( + shark_compile_after_ir( + model_name, device, vmfb_path, precision, mlir_path + ), + None, + ) + return None, None + + class SharkifyStableDiffusionModel: def __init__( self, @@ -86,6 +156,7 @@ def __init__( generate_vmfb: bool = True, is_inpaint: bool = False, is_upscaler: bool = False, + is_sdxl: bool = False, use_stencil: str = None, use_lora: str = "", use_quantize: str = None, @@ -93,8 +164,14 @@ def __init__( ): self.check_params(max_len, width, height) self.max_len = max_len - self.height = height // 8 - self.width = width // 8 + self.is_sdxl = is_sdxl + self.height = height + self.width = width + if is_sdxl: + # We need to scale down the height/width by vae_scale_factor, which + # happens to be 8 in this case. + self.height = height // 8 + self.width = width // 8 self.batch_size = batch_size self.custom_weights = custom_weights.strip() self.use_quantize = use_quantize @@ -175,6 +252,7 @@ def get_extended_name_for_all_model(self): model_name = {} sub_model_list = [ "clip", + "clip2", "unet", "unet512", "stencil_unet", @@ -342,6 +420,76 @@ def forward(self, input): ) return shark_vae, vae_mlir + def get_vae_sdxl(self): + # TODO: Remove this after convergence with shark_tank. This should just be part of + # opt_params.py. + shark_module_or_none = process_vmfb_ir_sdxl( + self.model_name["vae"], "vae", args.device, self.precision + ) + if shark_module_or_none[0]: + return shark_module_or_none + + class VaeModel(torch.nn.Module): + def __init__( + self, + model_id=self.model_id, + base_vae=self.base_vae, + custom_vae=self.custom_vae, + low_cpu_mem_usage=False, + ): + super().__init__() + self.vae = None + if custom_vae == "": + self.vae = AutoencoderKL.from_pretrained( + model_id, + subfolder="vae", + low_cpu_mem_usage=low_cpu_mem_usage, + ) + elif not isinstance(custom_vae, dict): + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + subfolder="vae", + low_cpu_mem_usage=low_cpu_mem_usage, + ) + else: + self.vae = AutoencoderKL.from_pretrained( + model_id, + subfolder="vae", + low_cpu_mem_usage=low_cpu_mem_usage, + ) + self.vae.load_state_dict(custom_vae) + + def forward(self, latents): + image = self.vae.decode(latents / 0.13025, return_dict=False)[ + 0 + ] + return image + + vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage) + inputs = tuple(self.inputs["vae"]) + # Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL + # pipeline. + is_f16 = False + save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"]) + if self.debug: + os.makedirs(save_dir, exist_ok=True) + shark_vae, vae_mlir = compile_through_fx( + vae, + inputs, + is_f16=is_f16, + use_tuned=self.use_tuned, + extended_model_name=self.model_name["vae"], + debug=self.debug, + generate_vmfb=self.generate_vmfb, + save_dir=save_dir, + extra_args=get_opt_flags("vae", precision=self.precision), + base_model_id=self.base_model_id, + model_name="vae", + precision=self.precision, + return_mlir=self.return_mlir, + ) + return shark_vae, vae_mlir + def get_controlled_unet(self, use_large=False): class ControlledUnetModel(torch.nn.Module): def __init__( @@ -688,6 +836,93 @@ def forward(self, latent, timestep, text_embedding, noise_level): ) return shark_unet, unet_mlir + def get_unet_sdxl(self): + # TODO: Remove this after convergence with shark_tank. This should just be part of + # opt_params.py. + shark_module_or_none = process_vmfb_ir_sdxl( + self.model_name["unet"], "unet", args.device, self.precision + ) + if shark_module_or_none[0]: + return shark_module_or_none + + class UnetModel(torch.nn.Module): + def __init__( + self, + model_id=self.model_id, + low_cpu_mem_usage=False, + ): + super().__init__() + self.unet = UNet2DConditionModel.from_pretrained( + model_id, + subfolder="unet", + low_cpu_mem_usage=low_cpu_mem_usage, + ) + if ( + args.attention_slicing is not None + and args.attention_slicing != "none" + ): + if args.attention_slicing.isdigit(): + self.unet.set_attention_slice( + int(args.attention_slicing) + ) + else: + self.unet.set_attention_slice(args.attention_slicing) + + def forward( + self, + latent, + timestep, + prompt_embeds, + text_embeds, + time_ids, + guidance_scale, + ): + added_cond_kwargs = { + "text_embeds": text_embeds, + "time_ids": time_ids, + } + noise_pred = self.unet.forward( + latent, + timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + return noise_pred + + unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage) + is_f16 = True if self.precision == "fp16" else False + inputs = tuple(self.inputs["unet"]) + save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"]) + input_mask = [True, True, True, True, True, True] + if self.debug: + os.makedirs( + save_dir, + exist_ok=True, + ) + shark_unet, unet_mlir = compile_through_fx( + unet, + inputs, + extended_model_name=self.model_name["unet"], + is_f16=is_f16, + f16_input_mask=input_mask, + use_tuned=self.use_tuned, + debug=self.debug, + generate_vmfb=self.generate_vmfb, + save_dir=save_dir, + extra_args=get_opt_flags("unet", precision=self.precision), + base_model_id=self.base_model_id, + model_name="unet", + precision=self.precision, + return_mlir=self.return_mlir, + ) + return shark_unet, unet_mlir + def get_clip(self): class CLIPText(torch.nn.Module): def __init__( @@ -735,6 +970,78 @@ def forward(self, input): ) return shark_clip, clip_mlir + def get_clip_sdxl(self, clip_index=1): + if clip_index == 1: + extended_model_name = self.model_name["clip"] + model_name = "clip" + else: + extended_model_name = self.model_name["clip2"] + model_name = "clip2" + # TODO: Remove this after convergence with shark_tank. This should just be part of + # opt_params.py. + shark_module_or_none = process_vmfb_ir_sdxl( + extended_model_name, f"clip", args.device, self.precision + ) + if shark_module_or_none[0]: + return shark_module_or_none + + class CLIPText(torch.nn.Module): + def __init__( + self, + model_id=self.model_id, + low_cpu_mem_usage=False, + clip_index=1, + ): + super().__init__() + if clip_index == 1: + self.text_encoder = CLIPTextModel.from_pretrained( + model_id, + subfolder="text_encoder", + low_cpu_mem_usage=low_cpu_mem_usage, + ) + else: + self.text_encoder = ( + CLIPTextModelWithProjection.from_pretrained( + model_id, + subfolder="text_encoder_2", + low_cpu_mem_usage=low_cpu_mem_usage, + ) + ) + + def forward(self, input): + prompt_embeds = self.text_encoder( + input, + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + return prompt_embeds, pooled_prompt_embeds + + clip_model = CLIPText( + low_cpu_mem_usage=self.low_cpu_mem_usage, clip_index=clip_index + ) + save_dir = os.path.join(self.sharktank_dir, extended_model_name) + if self.debug: + os.makedirs( + save_dir, + exist_ok=True, + ) + shark_clip, clip_mlir = compile_through_fx( + clip_model, + tuple(self.inputs["clip"]), + extended_model_name=extended_model_name, + debug=self.debug, + generate_vmfb=self.generate_vmfb, + save_dir=save_dir, + extra_args=get_opt_flags("clip", precision="fp32"), + base_model_id=self.base_model_id, + model_name="clip", + precision=self.precision, + return_mlir=self.return_mlir, + ) + return shark_clip, clip_mlir + def process_custom_vae(self): custom_vae = self.custom_vae.lower() if not custom_vae.endswith((".ckpt", ".safetensors")): @@ -767,7 +1074,9 @@ def process_custom_vae(self): } return vae_dict - def compile_unet_variants(self, model, use_large=False): + def compile_unet_variants(self, model, use_large=False, base_model=""): + if self.is_sdxl: + return self.get_unet_sdxl() if model == "unet": if self.is_upscaler: return self.get_unet_upscaler(use_large=use_large) @@ -809,6 +1118,22 @@ def clip(self): except Exception as e: sys.exit(e) + def sdxl_clip(self): + try: + self.inputs["clip"] = self.get_input_info_for( + base_models["sdxl_clip"] + ) + compiled_clip, clip_mlir = self.get_clip_sdxl(clip_index=1) + compiled_clip2, clip_mlir2 = self.get_clip_sdxl(clip_index=2) + + check_compilation(compiled_clip, "Clip") + check_compilation(compiled_clip, "Clip2") + if self.return_mlir: + return clip_mlir, clip_mlir2 + return compiled_clip, compiled_clip2 + except Exception as e: + sys.exit(e) + def unet(self, use_large=False): try: model = "stencil_unet" if self.use_stencil is not None else "unet" @@ -820,7 +1145,7 @@ def unet(self, use_large=False): unet_inputs[self.base_model_id] ) compiled_unet, unet_mlir = self.compile_unet_variants( - model, use_large=use_large + model, use_large=use_large, base_model=self.base_model_id ) else: for model_id in unet_inputs: @@ -831,7 +1156,7 @@ def unet(self, use_large=False): try: compiled_unet, unet_mlir = self.compile_unet_variants( - model, use_large=use_large + model, use_large=use_large, base_model=model_id ) except Exception as e: print(e) @@ -870,7 +1195,10 @@ def vae(self): is_base_vae = self.base_vae if self.is_upscaler: self.base_vae = True - compiled_vae, vae_mlir = self.get_vae() + if self.is_sdxl: + compiled_vae, vae_mlir = self.get_vae_sdxl() + else: + compiled_vae, vae_mlir = self.get_vae() self.base_vae = is_base_vae check_compilation(compiled_vae, "Vae") diff --git a/apps/stable_diffusion/src/models/opt_params.py b/apps/stable_diffusion/src/models/opt_params.py index 3706a4978d..5dd59b006e 100644 --- a/apps/stable_diffusion/src/models/opt_params.py +++ b/apps/stable_diffusion/src/models/opt_params.py @@ -123,8 +123,8 @@ def get_clip(): return get_shark_model(bucket, model_name, iree_flags) -def get_tokenizer(): +def get_tokenizer(subfolder="tokenizer"): tokenizer = CLIPTokenizer.from_pretrained( - args.hf_model_id, subfolder="tokenizer" + args.hf_model_id, subfolder=subfolder ) return tokenizer diff --git a/apps/stable_diffusion/src/pipelines/__init__.py b/apps/stable_diffusion/src/pipelines/__init__.py index 79d4122c2a..d65921c781 100644 --- a/apps/stable_diffusion/src/pipelines/__init__.py +++ b/apps/stable_diffusion/src/pipelines/__init__.py @@ -1,6 +1,9 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import ( Text2ImagePipeline, ) +from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import ( + Text2ImageSDXLPipeline, +) from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import ( Image2ImagePipeline, ) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py new file mode 100644 index 0000000000..42df60af9a --- /dev/null +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py @@ -0,0 +1,214 @@ +import torch +import numpy as np +from random import randint +from typing import Union +from diffusers import ( + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + KDPM2DiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DEISMultistepScheduler, + DDPMScheduler, + DPMSolverSinglestepScheduler, + KDPM2AncestralDiscreteScheduler, + HeunDiscreteScheduler, +) +from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler +from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( + StableDiffusionPipeline, +) +from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class Text2ImageSDXLPipeline(StableDiffusionPipeline): + def __init__( + self, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + KDPM2DiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + SharkEulerDiscreteScheduler, + DEISMultistepScheduler, + DDPMScheduler, + DPMSolverSinglestepScheduler, + KDPM2AncestralDiscreteScheduler, + HeunDiscreteScheduler, + ], + sd_model: SharkifyStableDiffusionModel, + import_mlir: bool, + use_lora: str, + ondemand: bool, + ): + super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand) + + def prepare_latents( + self, + batch_size, + height, + width, + generator, + num_inference_steps, + dtype, + ): + latents = torch.randn( + ( + batch_size, + 4, + height // 8, + width // 8, + ), + generator=generator, + dtype=torch.float32, + ).to(dtype) + + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.is_scale_input_called = True + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype + ): + add_time_ids = list( + original_size + crops_coords_top_left + target_size + ) + + # self.unet.config.addition_time_embed_dim IS 256. + # self.text_encoder_2.config.projection_dim IS 1280. + passed_add_embed_dim = 256 * len(add_time_ids) + 1280 + expected_add_embed_dim = 2816 + # self.unet.add_embedding.linear_1.in_features IS 2816. + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def generate_images( + self, + prompts, + neg_prompts, + batch_size, + height, + width, + num_inference_steps, + guidance_scale, + seed, + max_length, + dtype, + use_base_vae, + cpu_scheduling, + max_embeddings_multiples, + ): + # prompts and negative prompts must be a list. + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(neg_prompts, str): + neg_prompts = [neg_prompts] + + prompts = prompts * batch_size + neg_prompts = neg_prompts * batch_size + + # seed generator to create the inital latent noise. Also handle out of range seeds. + # TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly? + uint32_info = np.iinfo(np.uint32) + uint32_min, uint32_max = uint32_info.min, uint32_info.max + if seed < uint32_min or seed >= uint32_max: + seed = randint(uint32_min, uint32_max) + generator = torch.manual_seed(seed) + + # Get initial latents. + init_latents = self.prepare_latents( + batch_size=batch_size, + height=height, + width=width, + generator=generator, + num_inference_steps=num_inference_steps, + dtype=dtype, + ) + + # Get text embeddings. + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt_sdxl( + prompt=prompts, + num_images_per_prompt=1, + do_classifier_free_guidance=True, + negative_prompt=neg_prompts, + ) + + # Prepare timesteps. + self.scheduler.set_timesteps(num_inference_steps) + + timesteps = self.scheduler.timesteps + + # Prepare added time ids & embeddings. + original_size = (height, width) + target_size = (height, width) + crops_coords_top_left = (0, 0) + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + ) + + prompt_embeds = torch.cat( + [negative_prompt_embeds, prompt_embeds], dim=0 + ) + add_text_embeds = torch.cat( + [negative_pooled_prompt_embeds, add_text_embeds], dim=0 + ) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds + add_text_embeds = add_text_embeds.to(dtype) + add_time_ids = add_time_ids.repeat(batch_size * 1, 1) + + # guidance scale as a float32 tensor. + guidance_scale = torch.tensor(guidance_scale).to(dtype) + prompt_embeds = prompt_embeds.to(dtype) + add_time_ids = add_time_ids.to(dtype) + + # Get Image latents. + latents = self.produce_img_latents_sdxl( + init_latents, + timesteps, + add_text_embeds, + add_time_ids, + prompt_embeds, + cpu_scheduling, + guidance_scale, + dtype, + ) + + # Img latents -> PIL images. + all_imgs = [] + self.load_vae() + # imgs = self.decode_latents_sdxl(None) + # all_imgs.extend(imgs) + for i in range(0, latents.shape[0], batch_size): + imgs = self.decode_latents_sdxl(latents[i : i + batch_size]) + all_imgs.extend(imgs) + if self.ondemand: + self.unload_vae() + + return all_imgs diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py index dd81f55341..3266c73fb6 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py @@ -33,6 +33,7 @@ end_profiling, ) import sys +from typing import List, Optional SD_STATE_IDLE = "idle" SD_STATE_CANCEL = "cancel" @@ -63,6 +64,7 @@ def __init__( ): self.vae = None self.text_encoder = None + self.text_encoder_2 = None self.unet = None self.unet_512 = None self.model_max_length = 77 @@ -106,6 +108,34 @@ def unload_clip(self): del self.text_encoder self.text_encoder = None + def load_clip_sdxl(self): + if self.text_encoder and self.text_encoder_2: + return + + if self.import_mlir or self.use_lora: + if not self.import_mlir: + print( + "Warning: LoRA provided but import_mlir not specified. " + "Importing MLIR anyways." + ) + self.text_encoder, self.text_encoder_2 = self.sd_model.sdxl_clip() + else: + try: + # TODO: Fix this for SDXL + self.text_encoder = get_clip() + except Exception as e: + print(e) + print("download pipeline failed, falling back to import_mlir") + ( + self.text_encoder, + self.text_encoder_2, + ) = self.sd_model.sdxl_clip() + + def unload_clip_sdxl(self): + del self.text_encoder, self.text_encoder_2 + self.text_encoder = None + self.text_encoder_2 = None + def load_unet(self): if self.unet is not None: return @@ -160,6 +190,177 @@ def unload_vae(self): del self.vae self.vae = None + def encode_prompt_sdxl( + self, + prompt: str, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + self.tokenizer_2 = get_tokenizer("tokenizer_2") + self.load_clip_sdxl() + tokenizers = ( + [self.tokenizer, self.tokenizer_2] + if self.tokenizer is not None + else [self.tokenizer_2] + ) + text_encoders = ( + [self.text_encoder, self.text_encoder_2] + if self.text_encoder is not None + else [self.text_encoder_2] + ) + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt] + for prompt, tokenizer, text_encoder in zip( + prompts, tokenizers, text_encoders + ): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, tokenizer.model_max_length - 1 : -1] + ) + print( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_encoder_output = text_encoder("forward", (text_input_ids,)) + prompt_embeds = torch.from_numpy(text_encoder_output[0]) + pooled_prompt_embeds = torch.from_numpy(text_encoder_output[1]) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = ( + negative_prompt is None + and self.config.force_zeros_for_empty_prompt + ) + if ( + do_classifier_free_guidance + and negative_prompt_embeds is None + and zero_out_negative_prompt + ): + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like( + pooled_prompt_embeds + ) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type( + negative_prompt + ): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt, negative_prompt_2] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip( + uncond_tokens, tokenizers, text_encoders + ): + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + text_encoder_output = text_encoder( + "forward", (uncond_input.input_ids,) + ) + negative_prompt_embeds = torch.from_numpy( + text_encoder_output[0] + ) + negative_pooled_prompt_embeds = torch.from_numpy( + text_encoder_output[1] + ) + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat( + negative_prompt_embeds_list, dim=-1 + ) + + if self.ondemand: + self.unload_clip_sdxl() + + # TODO: Look into dtype for text_encoder_2! + prompt_embeds = prompt_embeds.to(dtype=torch.float32) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=torch.float32) + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat( + 1, num_images_per_prompt + ).view(bs_embed * num_images_per_prompt, -1) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( + 1, num_images_per_prompt + ).view(bs_embed * num_images_per_prompt, -1) + + return ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + def encode_prompts(self, prompts, neg_prompts, max_length): # Tokenize text and get embeddings text_input = self.tokenizer( @@ -306,6 +507,70 @@ def produce_img_latents( all_latents = torch.cat(latent_history, dim=0) return all_latents + def produce_img_latents_sdxl( + self, + latents, + total_timesteps, + add_text_embeds, + add_time_ids, + prompt_embeds, + cpu_scheduling, + guidance_scale, + dtype, + ): + # return None + self.status = SD_STATE_IDLE + step_time_sum = 0 + extra_step_kwargs = {"generator": None} + self.load_unet() + for i, t in tqdm(enumerate(total_timesteps)): + step_start_time = time.time() + timestep = torch.tensor([t]).to(dtype).detach().numpy() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ).to(dtype) + + noise_pred = self.unet( + "forward", + ( + latent_model_input, + timestep, + prompt_embeds, + add_text_embeds, + add_time_ids, + guidance_scale, + ), + send_to_host=False, + ) + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] + + step_time = (time.time() - step_start_time) * 1000 + step_time_sum += step_time + + if self.status == SD_STATE_CANCEL: + break + if self.ondemand: + self.unload_unet() + avg_step_time = step_time_sum / len(total_timesteps) + self.log += f"\nAverage step time: {avg_step_time}ms/it" + + return latents + + def decode_latents_sdxl(self, latents): + latents = latents.to(torch.float32) + images = self.vae("forward", (latents,)) + images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1) + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image[:, :, :3]) for image in images] + + return pil_images + @classmethod def from_pretrained( cls, @@ -355,6 +620,7 @@ def from_pretrained( "OutpaintPipeline", ] is_upscaler = cls.__name__ in ["UpscalerPipeline"] + is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"] sd_model = SharkifyStableDiffusionModel( model_id, @@ -371,6 +637,7 @@ def from_pretrained( debug=debug, is_inpaint=is_inpaint, is_upscaler=is_upscaler, + is_sdxl=is_sdxl, use_stencil=use_stencil, use_lora=use_lora, use_quantize=use_quantize, diff --git a/apps/stable_diffusion/src/utils/resources/base_model.json b/apps/stable_diffusion/src/utils/resources/base_model.json index 3666119b97..cede8e0e6e 100644 --- a/apps/stable_diffusion/src/utils/resources/base_model.json +++ b/apps/stable_diffusion/src/utils/resources/base_model.json @@ -8,6 +8,15 @@ "dtype":"i64" } }, + "sdxl_clip": { + "token" : { + "shape" : [ + "1*batch_size", + "max_len" + ], + "dtype":"i64" + } + }, "vae_encode": { "image" : { "shape" : [ @@ -179,6 +188,49 @@ "shape": [2], "dtype": "i64" } + }, + "stabilityai/stable-diffusion-xl-base-1.0": { + "latents": { + "shape": [ + "2*batch_size", + 4, + "height", + "width" + ], + "dtype": "f32" + }, + "timesteps": { + "shape": [ + 1 + ], + "dtype": "f32" + }, + "prompt_embeds": { + "shape": [ + "2*batch_size", + "max_len", + 2048 + ], + "dtype": "f32" + }, + "text_embeds": { + "shape": [ + "2*batch_size", + 1280 + ], + "dtype": "f32" + }, + "time_ids": { + "shape": [ + "2*batch_size", + 6 + ], + "dtype": "f32" + }, + "guidance_scale": { + "shape": 1, + "dtype": "f32" + } } }, "stencil_adaptor": { diff --git a/apps/stable_diffusion/src/utils/stable_args.py b/apps/stable_diffusion/src/utils/stable_args.py index caf6aedf9d..2a1d45e7a8 100644 --- a/apps/stable_diffusion/src/utils/stable_args.py +++ b/apps/stable_diffusion/src/utils/stable_args.py @@ -85,7 +85,7 @@ def is_valid_file(arg): "--height", type=int, default=512, - choices=range(128, 769, 8), + choices=range(128, 1025, 8), help="The height of the output image.", ) @@ -93,7 +93,7 @@ def is_valid_file(arg): "--width", type=int, default=512, - choices=range(128, 769, 8), + choices=range(128, 1025, 8), help="The width of the output image.", ) diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index 32d40e3dad..b696a222b6 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -109,6 +109,12 @@ def resource_path(relative_path): txt2img_sendto_inpaint, txt2img_sendto_outpaint, txt2img_sendto_upscaler, + # SDXL + txt2img_sdxl_inf, + txt2img_sdxl_web, + txt2img_sdxl_custom_model, + txt2img_sdxl_gallery, + txt2img_sdxl_status, # h2ogpt_upload, # h2ogpt_web, img2img_web, @@ -253,6 +259,8 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): # h2ogpt_upload.render() # with gr.TabItem(label="DocuChat(Experimental)", id=12): # h2ogpt_web.render() + with gr.TabItem(label="Text-to-Image-SDXL (Experimental)", id=13): + txt2img_sdxl_web.render() actual_port = app.usable_port() if actual_port != args.server_port: diff --git a/apps/stable_diffusion/web/ui/__init__.py b/apps/stable_diffusion/web/ui/__init__.py index 10cef374a1..7ba21b73ef 100644 --- a/apps/stable_diffusion/web/ui/__init__.py +++ b/apps/stable_diffusion/web/ui/__init__.py @@ -10,6 +10,13 @@ txt2img_sendto_outpaint, txt2img_sendto_upscaler, ) +from apps.stable_diffusion.web.ui.txt2img_sdxl_ui import ( + txt2img_sdxl_inf, + txt2img_sdxl_web, + txt2img_sdxl_custom_model, + txt2img_sdxl_gallery, + txt2img_sdxl_status, +) from apps.stable_diffusion.web.ui.img2img_ui import ( img2img_inf, img2img_web, diff --git a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py new file mode 100644 index 0000000000..299d6c2884 --- /dev/null +++ b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py @@ -0,0 +1,456 @@ +import os +import torch +import time +import sys +import gradio as gr +from PIL import Image +from math import ceil +from apps.stable_diffusion.web.ui.utils import ( + available_devices, + nodlogo_loc, + get_custom_model_path, + get_custom_model_files, + scheduler_list, + predefined_models, + cancel_sd, +) +from apps.stable_diffusion.web.utils.metadata import import_png_metadata +from apps.stable_diffusion.web.utils.common_label_calc import status_label +from apps.stable_diffusion.src import ( + args, + Text2ImageSDXLPipeline, + get_schedulers, + set_init_device_flags, + utils, + save_output_img, + prompt_examples, + Image2ImagePipeline, +) +from apps.stable_diffusion.src.utils import ( + get_generated_imgs_path, + get_generation_text_info, +) + +# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir. +init_iree_vulkan_target_triple = args.iree_vulkan_target_triple +init_iree_metal_target_platform = args.iree_metal_target_platform +init_use_tuned = args.use_tuned +init_import_mlir = args.import_mlir + + +def txt2img_sdxl_inf( + prompt: str, + negative_prompt: str, + height: int, + width: int, + steps: int, + guidance_scale: float, + seed: str | int, + batch_count: int, + batch_size: int, + scheduler: str, + model_id: str, + precision: str, + device: str, + max_length: int, + save_metadata_to_json: bool, + save_metadata_to_png: bool, + ondemand: bool, + repeatable_seeds: bool, +): + if precision != "fp16": + print("currently we support fp16 for SDXL") + precision = "fp16" + from apps.stable_diffusion.web.ui.utils import ( + get_custom_model_pathfile, + get_custom_vae_or_lora_weights, + Config, + ) + import apps.stable_diffusion.web.utils.global_obj as global_obj + from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( + SD_STATE_CANCEL, + ) + + args.prompts = [prompt] + args.negative_prompts = [negative_prompt] + args.guidance_scale = guidance_scale + args.steps = steps + args.scheduler = scheduler + args.ondemand = ondemand + + # set ckpt_loc and hf_model_id. + args.ckpt_loc = "" + args.hf_model_id = "" + args.custom_vae = "" + + # .safetensor or .chkpt on the custom model path + if model_id in get_custom_model_files(): + args.ckpt_loc = get_custom_model_pathfile(model_id) + # civitai download + elif "civitai" in model_id: + args.ckpt_loc = model_id + # either predefined or huggingface + else: + args.hf_model_id = model_id + + # if custom_vae != "None": + # args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae") + + args.save_metadata_to_json = save_metadata_to_json + args.write_metadata_to_png = save_metadata_to_png + + args.use_lora = "" + + dtype = torch.float32 if precision == "fp32" else torch.half + cpu_scheduling = not scheduler.startswith("Shark") + new_config_obj = Config( + "txt2img_sdxl", + args.hf_model_id, + args.ckpt_loc, + args.custom_vae, + precision, + batch_size, + max_length, + height, + width, + device, + use_lora=args.use_lora, + use_stencil=None, + ondemand=ondemand, + ) + if ( + not global_obj.get_sd_obj() + or global_obj.get_cfg_obj() != new_config_obj + ): + global_obj.clear_cache() + global_obj.set_cfg_obj(new_config_obj) + args.precision = precision + args.batch_count = batch_count + args.batch_size = batch_size + args.max_length = max_length + args.height = height + args.width = width + args.device = device.split("=>", 1)[1].strip() + args.iree_vulkan_target_triple = init_iree_vulkan_target_triple + args.iree_metal_target_platform = init_iree_metal_target_platform + args.use_tuned = init_use_tuned + args.import_mlir = init_import_mlir + args.img_path = None + set_init_device_flags() + model_id = ( + args.hf_model_id + if args.hf_model_id + else "stabilityai/stable-diffusion-xl-base-1.0" + ) + global_obj.set_schedulers(get_schedulers(model_id)) + scheduler_obj = global_obj.get_scheduler(scheduler) + # For SDXL we set max_length as 77. + print("Setting max_length = 77") + max_length = 77 + txt2img_sdxl_obj = Text2ImageSDXLPipeline.from_pretrained( + scheduler=scheduler_obj, + import_mlir=args.import_mlir, + model_id=args.hf_model_id, + ckpt_loc=args.ckpt_loc, + precision=precision, + max_length=max_length, + batch_size=batch_size, + height=height, + width=width, + use_base_vae=args.use_base_vae, + use_tuned=args.use_tuned, + custom_vae=args.custom_vae, + low_cpu_mem_usage=args.low_cpu_mem_usage, + debug=args.import_debug if args.import_mlir else False, + use_lora=args.use_lora, + use_quantize=args.use_quantize, + ondemand=args.ondemand, + ) + global_obj.set_sd_obj(txt2img_sdxl_obj) + + global_obj.set_sd_scheduler(scheduler) + + start_time = time.time() + global_obj.get_sd_obj().log = "" + generated_imgs = [] + text_output = "" + try: + seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) + except TypeError as error: + raise gr.Error(str(error)) from None + + for current_batch in range(batch_count): + out_imgs = global_obj.get_sd_obj().generate_images( + prompt, + negative_prompt, + batch_size, + height, + width, + steps, + guidance_scale, + seeds[current_batch], + args.max_length, + dtype, + args.use_base_vae, + cpu_scheduling, + args.max_embeddings_multiples, + ) + + total_time = time.time() - start_time + text_output = get_generation_text_info( + seeds[: current_batch + 1], device + ) + text_output += "\n" + global_obj.get_sd_obj().log + text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec" + + if global_obj.get_sd_status() == SD_STATE_CANCEL: + break + else: + save_output_img(out_imgs[0], seeds[current_batch]) + generated_imgs.extend(out_imgs) + yield generated_imgs, text_output, status_label( + "Text-to-Image-SDXL", + current_batch + 1, + batch_count, + batch_size, + ) + + return generated_imgs, text_output, "" + + +with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web: + with gr.Row(elem_id="ui_title"): + nod_logo = Image.open(nodlogo_loc) + with gr.Row(): + with gr.Column(scale=1, elem_id="demo_title_outer"): + gr.Image( + value=nod_logo, + show_label=False, + interactive=False, + elem_id="top_logo", + width=150, + height=50, + ) + with gr.Row(elem_id="ui_body"): + with gr.Row(): + with gr.Column(scale=1, min_width=600): + with gr.Row(): + with gr.Column(scale=10): + with gr.Row(): + t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}" + txt2img_sdxl_custom_model = gr.Dropdown( + label=f"Models", + info="Select, or enter HuggingFace Model ID or Civitai model download URL", + elem_id="custom_model", + value=os.path.basename(args.ckpt_loc) + if args.ckpt_loc + else "stabilityai/stable-diffusion-xl-base-1.0", + choices=[ + "stabilityai/stable-diffusion-xl-base-1.0" + ], + allow_custom_value=True, + scale=2, + ) + + with gr.Group(elem_id="prompt_box_outer"): + prompt = gr.Textbox( + label="Prompt", + value=args.prompts[0], + lines=2, + elem_id="prompt_box", + ) + negative_prompt = gr.Textbox( + label="Negative Prompt", + value=args.negative_prompts[0], + lines=2, + elem_id="negative_prompt_box", + ) + + with gr.Accordion(label="Advanced Options", open=False): + with gr.Row(): + scheduler = gr.Dropdown( + elem_id="scheduler", + label="Scheduler", + value="DDIM", + choices=["DDIM"], + allow_custom_value=True, + visible=False, + ) + with gr.Column(): + save_metadata_to_png = gr.Checkbox( + label="Save prompt information to PNG", + value=args.write_metadata_to_png, + interactive=True, + ) + save_metadata_to_json = gr.Checkbox( + label="Save prompt information to JSON file", + value=args.save_metadata_to_json, + interactive=True, + ) + with gr.Row(): + height = gr.Slider( + 1024, + value=1024, + step=8, + label="Height", + visible=False, + ) + width = gr.Slider( + 1024, + value=1024, + step=8, + label="Width", + visible=False, + ) + precision = gr.Radio( + label="Precision", + value="fp16", + choices=[ + "fp16", + "fp32", + ], + visible=False, + ) + max_length = gr.Radio( + label="Max Length", + value=args.max_length, + choices=[ + 64, + 77, + ], + visible=False, + ) + with gr.Row(): + with gr.Column(scale=3): + steps = gr.Slider( + 1, 100, value=args.steps, step=1, label="Steps" + ) + with gr.Column(scale=3): + guidance_scale = gr.Slider( + 0, + 50, + value=args.guidance_scale, + step=0.1, + label="CFG Scale", + ) + ondemand = gr.Checkbox( + value=args.ondemand, + label="Low VRAM", + interactive=True, + ) + with gr.Row(): + with gr.Column(scale=3): + batch_count = gr.Slider( + 1, + 100, + value=args.batch_count, + step=1, + label="Batch Count", + interactive=True, + ) + with gr.Column(scale=3): + batch_size = gr.Slider( + 1, + 4, + value=args.batch_size, + step=1, + label="Batch Size", + interactive=True, + ) + repeatable_seeds = gr.Checkbox( + args.repeatable_seeds, + label="Repeatable Seeds", + ) + with gr.Row(): + seed = gr.Textbox( + value=args.seed, + label="Seed", + info="An integer or a JSON list of integers, -1 for random", + ) + device = gr.Dropdown( + elem_id="device", + label="Device", + value=available_devices[0], + choices=available_devices, + allow_custom_value=True, + ) + with gr.Accordion(label="Prompt Examples!", open=False): + ex = gr.Examples( + examples=prompt_examples, + inputs=prompt, + cache_examples=False, + elem_id="prompt_examples", + ) + + with gr.Column(scale=1, min_width=600): + with gr.Group(): + txt2img_sdxl_gallery = gr.Gallery( + label="Generated images", + show_label=False, + elem_id="gallery", + columns=[2], + object_fit="contain", + ) + std_output = gr.Textbox( + value=f"{t2i_model_info}\n" + f"Images will be saved at " + f"{get_generated_imgs_path()}", + lines=1, + elem_id="std_output", + show_label=False, + ) + txt2img_sdxl_status = gr.Textbox(visible=False) + with gr.Row(): + stable_diffusion = gr.Button("Generate Image(s)") + random_seed = gr.Button("Randomize Seed") + random_seed.click( + lambda: -1, + inputs=[], + outputs=[seed], + queue=False, + ) + stop_batch = gr.Button("Stop Batch") + with gr.Row(): + blank_thing_for_row = None + + kwargs = dict( + fn=txt2img_sdxl_inf, + inputs=[ + prompt, + negative_prompt, + height, + width, + steps, + guidance_scale, + seed, + batch_count, + batch_size, + scheduler, + txt2img_sdxl_custom_model, + precision, + device, + max_length, + save_metadata_to_json, + save_metadata_to_png, + ondemand, + repeatable_seeds, + ], + outputs=[txt2img_sdxl_gallery, std_output, txt2img_sdxl_status], + show_progress="minimal" if args.progress_bar else "none", + ) + + status_kwargs = dict( + fn=lambda bc, bs: status_label("Text-to-Image-SDXL", 0, bc, bs), + inputs=[batch_count, batch_size], + outputs=txt2img_sdxl_status, + ) + + prompt_submit = prompt.submit(**status_kwargs).then(**kwargs) + neg_prompt_submit = negative_prompt.submit(**status_kwargs).then( + **kwargs + ) + generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs) + stop_batch.click( + fn=cancel_sd, + cancels=[prompt_submit, neg_prompt_submit, generate_click], + )