Skip to content

Commit

Permalink
[SDXL] Add SDXL pipeline to SHARK
Browse files Browse the repository at this point in the history
-- This commit adds SDXL pipeline to SHARK.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
  • Loading branch information
Abhishek-Varma committed Nov 6, 2023
1 parent 5001db3 commit adf077d
Show file tree
Hide file tree
Showing 10 changed files with 996 additions and 40 deletions.
71 changes: 51 additions & 20 deletions apps/stable_diffusion/scripts/txt2img.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
import transformers
import time
from apps.stable_diffusion.src import (
args,
Text2ImagePipeline,
Text2ImageSDXLPipeline,
get_schedulers,
set_init_device_flags,
utils,
Expand All @@ -16,31 +16,62 @@ def main():
if args.clear_all:
clear_all()

# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
if args.height == 1024:
assert (
args.width == 1024
), "currently we support only 1024x1024 image size via SDXL"
assert args.precision == "fp16", "currently we support fp16 for SDXL"
# For SDXL we set max_length as 77.
args.max_length = 77
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
else:
assert (
args.height <= 768 and args.width <= 768
), "height/width not in supported range"
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)

seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,
Text2ImageSDXLPipeline,
Image2ImagePipeline,
InpaintPipeline,
OutpaintPipeline,
Expand Down
Loading

0 comments on commit adf077d

Please sign in to comment.