Skip to content

Commit

Permalink
Address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters-amd committed Sep 26, 2024
1 parent 5388722 commit 38b6248
Showing 1 changed file with 87 additions and 129 deletions.
216 changes: 87 additions & 129 deletions apps/shark_studio/api/shark_api.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,26 @@

# Internal API
pipelines = {
"sd1.5": ("", None),
"sd2": ("", None),
"sdxl": ("", None),
"sd3": ("", None),
}

# Used for filenames as well as the key for the global cache
def safe_name():
pass

def local_path():
pass

def generate_sd_vmfb(
model: str,
# Used for filenames as well as the key for the global cache
def safe_name(
model_name: str,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
batch_size: int = 1,
base_model_id: str,
precision: str,
controlled: bool,
**kwargs,
batch_size: int,
):
pass

def load_sd_vmfb(
model: str,
weight_file: str,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
batch_size: int = 1,
base_model: str,
precision: str,
controlled: bool,
try_download: bool,
**kwargs,
):
# Check if the file is already loaded and cached
# Check if the file already exists on disk
# Try to download from the web
# Generate the vmfb (generate_sd_vmfb)
# Load the vmfb and weights
# Return wrapper

def local_path():
pass


# External API
def generate_images(
prompt: str,
Expand Down Expand Up @@ -78,123 +53,106 @@ def generate_images(

# Handle img2img
if not isinstance(sd_init_image, list):
sd_init_image = [sd_init_image]
sd_init_image = [sd_init_image] * batch_count
is_img2img = True if sd_init_image[0] is not None else False

# Generate seed if < 0
# TODO

# Cache dir
# TODO
pipeline_dir = None

# Sanity checks
# Scheduler
# Base model
assert scheduler in ["EulerDiscrete"]
assert base_model in ["sd1.5", "sd2", "sdxl", "sd3"]
assert precision in ["fp16", "fp32"]
assert device in [
"cpu",
"vulkan",
"rocm",
"hip",
"cuda",
] # and (IREE check if the device exists)
assert resample_type in ["Nearest Neighbor"]

# Custom weights
# TODO
# Custom VAE
# Precision
# Device
# TODO
# Target triple
# Resample type
# TODO

adapters = {}
is_controlled = False
control_mode = None
hints = []
num_loras = 0
import_ir = True

# Populate model map
if model == "sd1.5":
submodels = {
"clip": None,
"scheduler": None,
"unet": None,
"vae_decode": None,
}
elif model == "sd2":
submodels = {
"clip": None,
"scheduler": None,
"unet": None,
"vae_decode": None,
}
elif model == "sdxl":
submodels = {
"prompt_encoder": None,
"scheduled_unet": None,
"vae_decode": None,
"pipeline": None,
"full_pipeline": None,
}
elif model == "sd3":
# (Re)initialize pipeline
pipeline_args = {
"height": height,
"width": width,
"batch_size": batch_size,
"precision": precision,
"device": device,
"target_triple": target_triple,
}
(existing_args, pipeline) = pipelines[base_model]
if not existing_args or not pipeline or not pipeline_args == existing_args:
# TODO: Initialize new pipeline
if base_model == "sd1.5":
pass
elif base_model == "sd2":
new_pipeline = SharkSDPipeline(
hf_model_name="stabilityai/stable-diffusion-2-1",
scheduler_id=scheduler,
height=height,
width=width,
precision=precision,
max_length=64,
batch_size=batch_size,
num_inference_steps=steps,
device=device, # TODO: Get the IREE device ID?
iree_target_triple=target_triple,
ireec_flags={},
attn_spec=None, # TODO: Find a better way to figure this out than hardcoding
decomp_attn=True, # TODO: Ditto
pipeline_dir=pipeline_dir,
external_weights_dir=weights, # TODO: Are both necessary still?
external_weights=weights,
custom_vae=custom_vae,
)
elif base_model == "sdxl":
pass
elif base_model == "sd3":
pass
# existing_args = pipeline_args
pass

# TODO: generate and load submodel vmfbs
for submodel in submodels:
submodels[submodel] = load_sd_vmfb(
submodel,
custom_weights,
height,
width,
steps,
strength,
guidance_scale,
batch_size,
model,
precision,
not controlnets.keys(),
True,
)

generated_imgs = []
generated_images = []
for current_batch in range(batch_count):

# TODO: Batch size > 1

# TODO: random sample (or img2img input)
sample = None

# TODO: encode input
prompt_embeds, negative_prompt_embeds = encode(prompt, negative_prompt)

start_time = time.time()
for t in range(steps):

# Prepare latents

# Scale model input
latent_model_input = submodels["scheduler"].scale_model_input(
sample,
t
)

# Run unet
latents = submodels["unet"](
latent_model_input,
t,
(negative_prompt_embeds, prompt_embeds),
guidance_scale,
out_images = pipeline.generate_images(
prompt=prompt,
negative_prompt=negative_prompt,
image=sd_init_image[current_batch],
strength=strength,
guidance_scale=guidance_scale,
seed=seed,
ondemand=ondemand,
resample_type=resample_type,
control_mode=control_mode,
hints=hints,
)

# Step scheduler
sample = submodels["scheduler"].step(
latents,
t,
sample
)

# VAE decode
out_img = submodels["vae_decode"](
sample
)

# Processing time
total_time = time.time() - start_time
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
# print(f"\n[LOG] {text_output}")

# TODO: Add to output list
generated_imgs.append(out_img)
if not isinstance(out_images, list):
out_images = [out_images]
generated_images.extend(out_images)

# TODO: Allow the user to halt the process

return generated_imgs
return generated_images

0 comments on commit 38b6248

Please sign in to comment.