Skip to content

Commit

Permalink
(Shark 1.0) UI/SD UX improvements for SDXL (#2057)
Browse files Browse the repository at this point in the history
* SDXL Tab
  * Filter VAEs in dropdown in the same manner as models
  * Set default VAE selection to 'madebyollin/sdxl-vae-fp16-fix'
  * Set default image size to 768x768 to match current Vulkan constraints
* SharkifySDModel Base Unet Model Determination
  * Alway use the model_to_run as the base model for unet, if it is in
base_model.json, instead of potentially trying to compile for other base
models.
  * Allow SharkSDPipelines to define a 'favor_base_models' @classmethod,
answering a list of sane base model names for the pipeline. Exclude base
models not in that list from compilation attempts when trying to determine
a base unet model.
  *  Add a 'favor_base_models' method for both Normal and SDXL Txt2Img
Pipelines. Define the method as answering 'None' in the base class.
  • Loading branch information
one-lithe-rune authored Jan 6, 2024
1 parent 0a6f6fa commit 7fdd195
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 6 deletions.
19 changes: 17 additions & 2 deletions apps/stable_diffusion/src/models/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(
lora_strength: float = 0.75,
use_quantize: str = None,
return_mlir: bool = False,
favored_base_models=None,
):
self.check_params(max_len, width, height)
self.max_len = max_len
Expand Down Expand Up @@ -191,6 +192,7 @@ def __init__(
)

self.model_id = model_id if custom_weights == "" else custom_weights
self.favored_base_models = favored_base_models
self.custom_vae = custom_vae
self.precision = precision
self.base_vae = use_base_vae
Expand Down Expand Up @@ -1288,6 +1290,10 @@ def unet(self, use_large=False):
compiled_unet = None
unet_inputs = base_models[model]

# if the model to run *is* a base model, then we should treat it as such
if self.model_to_run in unet_inputs:
self.base_model_id = self.model_to_run

if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[self.base_model_id]
Expand All @@ -1296,7 +1302,16 @@ def unet(self, use_large=False):
model, use_large=use_large, base_model=self.base_model_id
)
else:
for model_id in unet_inputs:
# restrict base models to check if we were given a specific list of valid ones
allowed_base_model_ids = unet_inputs
if self.favored_base_models != None:
allowed_base_model_ids = self.favored_base_models

print(f"self.favored_base_models: {self.favored_base_models}")
print(f"allowed_base_model_ids: {allowed_base_model_ids}")

# try compiling with each base model until we find one that works (of not)
for model_id in allowed_base_model_ids:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[model_id]
Expand All @@ -1309,7 +1324,7 @@ def unet(self, use_large=False):
except Exception as e:
print(e)
print(
"Retrying with a different base model configuration"
f"Retrying with a different base model configuration, as {model_id} did not work"
)
continue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def __init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)

@classmethod
def favored_base_models(cls, model_id):
return [
"stabilityai/stable-diffusion-2-1",
"CompVis/stable-diffusion-v1-4",
]

def prepare_latents(
self,
batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@ def __init__(
)
self.is_fp32_vae = is_fp32_vae

@classmethod
def favored_base_models(cls, model_id):
if "turbo" in model_id:
return [
"stabilityai/sdxl-turbo",
"stabilityai/stable-diffusion-xl-base-1.0",
]
else:
return [
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/sdxl-turbo",
]

def prepare_latents(
self,
batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def __init__(
self.unload_unet()
self.tokenizer = get_tokenizer()

def favored_base_models(cls, model_id):
# all base models can be candidate base models for unet compilation
return None

def load_clip(self):
if self.text_encoder is not None:
return
Expand Down Expand Up @@ -667,6 +671,9 @@ def from_pretrained(
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"]

print(f"model_id", model_id)
print(f"ckpt_loc", ckpt_loc)
print(f"favored_base_models:", cls.favored_base_models(model_id))
sd_model = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
Expand All @@ -687,6 +694,9 @@ def from_pretrained(
use_lora=use_lora,
lora_strength=lora_strength,
use_quantize=use_quantize,
favored_base_models=cls.favored_base_models(
model_id if model_id != "" else ckpt_loc
),
)

if cls.__name__ in ["UpscalerPipeline"]:
Expand Down
11 changes: 7 additions & 4 deletions apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def txt2img_sdxl_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-xl-base-1.0"
)

global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
if global_obj.get_cfg_obj().ondemand:
Expand Down Expand Up @@ -280,12 +281,14 @@ def txt2img_sdxl_inf(
label=f"VAE Models",
info=t2i_sdxl_vae_info,
elem_id="custom_model",
value="None",
value="madebyollin/sdxl-vae-fp16-fix",
choices=[
None,
"madebyollin/sdxl-vae-fp16-fix",
]
+ get_custom_model_files("vae"),
+ get_custom_model_files(
"vae", custom_checkpoint_type="sdxl"
),
allow_custom_value=True,
scale=4,
)
Expand Down Expand Up @@ -375,7 +378,7 @@ def txt2img_sdxl_inf(
height = gr.Slider(
512,
1024,
value=1024,
value=768,
step=256,
label="Height",
visible=True,
Expand All @@ -384,7 +387,7 @@ def txt2img_sdxl_inf(
width = gr.Slider(
512,
1024,
value=1024,
value=768,
step=256,
label="Width",
visible=True,
Expand Down
4 changes: 4 additions & 0 deletions apps/stable_diffusion/web/ui/txt2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@ def txt2img_inf(
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
valid_base_models=[
"stabilityai/stable-diffusion-2-1",
"CompVis/stable-diffusion-v1-4",
],
)
)

Expand Down

0 comments on commit 7fdd195

Please sign in to comment.