ControlNet: Adding Conditional Control to Text-to-Image Diffusion Models
ControlNet controls pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small. Large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like canny edge maps, segmentation maps, keypoints, etc.
Fig 1. Illustration of a ControlNet [1]
- AI framework: MindSpore >= 2.2
- Hardware: Ascend 910*
cd examples/stable_diffusion_xl
pip install -r requirement.txt
1. Convert trained weight from Diffusers
Step1: Convert SDXL-base-1.0 model weight from Diffusers to MindONE, refer to here. Get sd_xl_base_1.0_ms.ckpt
.
Step2: Since ControlNet acts like a plug-in to the SDXL, we convert the ControlNet weight diffusion_pytorch_model.safetensors
from diffusers/controlnet-canny-sdxl-1.0
to MindSpore version and then merge it into the SDXL-base-1.0 MindONE model weight (sd_xl_base_1.0_ms.ckpt
). Eventually, we get the ControlNet + SDXL-base-1.0 MindONE model weight (sd_xl_base_1.0_controlnet_canny_ms.ckpt
).
cd tools/controlnet_conversion
python convert_weight.py \
--weight_torch_controlnet /PATH TO/diffusion_pytorch_model.safetensors \
--weight_ms_sdxl /PATH TO/sd_xl_base_1.0_ms.ckpt \
--output_ms_ckpt_path /PATH TO/sd_xl_base_1.0_controlnet_canny_ms.ckpt
Note: The ControlNet weight parameters name mapping between Diffusers and MindONE is prepared:
tools/controlnet_conversion/controlnet_ms2torch_mapping.yaml
.
2. Or train your ControlNet using MindONE, check Training section below
Stable Diffusion XL with ControlNet can generate images following the input control signal (e.g. canny edge). You can either prepare (1) a raw image (Fig 2) to be extracted control signal from, or (2) the control signal image itself (Fig 3).
Fig 2. raw image to be extracted control signal Fig 3. control signal image (canny edge)
Please refer to scripts/run_infer_base_controlnet.sh
.
python demo/sampling_without_streamlit.py \
--task txt2img \
--config configs/inference/sd_xl_base_controlnet.yaml \
--weight checkpoints/sd_xl_base_1.0_controlnet_canny_ms.ckpt \
--guidance_scale 9.0 \
--controlnet_mode canny \
--control_image_path /PATH TO/dog2.png \
--prompt "cute dog, best quality, extremely detailed" \
Key arguments:
weight
: path to the model weight, refer to Prepare model weight section.guidance_scale
: the guidance scale for txt2img and img2img tasks. For NoDynamicThresholding, uncond + guidance_scale * (uncond - cond). Note that this scale could heavily impact the inference result.controlnet_mode
: Control mode for controlnet, supported mode: 'raw': use the image itself as control signal; 'canny': use canny edge detector to extract control signal from input image.control_image_path
: path of input image for controlnet.prompt
: positve text prompt for image generation.
You can check all arguments description by running python demo/sampling_without_streamlit.py -h
.
Fig 4. From left to right: raw image - extracted canny edge - inference result.
Prompt: "cute dog, best quality, extremely detailed".
Fig 5. From left to right: raw image - extracted canny edge - inference result.
Prompt: "beautiful bird standing on a trunk, natural color, best quality, extremely detailed".
Step1: Convert SDXL-base-1.0 model weight from Diffusers to MindONE, refer to here. Get sd_xl_base_1.0_ms.ckpt
.
Step2:
cd tools/controlnet_conversion
python init_weight.py
The parameters of zero_conv
, input_hint_block
and middle_block_out
blocks are randomly initialized in ControlNet. Other parameters of ControlNet are copied from SDXL pretrained weight sd_xl_base_1.0_ms.ckpt
(referring to here).
We use Fill50k dataset to train the model to generate images following the edge control. The directory struture of Fill50k dataset is shown below.
DATA_PATH
├── prompt.json
├── source
│ ├── 0.png
│ ├── 1.png
│ └── ...
└── target
├── 0.png
├── 1.png
└── ...
Images in target/
are raw images. Images in source/
are the canny edge/segementation/other control images extracted from the corresponding raw images. For example, source/img0.png
is the canny edge image of target/img0.png
.
prompt.json
is the annotation file with the following format.
{"source": "source/0.png", "target": "target/0.png", "prompt": "pale golden rod circle with old lace background"}
{"source": "source/1.png", "target": "target/1.png", "prompt": "light coral circle with white background"}
{"source": "source/2.png", "target": "target/2.png", "prompt": "aqua circle with light pink background"}
{"source": "source/3.png", "target": "target/3.png", "prompt": "cornflower blue circle with light golden rod yellow background"}
Note: if you want to use your own dataset for training, please follow the directory and file structure shown above.
Please refer to scripts/run_train_base_controlnet.sh
.
nohup mpirun -n 8 --allow-run-as-root python train_controlnet.py \
--data_path DATA_PATH \
--weight PATH TO/sd_xl_base_1.0_ms_controlnet_init.ckpt \
--config configs/training/sd_xl_base_finetune_controlnet_910b.yaml \
--total_step 300000 \
--per_batch_size 2 \
--group_lr_scaler 10.0 \
--save_ckpt_interval 10000 \
--max_num_ckpt 5 \
> train.log 2>&1 &
- The parameters of
zero_conv
,input_hint_block
andmiddle_block_out
blocks are randomly initialized in ControlNet, which are very hard to train. We scale up (x10 by default) the base learning rate for training parameters specifically. You can set the scale value byargs.group_lr_scaler
. - As mentioned in ControlNet paper[1] and repo, there is a sudden convergence phenomenon in ControlNet training, which means the training steps should be large enough to let the training converge SUDDENLY and then generate images following the control signals. For ControlNet + SDXL, the training steps should be even much more larger.
- As mentioned in ControlNet paper[1], randomly dropping 50% text prompt during training is very helpful for ControlNet to learn the control signals. Don't miss that.
Key settings:
base_learning_rate | group_lr_scaler | global batch size (#NPUs * bs per NPU) | total_step | inference guidance_scale |
---|---|---|---|---|
4.0e-5 | 10.0 | 64 (32 x 2) | 220k | 15.0 |
Ground truth:
Our prediction:
Prompts (correspond to the images above from left to right):
"light coral circle with white background"
"light sea green circle with dark salmon background"
"medium sea green circle with black background"
"dark turquoise circle with medium spring green background"
[1] ControlNet: Adding Conditional Control to Text-to-Image Diffusion Models