Skip to content

Commit

Permalink
[SDXL] Add SDXL pipeline to SHARK
Browse files Browse the repository at this point in the history
-- This commit adds SDXL pipeline to SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
  • Loading branch information
Abhishek-Varma committed Nov 22, 2023
1 parent d051c3a commit 560350a
Show file tree
Hide file tree
Showing 12 changed files with 1,443 additions and 11 deletions.
96 changes: 96 additions & 0 deletions apps/stable_diffusion/scripts/txt2img_sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
import time
from apps.stable_diffusion.src import (
args,
Text2ImageSDXLPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)


def main():
if args.clear_all:
clear_all()

# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
args.precision = "fp16"
args.height = 1024
args.width = 1024
args.max_length = 77
args.scheduler = "DDIM"
print(
"Using default supported configuration for SDXL :-\nprecision=fp16, width*height= 1024*1024, max_length=77 and scheduler=DDIM"
)
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)

seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += (
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
)
text_output += (
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
)
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"

save_output_img(generated_imgs[0], seed)
print(text_output)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions apps/stable_diffusion/src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,
Text2ImageSDXLPipeline,
Image2ImagePipeline,
InpaintPipeline,
OutpaintPipeline,
Expand Down
Loading

0 comments on commit 560350a

Please sign in to comment.