Skip to content

Commit

Permalink
Add multicontrolnet
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 committed Nov 17, 2023
1 parent 192b3b2 commit 5a31eca
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 154 deletions.
59 changes: 32 additions & 27 deletions apps/stable_diffusion/src/models/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
generate_vmfb: bool = True,
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
stencils: list[str] = [],
use_lora: str = "",
use_quantize: str = None,
return_mlir: bool = False,
Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(
self.low_cpu_mem_usage = low_cpu_mem_usage
self.is_inpaint = is_inpaint
self.is_upscaler = is_upscaler
self.use_stencil = get_stencil_model_id(use_stencil)
self.stencils = [get_stencil_model_id(x) for x in stencils]
if use_lora != "":
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.use_lora = use_lora
Expand Down Expand Up @@ -195,8 +195,9 @@ def get_extended_name_for_all_model(self):
)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
model_config = model_config + get_path_stem(self.use_stencil)
# TODO: Fix this
# if "stencil_adaptor" == model and self.use_stencil is not None:
# model_config = model_config + get_path_stem(self.use_stencil)
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
return model_name
Expand Down Expand Up @@ -381,23 +382,24 @@ def forward(
control12,
control13,
):
# TODO: Average pooling
db_res_samples = [
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
]

# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple(
[
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
]
)
db_res_samples = tuple(db_res_samples)
mb_res_samples = control13
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
Expand Down Expand Up @@ -462,11 +464,11 @@ def forward(
)
return shark_controlled_unet, controlled_unet_mlir

def get_control_net(self, use_large=False):
def get_control_net(self, stencil_id, use_large=False):
stencil_id = get_stencil_model_id(stencil_id)

class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.use_stencil, low_cpu_mem_usage=False
):
def __init__(self, model_id=stencil_id, low_cpu_mem_usage=False):
super().__init__()
self.cnet = ControlNetModel.from_pretrained(
model_id,
Expand Down Expand Up @@ -811,7 +813,10 @@ def clip(self):

def unet(self, use_large=False):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
stencil_count = 0
for stencil in self.stencils:
stencil_count += 1
model = "stencil_unet" if stencil_count > 0 else "unet"
compiled_unet = None
unet_inputs = base_models[model]

Expand Down Expand Up @@ -880,13 +885,13 @@ def vae(self):
except Exception as e:
sys.exit(e)

def controlnet(self, use_large=False):
def controlnet(self, stencil_id, use_large=False):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
use_large=use_large
stencil_id, use_large=use_large
)

check_compilation(compiled_stencil_adaptor, "Stencil")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def generate_images(
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
resample_type,
):
# prompts and negative prompts must be a list.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,28 +55,47 @@ def __init__(
import_mlir: bool,
use_lora: str,
ondemand: bool,
controlnet_names: list[str],
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
self.controlnet_512 = None
self.controlnet = [None] * len(controlnet_names)
self.controlnet_512 = [None] * len(controlnet_names)
self.controlnet_id = [str] * len(controlnet_names)
self.controlnet_512_id = [str] * len(controlnet_names)
self.controlnet_names = controlnet_names

def load_controlnet(self):
if self.controlnet is not None:
def load_controlnet(self, index, model_name):
if model_name is None:
return
self.controlnet = self.sd_model.controlnet()
if (
self.controlnet[index] is not None
and self.controlnet_id[index] is not None
and self.controlnet_id[index] == model_name
):
return
self.controlnet_id[index] = model_name
self.controlnet[index] = self.sd_model.controlnet(model_name)

def unload_controlnet(self):
del self.controlnet
self.controlnet = None
def unload_controlnet(self, index):
del self.controlnet[index]
self.controlnet_id[index] = None
self.controlnet[index] = None

def load_controlnet_512(self):
if self.controlnet_512 is not None:
def load_controlnet_512(self, index, model_name):
if (
self.controlnet_512[index] is not None
and self.controlnet_512_id[index] == model_name
):
return
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
self.controlnet_512_id[index] = model_name
self.controlnet_512[index] = self.sd_model.controlnet(
model_name, use_large=True
)

def unload_controlnet_512(self):
del self.controlnet_512
self.controlnet_512 = None
def unload_controlnet_512(self, index):
del self.controlnet_512[index]
self.controlnet_512_id[index] = None
self.controlnet_512[index] = None

def prepare_latents(
self,
Expand Down Expand Up @@ -111,7 +130,7 @@ def produce_stencil_latents(
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
stencil_hints=[None],
controlnet_conditioning_scale: float = 1.0,
mask=None,
masked_image_latents=None,
Expand All @@ -123,10 +142,15 @@ def produce_stencil_latents(
text_embeddings_numpy = text_embeddings.detach().numpy()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
self.load_controlnet()
else:
self.load_unet_512()
self.load_controlnet_512()

for i, name in enumerate(self.controlnet_names):
if text_embeddings.shape[1] <= self.model_max_length:
self.load_controlnet(i, name)
else:
self.load_controlnet_512(i, name)

for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
Expand All @@ -149,33 +173,49 @@ def produce_stencil_latents(
).to(dtype)
else:
latent_model_input_1 = latent_model_input
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
else:
control = self.controlnet_512(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)

# Multicontrolnet
control_steps = []
for i, controlnet_hint in enumerate(stencil_hints):
if controlnet_hint is None:
continue
if text_embeddings.shape[1] <= self.model_max_length:
subcontrol = self.controlnet[i](
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
else:
subcontrol = self.controlnet_512[i](
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
control_steps.append(subcontrol)

timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.

control = []
for subcontrol in control_steps:
subcontrol = [torch.from_numpy(np.asarray(x)) for x in subcontrol]
if not control:
control = subcontrol
else:
control = [x + y for (x, y) in zip(control, subcontrol)]

if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
Expand Down Expand Up @@ -245,8 +285,9 @@ def produce_stencil_latents(
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
self.unload_controlnet()
self.unload_controlnet_512()
for i in range(len(self.controlnet_names)):
self.unload_controlnet(i)
self.unload_controlnet_512(i)
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"

Expand All @@ -272,14 +313,29 @@ def generate_images(
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
stencils,
stencil_images,
resample_type,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.
controlnet_hint = controlnet_hint_conversion(
image, use_stencil, height, width, dtype, num_images_per_prompt=1
)
# controlnet_hint = controlnet_hint_conversion(
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
# )
stencil_hints = []
for i, stencil in enumerate(stencils):
image = stencil_images[i]
stencil_hints.append(
controlnet_hint_conversion(
image,
stencil,
height,
width,
dtype,
num_images_per_prompt=1,
)
)

# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
Expand Down Expand Up @@ -327,7 +383,7 @@ def generate_images(
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
stencil_hints=stencil_hints,
)

# Img latents -> PIL images
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,8 @@ def from_pretrained(
ondemand: bool,
low_cpu_mem_usage: bool = False,
debug: bool = False,
use_stencil: str = None,
stencils: list[str] = [],
# stencil_images: list[Image] = []
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
Expand Down Expand Up @@ -371,7 +372,7 @@ def from_pretrained(
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
stencils=stencils,
use_lora=use_lora,
use_quantize=use_quantize,
)
Expand All @@ -386,6 +387,10 @@ def from_pretrained(
ondemand,
)

if cls.__name__ == "StencilPipeline":
return cls(
scheduler, sd_model, import_mlir, use_lora, ondemand, stencils
)
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)

# #####################################################
Expand Down
Loading

0 comments on commit 5a31eca

Please sign in to comment.