Skip to content

Commit

Permalink
add sd multi controlnet support
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Sep 6, 2024
1 parent 281d9bb commit 6a7132c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
7 changes: 6 additions & 1 deletion optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,9 +1189,14 @@ def forward(
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union["ControlNetOutput", Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.model)):
if guess_mode:
logger.info(
"Guess mode is not yet supported. File us an issue on: https://github.com/huggingface/optimum-neuron/issues."
)
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
inputs = (sample, timestep, encoder_hidden_states, image, scale)
down_samples, mid_sample = controlnet(*inputs)

Expand Down
17 changes: 11 additions & 6 deletions optimum/neuron/pipelines/diffusers/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def __call__(
global_pool_conditions = (
controlnet.config.global_pool_conditions
if controlnet.__class__.__name__ == "NeuronControlNetModel"
else controlnet.nets[0].config.global_pool_conditions
else controlnet.config[0].global_pool_conditions
)
guess_mode = guess_mode or global_pool_conditions
# TODO: support guess mode of ControlNet
Expand Down Expand Up @@ -502,11 +502,16 @@ def __call__(

# Duplicate inputs for ddp
t = torch.tensor([t] * 2) if self.data_parallel_mode == "unet" else t
cond_scale = (
torch.tensor([cond_scale]).repeat(2)
if self.data_parallel_mode == "unet"
else torch.tensor(cond_scale)
)
if controlnet.__class__.__name__ == "NeuronControlNetModel":
cond_scale = (
torch.tensor([cond_scale]).repeat(2)
if self.data_parallel_mode == "unet"
else torch.tensor(cond_scale)
)
else:
for i, scale in enumerate(cond_scale):
new_scale = torch.tensor([scale]).repeat(2) if self.data_parallel_mode == "unet" else torch.tensor(scale)
cond_scale[i] = new_scale

down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
"soundfile",
"librosa",
"opencv-python-headless",
"controlnet-aux",
"mediapipe",
]

QUALITY_REQUIRES = [
Expand Down

0 comments on commit 6a7132c

Please sign in to comment.