diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index 57998de598..5c8478eca5 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -162,6 +162,7 @@ def __init__( lora_strength: float = 0.75, use_quantize: str = None, return_mlir: bool = False, + favored_base_models=None, ): self.check_params(max_len, width, height) self.max_len = max_len @@ -191,6 +192,7 @@ def __init__( ) self.model_id = model_id if custom_weights == "" else custom_weights + self.favored_base_models = favored_base_models self.custom_vae = custom_vae self.precision = precision self.base_vae = use_base_vae @@ -1288,6 +1290,10 @@ def unet(self, use_large=False): compiled_unet = None unet_inputs = base_models[model] + # if the model to run *is* a base model, then we should treat it as such + if self.model_to_run in unet_inputs: + self.base_model_id = self.model_to_run + if self.base_model_id != "": self.inputs["unet"] = self.get_input_info_for( unet_inputs[self.base_model_id] @@ -1296,7 +1302,16 @@ def unet(self, use_large=False): model, use_large=use_large, base_model=self.base_model_id ) else: - for model_id in unet_inputs: + # restrict base models to check if we were given a specific list of valid ones + allowed_base_model_ids = unet_inputs + if self.favored_base_models != None: + allowed_base_model_ids = self.favored_base_models + + print(f"self.favored_base_models: {self.favored_base_models}") + print(f"allowed_base_model_ids: {allowed_base_model_ids}") + + # try compiling with each base model until we find one that works (of not) + for model_id in allowed_base_model_ids: self.base_model_id = model_id self.inputs["unet"] = self.get_input_info_for( unet_inputs[model_id] @@ -1309,7 +1324,7 @@ def unet(self, use_large=False): except Exception as e: print(e) print( - "Retrying with a different base model configuration" + f"Retrying with a different base model configuration, as {model_id} did not work" ) continue diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py index 4872e262bd..70fa314534 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py @@ -56,6 +56,13 @@ def __init__( scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand ) + @classmethod + def favored_base_models(cls, model_id): + return [ + "stabilityai/stable-diffusion-2-1", + "CompVis/stable-diffusion-v1-4", + ] + def prepare_latents( self, batch_size, diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py index 2d52b0b0c9..901d4f505a 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py @@ -60,6 +60,19 @@ def __init__( ) self.is_fp32_vae = is_fp32_vae + @classmethod + def favored_base_models(cls, model_id): + if "turbo" in model_id: + return [ + "stabilityai/sdxl-turbo", + "stabilityai/stable-diffusion-xl-base-1.0", + ] + else: + return [ + "stabilityai/stable-diffusion-xl-base-1.0", + "stabilityai/sdxl-turbo", + ] + def prepare_latents( self, batch_size, diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py index 2144fd85e4..e88c4dd027 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py @@ -94,6 +94,10 @@ def __init__( self.unload_unet() self.tokenizer = get_tokenizer() + def favored_base_models(cls, model_id): + # all base models can be candidate base models for unet compilation + return None + def load_clip(self): if self.text_encoder is not None: return @@ -667,6 +671,9 @@ def from_pretrained( is_upscaler = cls.__name__ in ["UpscalerPipeline"] is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"] + print(f"model_id", model_id) + print(f"ckpt_loc", ckpt_loc) + print(f"favored_base_models:", cls.favored_base_models(model_id)) sd_model = SharkifyStableDiffusionModel( model_id, ckpt_loc, @@ -687,6 +694,9 @@ def from_pretrained( use_lora=use_lora, lora_strength=lora_strength, use_quantize=use_quantize, + favored_base_models=cls.favored_base_models( + model_id if model_id != "" else ckpt_loc + ), ) if cls.__name__ in ["UpscalerPipeline"]: diff --git a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py index 2e383af729..e91b8e8c11 100644 --- a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py @@ -153,6 +153,7 @@ def txt2img_sdxl_inf( if args.hf_model_id else "stabilityai/stable-diffusion-xl-base-1.0" ) + global_obj.set_schedulers(get_schedulers(model_id)) scheduler_obj = global_obj.get_scheduler(scheduler) if global_obj.get_cfg_obj().ondemand: @@ -280,12 +281,14 @@ def txt2img_sdxl_inf( label=f"VAE Models", info=t2i_sdxl_vae_info, elem_id="custom_model", - value="None", + value="madebyollin/sdxl-vae-fp16-fix", choices=[ None, "madebyollin/sdxl-vae-fp16-fix", ] - + get_custom_model_files("vae"), + + get_custom_model_files( + "vae", custom_checkpoint_type="sdxl" + ), allow_custom_value=True, scale=4, ) @@ -375,7 +378,7 @@ def txt2img_sdxl_inf( height = gr.Slider( 512, 1024, - value=1024, + value=768, step=256, label="Height", visible=True, @@ -384,7 +387,7 @@ def txt2img_sdxl_inf( width = gr.Slider( 512, 1024, - value=1024, + value=768, step=256, label="Width", visible=True, diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index 221a90757e..0f8ffe3c27 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -212,6 +212,10 @@ def txt2img_inf( use_lora=args.use_lora, lora_strength=args.lora_strength, ondemand=args.ondemand, + valid_base_models=[ + "stabilityai/stable-diffusion-2-1", + "CompVis/stable-diffusion-v1-4", + ], ) )