Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Studio) Update gradio and multicontrolnet UI. #2001

Merged
merged 12 commits into from
Dec 4, 2023
3 changes: 3 additions & 0 deletions apps/stable_diffusion/shark_sd.spec
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ a = Analysis(
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
module_collection_mode={
'gradio': 'py', # Collect gradio package as source .py files
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont even want to know how you figured out this was required...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just wasnt collecting some random source files
tried adding metadata and hidden files, nope, ok just gimme it all then

},
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)

Expand Down
4 changes: 3 additions & 1 deletion apps/stable_diffusion/shark_studio_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += copy_metadata("gradio")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
Expand Down Expand Up @@ -75,6 +76,7 @@
# hidden imports for pyinstaller
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
hiddenimports += [
x for x in collect_submodules("diffusers") if "tests" not in x
]
Expand All @@ -85,4 +87,4 @@
if not any(kw in x for kw in blacklist)
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]
hiddenimports += ["iree._runtime"]
61 changes: 49 additions & 12 deletions apps/stable_diffusion/src/models/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision):
if "vulkan" in device:
_device = args.iree_vulkan_target_triple
_device = _device.replace("-", "_")
vmfb_path = Path(extended_model_name_for_vmfb + f"_{_device}.vmfb")
vmfb_path = Path(extended_model_name_for_vmfb + f"_vulkan.vmfb")
monorimet marked this conversation as resolved.
Show resolved Hide resolved
if vmfb_path.exists():
shark_module = SharkInference(
None,
Expand Down Expand Up @@ -436,24 +436,48 @@ def __init__(
super().__init__()
self.vae = None
if custom_vae == "":
print(f"Loading default vae, with target {model_id}")
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
precision = "fp16" if "fp16" in custom_vae else None
print(f"Loading custom vae, with target {custom_vae}")
if os.path.exists(custom_vae):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
custom_vae = "/".join(
[
custom_vae.split("/")[-2].split("\\")[-1],
custom_vae.split("/")[-1],
]
)
print("Using hub to get custom vae")
try:
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
variant=precision,
)
except:
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
print(f"Loading custom vae, with state {custom_vae}")
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.vae.load_state_dict(custom_vae)
self.base_vae = base_vae

def forward(self, latents):
image = self.vae.decode(latents / 0.13025, return_dict=False)[
Expand All @@ -465,7 +489,12 @@ def forward(self, latents):
inputs = tuple(self.inputs["vae"])
# Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL
# pipeline.
is_f16 = False
if not self.custom_vae:
is_f16 = False
elif "16" in self.custom_vae:
is_f16 = True
else:
is_f16 = False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
Expand Down Expand Up @@ -917,11 +946,19 @@ def __init__(
low_cpu_mem_usage=False,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
try:
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
variant="fp16",
)
except:
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
Expand Down
5 changes: 4 additions & 1 deletion apps/stable_diffusion/src/models/opt_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def get_clip():
return get_shark_model(bucket, model_name, iree_flags)


def get_tokenizer(subfolder="tokenizer"):
def get_tokenizer(subfolder="tokenizer", hf_model_id=None):
if hf_model_id is not None:
args.hf_model_id = hf_model_id

tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder=subfolder
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ def generate_images(
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
stencils,
images,
resample_type,
control_mode,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
Expand All @@ -38,6 +41,7 @@ def __init__(
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
Expand All @@ -48,8 +52,10 @@ def __init__(
import_mlir: bool,
use_lora: str,
ondemand: bool,
is_fp32_vae: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.is_fp32_vae = is_fp32_vae

def prepare_latents(
self,
Expand Down Expand Up @@ -203,10 +209,10 @@ def generate_images(
# Img latents -> PIL images.
all_imgs = []
self.load_vae()
# imgs = self.decode_latents_sdxl(None)
# all_imgs.extend(imgs)
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents_sdxl(latents[i : i + batch_size])
imgs = self.decode_latents_sdxl(
latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae
)
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
HeunDiscreteScheduler,
)
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.schedulers import (
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
)
from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae,
Expand Down Expand Up @@ -52,6 +55,7 @@ def __init__(
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
SharkEulerAncestralDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
Expand All @@ -62,21 +66,23 @@ def __init__(
import_mlir: bool,
use_lora: str,
ondemand: bool,
is_f32_vae: bool = False,
):
self.vae = None
self.text_encoder = None
self.text_encoder_2 = None
self.unet = None
self.unet_512 = None
self.model_max_length = 77
self.scheduler = scheduler
# TODO: Implement using logging python utility.
self.log = ""
self.status = SD_STATE_IDLE
self.sd_model = sd_model
self.scheduler = scheduler
self.import_mlir = import_mlir
self.use_lora = use_lora
self.ondemand = ondemand
self.is_f32_vae = is_f32_vae
# TODO: Find a better workaround for fetching base_model_id early
# enough for CLIPTokenizer.
try:
Expand Down Expand Up @@ -202,6 +208,9 @@ def encode_prompt_sdxl(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
hf_model_id: Optional[
str
] = "stabilityai/stable-diffusion-xl-base-1.0",
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand All @@ -211,7 +220,7 @@ def encode_prompt_sdxl(
batch_size = prompt_embeds.shape[0]

# Define tokenizers and text encoders
self.tokenizer_2 = get_tokenizer("tokenizer_2")
self.tokenizer_2 = get_tokenizer("tokenizer_2", hf_model_id)
self.load_clip_sdxl()
tokenizers = (
[self.tokenizer, self.tokenizer_2]
Expand Down Expand Up @@ -332,7 +341,7 @@ def encode_prompt_sdxl(
gc.collect()

# TODO: Look into dtype for text_encoder_2!
prompt_embeds = prompt_embeds.to(dtype=torch.float32)
prompt_embeds = prompt_embeds.to(dtype=torch.float16)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand Down Expand Up @@ -523,6 +532,9 @@ def produce_img_latents_sdxl(
cpu_scheduling,
guidance_scale,
dtype,
mask=None,
masked_image_latents=None,
return_all_latents=False,
):
# return None
self.status = SD_STATE_IDLE
Expand All @@ -533,11 +545,22 @@ def produce_img_latents_sdxl(
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
# expand the latents if we are doing classifier free guidance
if isinstance(latents, np.ndarray):
latents = torch.tensor(latents)
latent_model_input = torch.cat([latents] * 2)

latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
).to(dtype)
)
if mask is not None and masked_image_latents is not None:
latent_model_input = torch.cat(
[
torch.from_numpy(np.asarray(latent_model_input)),
mask,
masked_image_latents,
],
dim=1,
).to(dtype)

noise_pred = self.unet(
"forward",
Expand All @@ -549,11 +572,17 @@ def produce_img_latents_sdxl(
add_time_ids,
guidance_scale,
),
send_to_host=False,
send_to_host=True,
)
if not isinstance(latents, torch.Tensor):
latents = torch.from_numpy(latents).to("cpu")
noise_pred = torch.from_numpy(noise_pred).to("cpu")

latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
latents = latents.detach().numpy()
noise_pred = noise_pred.detach().numpy()

step_time = (time.time() - step_start_time) * 1000
step_time_sum += step_time
Expand All @@ -569,11 +598,15 @@ def produce_img_latents_sdxl(

return latents

def decode_latents_sdxl(self, latents):
latents = latents.to(torch.float32)
def decode_latents_sdxl(self, latents, is_fp32_vae):
# latents are in unet dtype here so switch if we want to use fp32
if is_fp32_vae:
print("Casting latents to float32 for VAE")
latents = latents.to(torch.float32)
images = self.vae("forward", (latents,))
images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()

images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]

Expand Down Expand Up @@ -666,6 +699,17 @@ def from_pretrained(
return cls(
scheduler, sd_model, import_mlir, use_lora, ondemand, stencils
)
if cls.__name__ == "Text2ImageSDXLPipeline":
is_fp32_vae = True if "16" not in custom_vae else False
return cls(
scheduler,
sd_model,
import_mlir,
use_lora,
ondemand,
is_fp32_vae,
)

return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)

# #####################################################
Expand Down
3 changes: 3 additions & 0 deletions apps/stable_diffusion/src/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
SharkEulerDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
SharkEulerAncestralDiscreteScheduler,
)
Loading
Loading