Skip to content

Commit

Permalink
Studio2: Remove duplications from api/utils.py (#2035)
Browse files Browse the repository at this point in the history
* Remove duplicate os import
* Remove duplicate parse_seed_input function

Migrating to JSON requests in SD UI

More UI and app flow improvements, logging, shared device cache

Model loading

Complete SD pipeline.

Tweaks to VAE, pipeline states

Pipeline tweaks, add cmd_opts parsing to sd api
  • Loading branch information
one-lithe-rune authored and monorimet committed Jan 17, 2024
1 parent cdf2eb5 commit 7a0017d
Show file tree
Hide file tree
Showing 35 changed files with 2,236 additions and 1,147 deletions.
109 changes: 41 additions & 68 deletions apps/shark_studio/api/controlnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
# from turbine_models.custom_models.controlnet import control_adapter, preprocessors
import os
import PIL
import numpy as np
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
)
from datetime import datetime
from PIL import Image
from gradio.components.image_editor import (
EditorValue,
)


class control_adapter:
Expand Down Expand Up @@ -29,20 +40,12 @@ def export_controlnet_model(model_keyword):
control_adapter_map = {
"sd15": {
"canny": {"initializer": control_adapter.export_control_adapter_model},
"openpose": {
"initializer": control_adapter.export_control_adapter_model
},
"scribble": {
"initializer": control_adapter.export_control_adapter_model
},
"zoedepth": {
"initializer": control_adapter.export_control_adapter_model
},
"openpose": {"initializer": control_adapter.export_control_adapter_model},
"scribble": {"initializer": control_adapter.export_control_adapter_model},
"zoedepth": {"initializer": control_adapter.export_control_adapter_model},
},
"sdxl": {
"canny": {
"initializer": control_adapter.export_xl_control_adapter_model
},
"canny": {"initializer": control_adapter.export_xl_control_adapter_model},
},
}
preprocessor_model_map = {
Expand All @@ -57,78 +60,48 @@ class PreprocessorModel:
def __init__(
self,
hf_model_id,
device,
device="cpu",
):
self.model = None
self.model = hf_model_id
self.device = device

def compile(self, device):
def compile(self):
print("compile not implemented for preprocessor.")
return

def run(self, inputs):
print("run not implemented for preprocessor.")
return
return inputs


def cnet_preview(model, input_img, stencils, images, preprocessed_hints):
if isinstance(input_image, PIL.Image.Image):
img_dict = {
"background": None,
"layers": [None],
"composite": input_image,
}
input_image = EditorValue(img_dict)
images[index] = input_image
if model:
stencils[index] = model
def cnet_preview(model, input_image):
curr_datetime = datetime.now().strftime("%Y-%m-%d.%H-%M-%S")
control_imgs_path = os.path.join(get_generated_imgs_path(), "control_hints")
if not os.path.exists(control_imgs_path):
os.mkdir(control_imgs_path)
img_dest = os.path.join(control_imgs_path, model + curr_datetime + ".png")
match model:
case "canny":
canny = CannyDetector()
canny = PreprocessorModel("canny")
result = canny(
np.array(input_image["composite"]),
np.array(input_image),
100,
200,
)
preprocessed_hints[index] = Image.fromarray(result)
return (
Image.fromarray(result),
stencils,
images,
preprocessed_hints,
)
Image.fromarray(result).save(fp=img_dest)
return result, img_dest
case "openpose":
openpose = OpenposeDetector()
result = openpose(np.array(input_image["composite"]))
preprocessed_hints[index] = Image.fromarray(result[0])
return (
Image.fromarray(result[0]),
stencils,
images,
preprocessed_hints,
)
openpose = PreprocessorModel("openpose")
result = openpose(np.array(input_image))
Image.fromarray(result[0]).save(fp=img_dest)
return result, img_dest
case "zoedepth":
zoedepth = ZoeDetector()
result = zoedepth(np.array(input_image["composite"]))
preprocessed_hints[index] = Image.fromarray(result)
return (
Image.fromarray(result),
stencils,
images,
preprocessed_hints,
)
zoedepth = PreprocessorModel("ZoeDepth")
result = zoedepth(np.array(input_image))
Image.fromarray(result).save(fp=img_dest)
return result, img_dest
case "scribble":
preprocessed_hints[index] = input_image["composite"]
return (
input_image["composite"],
stencils,
images,
preprocessed_hints,
)
input_image.save(fp=img_dest)
return input_image, img_dest
case _:
preprocessed_hints[index] = None
return (
None,
stencils,
images,
preprocessed_hints,
)
return None, None
32 changes: 21 additions & 11 deletions apps/shark_studio/api/initializers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import importlib
import logging
import os
import signal
import sys
import re
import warnings
import json
from threading import Thread

from apps.shark_studio.modules.timer import startup_timer
from apps.shark_studio.web.utils.tmp_configs import (
config_tmp,
clear_tmp_mlir,
clear_tmp_imgs,
)


def imports():
Expand All @@ -18,9 +21,8 @@ def imports():
warnings.filterwarnings(
action="ignore", category=DeprecationWarning, module="torch"
)
warnings.filterwarnings(
action="ignore", category=UserWarning, module="torchvision"
)
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")

import gradio # noqa: F401

Expand All @@ -34,20 +36,28 @@ def imports():
from apps.shark_studio.modules import (
img_processing,
) # noqa: F401
from apps.shark_studio.modules.schedulers import scheduler_model_map

startup_timer.record("other imports")


def initialize():
configure_sigint_handler()
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.

config_tmp()
clear_tmp_mlir()
clear_tmp_imgs()

from apps.shark_studio.web.utils.file_utils import (
create_checkpoint_folders,
)

# from apps.shark_studio.modules import modelloader
# modelloader.cleanup_models()
# Create custom models folders if they don't exist
create_checkpoint_folders()

# from apps.shark_studio.modules import sd_models
# sd_models.setup_model()
# startup_timer.record("setup SD model")
import gradio as gr

# initialize_rest(reload_script_modules=False)

Expand Down
2 changes: 1 addition & 1 deletion apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
get_iree_compiled_module,
load_vmfb_using_mmap,
)
from apps.shark_studio.api.utils import get_resource_path
from apps.shark_studio.web.utils.file_utils import get_resource_path
import iree.runtime as ireert
from itertools import chain
import gc
Expand Down
Loading

0 comments on commit 7a0017d

Please sign in to comment.