diff --git a/apps/stable_diffusion/src/utils/stencils/stencil_utils.py b/apps/stable_diffusion/src/utils/stencils/stencil_utils.py index 685d542405..41800f0cb7 100644 --- a/apps/stable_diffusion/src/utils/stencils/stencil_utils.py +++ b/apps/stable_diffusion/src/utils/stencils/stencil_utils.py @@ -1,6 +1,10 @@ import numpy as np from PIL import Image import torch +import os +from pathlib import Path +import torchvision +import time from apps.stable_diffusion.src.utils.stencils import ( CannyDetector, OpenposeDetector, @@ -10,6 +14,33 @@ stencil = {} +def save_img(img): + from apps.stable_diffusion.src.utils import ( + get_generated_imgs_path, + get_generated_imgs_todays_subdir, + ) + + subdir = Path( + get_generated_imgs_path(), get_generated_imgs_todays_subdir() + ) + os.makedirs(subdir, exist_ok=True) + if isinstance(img, Image.Image): + img.save( + os.path.join( + subdir, "controlnet_" + str(int(time.time())) + ".png" + ) + ) + elif isinstance(img, np.ndarray): + img = Image.fromarray(img) + img.save(os.path.join(subdir, str(int(time.time())) + ".png")) + else: + converter = torchvision.transforms.ToPILImage() + for i in img: + converter(i).save( + os.path.join(subdir, str(int(time.time())) + ".png") + ) + + def HWC3(x): assert x.dtype == np.uint8 if x.ndim == 2: @@ -161,6 +192,7 @@ def hint_canny( detected_map = stencil["canny"]( input_image, low_threshold, high_threshold ) + save_img(detected_map) detected_map = HWC3(detected_map) return detected_map @@ -176,6 +208,7 @@ def hint_openpose( stencil["openpose"] = OpenposeDetector() detected_map, _ = stencil["openpose"](input_image) + save_img(detected_map) detected_map = HWC3(detected_map) return detected_map @@ -187,6 +220,7 @@ def hint_scribble(image: Image.Image): detected_map = np.zeros_like(input_image, dtype=np.uint8) detected_map[np.min(input_image, axis=2) < 127] = 255 + save_img(detected_map) return detected_map @@ -199,5 +233,6 @@ def hint_zoedepth(image: Image.Image): stencil["depth"] = ZoeDetector() detected_map = stencil["depth"](input_image) + save_img(detected_map) detected_map = HWC3(detected_map) return detected_map