Skip to content

Commit

Permalink
(SDXL) Fix --ondemand and vae scale factor use, and fix VAE flags.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet authored and Abhishek-Varma committed Dec 1, 2023
1 parent 9ba0b9b commit 7a6d096
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 12 deletions.
9 changes: 2 additions & 7 deletions apps/stable_diffusion/src/models/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,8 @@ def __init__(
self.check_params(max_len, width, height)
self.max_len = max_len
self.is_sdxl = is_sdxl
self.height = height
self.width = width
if is_sdxl:
# We need to scale down the height/width by vae_scale_factor, which
# happens to be 8 in this case.
self.height = height // 8
self.width = width // 8
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights.strip()
self.use_quantize = use_quantize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
end_profiling,
)
import sys
import gc
from typing import List, Optional

SD_STATE_IDLE = "idle"
Expand Down Expand Up @@ -189,6 +190,7 @@ def load_vae(self):
def unload_vae(self):
del self.vae
self.vae = None
gc.collect()

def encode_prompt_sdxl(
self,
Expand Down Expand Up @@ -327,6 +329,7 @@ def encode_prompt_sdxl(

if self.ondemand:
self.unload_clip_sdxl()
gc.collect()

# TODO: Look into dtype for text_encoder_2!
prompt_embeds = prompt_embeds.to(dtype=torch.float32)
Expand Down Expand Up @@ -387,6 +390,7 @@ def encode_prompts(self, prompts, neg_prompts, max_length):
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
gc.collect()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"

return text_embeddings
Expand Down Expand Up @@ -499,6 +503,8 @@ def produce_img_latents(
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
gc.collect()

avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"

Expand Down Expand Up @@ -556,6 +562,8 @@ def produce_img_latents_sdxl(
break
if self.ondemand:
self.unload_unet()
gc.collect()

avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"

Expand Down Expand Up @@ -652,7 +660,6 @@ def from_pretrained(
use_lora,
ondemand,
)

return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)

# #####################################################
Expand Down Expand Up @@ -765,6 +772,7 @@ def encode_prompts_weight(
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
gc.collect()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"

return text_embeddings.numpy()
Expand Down
7 changes: 4 additions & 3 deletions apps/stable_diffusion/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,9 +565,10 @@ def get_opt_flags(model, precision="fp16"):
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
if "vae" not in model:
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
return iree_flags


Expand Down
4 changes: 3 additions & 1 deletion apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def txt2img_sdxl_inf(
# For SDXL we set max_length as 77.
print("Setting max_length = 77")
max_length = 77
if global_obj.get_cfg_obj().ondemand:
print("Running txt2img in memory efficient mode.")
txt2img_sdxl_obj = Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
Expand All @@ -164,7 +166,7 @@ def txt2img_sdxl_inf(
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
ondemand=global_obj.get_cfg_obj().ondemand,
)
global_obj.set_sd_obj(txt2img_sdxl_obj)

Expand Down

0 comments on commit 7a6d096

Please sign in to comment.