Skip to content

Commit

Permalink
Save intermediate values of controlnet (#1981)
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 authored Nov 18, 2023
1 parent 4125a26 commit 80a33d4
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions apps/stable_diffusion/src/utils/stencils/stencil_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import numpy as np
from PIL import Image
import torch
import os
from pathlib import Path
import torchvision
import time
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
Expand All @@ -10,6 +14,33 @@
stencil = {}


def save_img(img):
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)

subdir = Path(
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
os.makedirs(subdir, exist_ok=True)
if isinstance(img, Image.Image):
img.save(
os.path.join(
subdir, "controlnet_" + str(int(time.time())) + ".png"
)
)
elif isinstance(img, np.ndarray):
img = Image.fromarray(img)
img.save(os.path.join(subdir, str(int(time.time())) + ".png"))
else:
converter = torchvision.transforms.ToPILImage()
for i in img:
converter(i).save(
os.path.join(subdir, str(int(time.time())) + ".png")
)


def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
Expand Down Expand Up @@ -161,6 +192,7 @@ def hint_canny(
detected_map = stencil["canny"](
input_image, low_threshold, high_threshold
)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map

Expand All @@ -176,6 +208,7 @@ def hint_openpose(
stencil["openpose"] = OpenposeDetector()

detected_map, _ = stencil["openpose"](input_image)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map

Expand All @@ -187,6 +220,7 @@ def hint_scribble(image: Image.Image):

detected_map = np.zeros_like(input_image, dtype=np.uint8)
detected_map[np.min(input_image, axis=2) < 127] = 255
save_img(detected_map)
return detected_map


Expand All @@ -199,5 +233,6 @@ def hint_zoedepth(image: Image.Image):
stencil["depth"] = ZoeDetector()

detected_map = stencil["depth"](input_image)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map

0 comments on commit 80a33d4

Please sign in to comment.