Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Dec 4, 2023
1 parent 4e06aad commit 0785c51
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 38 deletions.
2 changes: 1 addition & 1 deletion apps/stable_diffusion/src/models/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision):
if "vulkan" in device:
_device = args.iree_vulkan_target_triple
_device = _device.replace("-", "_")
vmfb_path = Path(extended_model_name_for_vmfb + f"_{_device}.vmfb")
vmfb_path = Path(extended_model_name_for_vmfb + f"_vulkan.vmfb")
if vmfb_path.exists():
shark_module = SharkInference(
None,
Expand Down
5 changes: 4 additions & 1 deletion apps/stable_diffusion/src/models/opt_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def get_clip():
return get_shark_model(bucket, model_name, iree_flags)


def get_tokenizer(subfolder="tokenizer"):
def get_tokenizer(subfolder="tokenizer", hf_model_id=None):
if hf_model_id is not None:
args.hf_model_id = hf_model_id

tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder=subfolder
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ def encode_prompt_sdxl(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
hf_model_id: Optional[
str
] = "stabilityai/stable-diffusion-xl-base-1.0",
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand All @@ -217,7 +220,7 @@ def encode_prompt_sdxl(
batch_size = prompt_embeds.shape[0]

# Define tokenizers and text encoders
self.tokenizer_2 = get_tokenizer("tokenizer_2")
self.tokenizer_2 = get_tokenizer("tokenizer_2", hf_model_id)
self.load_clip_sdxl()
tokenizers = (
[self.tokenizer, self.tokenizer_2]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def step(
device="cpu",
generator=generator,
)
self._step_index += 1
step_inputs = [
noise_pred,
latent,
Expand All @@ -224,9 +223,10 @@ def step(
sigma_to,
noise,
]
print(step_inputs)
# TODO: Might not be proper behavior here... deal with dynamic inputs.
# TODO: deal with dynamic inputs in turbine flow.
# update step index since we're done with the variable and will return with compiled module output.
self._step_index += 1

if noise_pred.shape[0] < self.batch_size:
for i in [0, 1, 5]:
try:
Expand Down
24 changes: 15 additions & 9 deletions apps/stable_diffusion/web/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,15 @@ def resource_path(relative_path):
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
# SDXL
txt2img_sdxl_inf,
txt2img_sdxl_web,
txt2img_sdxl_custom_model,
txt2img_sdxl_gallery,
txt2img_sdxl_png_info_img,
txt2img_sdxl_status,
txt2img_sdxl_sendto_img2img,
txt2img_sdxl_sendto_inpaint,
txt2img_sdxl_sendto_outpaint,
txt2img_sdxl_sendto_upscaler,
# h2ogpt_upload,
# h2ogpt_web,
img2img_web,
Expand Down Expand Up @@ -151,7 +155,7 @@ def resource_path(relative_path):
upscaler_sendto_outpaint,
# lora_train_web,
# model_web,
# model_config_web,
model_config_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
Expand All @@ -165,6 +169,7 @@ def resource_path(relative_path):
outputgallery_watch,
outputgallery_filename,
outputgallery_sendto_txt2img,
outputgallery_sendto_txt2img_sdxl,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,
Expand Down Expand Up @@ -241,6 +246,7 @@ def register_outputgallery_button(button, selectedid, inputs, outputs):
inpaint_status,
outpaint_status,
upscaler_status,
txt2img_sdxl_status,
]
)
# with gr.TabItem(label="Model Manager", id=6):
Expand All @@ -249,17 +255,17 @@ def register_outputgallery_button(button, selectedid, inputs, outputs):
# lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=8):
stablelm_chat.render()
# with gr.TabItem(
# label="Generate Sharding Config (Experimental)", id=9
# ):
# model_config_web.render()
with gr.TabItem(label="MultiModal (Experimental)", id=10):
minigpt4_web.render()
with gr.TabItem(
label="Generate Sharding Config (Experimental)", id=9
):
model_config_web.render()
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
# minigpt4_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
with gr.TabItem(label="Text-to-Image-SDXL (Experimental)", id=13):
with gr.TabItem(label="Text-to-Image (SDXL)", id=13):
txt2img_sdxl_web.render()

actual_port = app.usable_port()
Expand Down
6 changes: 6 additions & 0 deletions apps/stable_diffusion/web/ui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
txt2img_sdxl_custom_model,
txt2img_sdxl_gallery,
txt2img_sdxl_status,
txt2img_sdxl_png_info_img,
txt2img_sdxl_sendto_img2img,
txt2img_sdxl_sendto_inpaint,
txt2img_sdxl_sendto_outpaint,
txt2img_sdxl_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_inf,
Expand Down Expand Up @@ -83,6 +88,7 @@
outputgallery_watch,
outputgallery_filename,
outputgallery_sendto_txt2img,
outputgallery_sendto_txt2img_sdxl,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,
Expand Down
21 changes: 8 additions & 13 deletions apps/stable_diffusion/web/ui/img2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,19 +426,14 @@ def cnet_preview(
return (None, stencils, images)

def create_canvas(width, height):
return {
"background": None,
"layers": [
Image.fromarray(
np.zeros(
shape=(height, width, 3),
dtype=np.uint8,
)
+ 255
)
],
"composite": None,
}
data = (
np.zeros(
shape=(height, width, 3),
dtype=np.uint8,
)
+ 255
)
return data

def update_cn_input(model, width, height):
if model == "scribble":
Expand Down
8 changes: 8 additions & 0 deletions apps/stable_diffusion/web/ui/outputgallery_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def output_subdirs() -> list[str]:
wrap=True,
elem_classes="output_parameters_dataframe",
value=[["Status", "No image selected"]],
interactive=True,
)

with gr.Accordion(label="Send To", open=True):
Expand All @@ -162,6 +163,12 @@ def output_subdirs() -> list[str]:
elem_classes="outputgallery_sendto",
size="sm",
)
outputgallery_sendto_txt2img_sdxl = gr.Button(
value="Txt2Img XL",
interactive=False,
elem_classes="outputgallery_sendto",
size="sm",
)

outputgallery_sendto_img2img = gr.Button(
value="Img2Img",
Expand Down Expand Up @@ -414,6 +421,7 @@ def on_select_tab(subdir_paths, request: gr.Request):
[outputgallery_filename],
[
outputgallery_sendto_txt2img,
outputgallery_sendto_txt2img_sdxl,
outputgallery_sendto_img2img,
outputgallery_sendto_inpaint,
outputgallery_sendto_outpaint,
Expand Down
15 changes: 13 additions & 2 deletions apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def txt2img_sdxl_inf(
show_label=False,
elem_id="gallery",
columns=[2],
object_fit="contain",
object_fit="scale_down",
)
std_output = gr.Textbox(
value=f"{t2i_sdxl_model_info}\n"
Expand All @@ -483,7 +483,18 @@ def txt2img_sdxl_inf(
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
txt2img_sdxl_sendto_img2img = gr.Button(
value="Send To Img2Img"
)
txt2img_sdxl_sendto_inpaint = gr.Button(
value="Send To Inpaint"
)
txt2img_sdxl_sendto_outpaint = gr.Button(
value="Send To Outpaint"
)
txt2img_sdxl_sendto_upscaler = gr.Button(
value="Send To Upscaler"
)

kwargs = dict(
fn=txt2img_sdxl_inf,
Expand Down
17 changes: 10 additions & 7 deletions apps/stable_diffusion/web/ui/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,18 +263,19 @@ def cancel_sd():
def set_model_default_configs(model_ckpt_or_id, jsonconfig=None):
import gradio as gr

config_modelname = default_config_exists(model_ckpt_or_id)
if jsonconfig:
return get_config_from_json(jsonconfig)
elif default_config_exists(model_ckpt_or_id):
return default_configs[model_ckpt_or_id]
elif config_modelname:
return default_configs[config_modelname]
# TODO: Use HF metadata to setup pipeline if available
# elif is_valid_hf_id(model_ckpt_or_id):
# return get_HF_default_configs(model_ckpt_or_id)
else:
# We don't have default metadata to setup a good config. Do not change configs.
return [
gr.Textbox(label="Prompt", interactive=True, visible=True),
gr.update(),
gr.Textbox(label="Negative Prompt", interactive=True),
gr.update(),
gr.update(),
gr.update(),
Expand Down Expand Up @@ -304,19 +305,21 @@ def default_config_exists(model_ckpt_or_id):
"stabilityai/sdxl-turbo",
"stabilityai/stable_diffusion-xl-base-1.0",
]:
return True
return model_ckpt_or_id
elif "turbo" in model_ckpt_or_id.lower():
return "stabilityai/sdxl-turbo"
else:
return False
return None


default_configs = {
"stabilityai/sdxl-turbo": [
gr.Textbox(label="", interactive=False, value=None, visible=False),
gr.Textbox(
label="Prompt",
value="An anthropomorphic shark writing code on an old tube monitor, macro shot, in an office filled with water, stop-animation style, claymation",
value="role-playing game (RPG) style fantasy, An enchanting image featuring an adorable kitten mage wearing intricate ancient robes, holding an ancient staff, hard at work in her fantastical workshop, magic runes floating in the air",
),
gr.Slider(0, 5, value=2),
gr.Slider(0, 10, value=2),
gr.Dropdown(value="EulerAncestralDiscrete"),
gr.Slider(0, value=0),
512,
Expand Down
2 changes: 1 addition & 1 deletion shark/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __call__(self, parser, namespace, values, option_string=None):

parser.add_argument(
"--vulkan_debug_utils",
default=True,
default=False,
action=argparse.BooleanOptionalAction,
help="Profiles vulkan device and collects the .rdc info.",
)
Expand Down

0 comments on commit 0785c51

Please sign in to comment.