-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: hpcai opensora 1.2 - VAE 3D training #621
Changes from 22 commits
07c32a8
9d1b10b
585111b
a2af062
6d4d228
d404a39
38400aa
c571585
8c5ac03
22f8075
1c8af61
18b5288
f0ab936
86426e3
54ac478
8b7c9c9
fac4272
1f5aba1
7a4bf32
88de7e2
a93bcb5
2b5a31e
a0b0dcc
9979d8d
6c00ed3
11b80cc
2f95e96
65baffc
ca06f32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -148,6 +148,7 @@ Your contributions are welcome. | |
* [Data Processing](#data-processing) | ||
* [Training](#training) | ||
* [Evaluation](#evaluation) | ||
* [VAE Training & Evaluation](#vae-training--evaluation) | ||
* [Contribution](#contribution) | ||
* [Acknowledgement](#acknowledgement) | ||
|
||
|
@@ -284,6 +285,7 @@ parameters is 724M. More information about training can be found in HPC-AI Tech' | |
</details> | ||
|
||
|
||
|
||
## Inference | ||
|
||
### Open-Sora 1.2 and 1.1 Command Line Inference | ||
|
@@ -759,7 +761,74 @@ Here are some generation results after fine-tuning STDiT on a subset of WebVid d | |
#### Quality Evaluation | ||
For quality evaluation, please refer to the original HPC-AI Tech [evaluation doc](https://github.com/hpcaitech/Open-Sora/blob/main/eval/README.md) for video generation quality evaluation. | ||
|
||
</details> | ||
|
||
## VAE Training & Evaluation | ||
|
||
A 3D-VAE pipeline consisting of a spatial VAE followed by a temporal VAE is trained in OpenSora v1.1. For more details, refer to [VAE Documentation](https://github.com/hpcaitech/Open-Sora/blob/main/docs/vae.md). | ||
|
||
### Prepare Pretrained Weights | ||
|
||
- Download pretained VAE-2D checkpoint from [PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers](https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main/vae) if you aim to train VAE-3D from spatial VAE initialization. | ||
|
||
Convert to ms checkpoint: | ||
``` | ||
python tools/convert_vae1.2.py --src /path/to/pixart_sigma_sdxlvae_T5_diffusers/vae/diffusion_pytorch_model.safetensors --target models/sdxl_vae.ckpt --from_vae2d | ||
``` | ||
|
||
- Downalod pretrained VAE-3D checkpoint from [hpcai-tech/OpenSora-VAE-v1.2](https://huggingface.co/hpcai-tech/OpenSora-VAE-v1.2/tree/main) if you aim to train VAEA-3D from the VAE-3D model pre-trained with 3 stages. | ||
|
||
Convert to ms checkpoint: | ||
``` | ||
python tools/convert_vae1.2.py --src /path/OpenSora-VAE-v1.2/models.safetensors --target models/OpenSora-VAE-v1.2/sdxl_vae.ckpt | ||
``` | ||
|
||
- Download lpips mindspore checkpoint from [here](https://download-mindspore.osinfra.cn/toolkits/mindone/autoencoders/lpips_vgg-426bf45c.ckpt) and put it under 'models/' | ||
|
||
|
||
### Data Preprocess | ||
If you want to train your own VAE, we need to prepare data in the csv following the [data processing](#data-processing) pipeline, then run the following commands. | ||
Note that you need to adjust the number of trained epochs (`epochs`) in the config file accordingly with respect to your own csv data size. | ||
|
||
Task UCF-101 for example. After downloading the [UCF-101](https://www.crcv.ucf.edu/data/UCF101.php) dataset and extract it to `datasets/UCF-101` folder, we can get the csv annotation by running `python tools/annotate_vae_ucf101.py`. | ||
|
||
The resulting train/test annotation csv files, which constain the relative video paths for train/test, will be saved as `datasets/ucf101_train.csv` and `datasets/ucf101_test.csv`. | ||
|
||
### Training | ||
```bash | ||
# stage 1 training, 8 NPUs | ||
msrun --worker_num=8 --local_work_num=8 \ | ||
python scripts/train_vae.py --config configs/vae/train/stage1.yaml --use_parallel=True --csv_path datasets/ucf101_train.csv --video_folder datasets/UCF-101 | ||
|
||
# stage 2 training, 8 NPUs | ||
msrun --worker_num=8 --local_work_num=8 \ | ||
python scripts/train_vae.py --config configs/vae/train/stage2.yaml --use_parallel=True --csv_path datasets/ucf101_train.csv --video_folder datasets/UCF-101 | ||
|
||
# stage 3 training, 8 NPUs | ||
msrun --worker_num=8 --local_work_num=8 \ | ||
python scripts/train_vae.py --config configs/vae/train/stage3.yaml --use_parallel=True --csv_path datasets/ucf101_train.csv --video_folder datasets/UCF-101 | ||
``` | ||
|
||
You can change the `csv_path` and `video_folder` to train on your own data. | ||
|
||
### Performance Evaluation | ||
To evaluate the VAE performance, you need to run VAE inference first to generate the videos, then calculate scores on the generated videos: | ||
|
||
```bash | ||
# video generation and evaluation | ||
python scripts/inference_vae.py --ckpt_path /path/to/you_vae_ckpt --image_size 256 --num_frames=17 --csv_path datasets/ucf101_test.csv --video_folder datasets/UCF-101 | ||
``` | ||
|
||
You can change the `csv_path` and `video_folder` to evaluate on your own data. | ||
|
||
Here, we report the training performance and evaluation results on the UCF-101 dataset. | ||
|
||
| Model | Context | jit_level | Precision | BS | NPUs | Resolution(framesxHxW) | Train T. (s/step) | PSNR | SSIM | | ||
|:------------|:-------------|:--------|:---------:|:--:|:----:|:----------------------:|:-----------------:|:-----------------:|:-----------------:| | ||
| STDiT2-XL/2 | D910\*-[CANN C18(0705)](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3](https://www.mindspore.cn/install) | O1 | BF16 | 1 | 8 | 17x256x256 | 0.97 | 29.29 | 0.88 | | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1、3个stage的性能如果有的话可以一起加一下 |
||
> Context: {G:GPU, D:Ascend}{chip type}-{mindspore version}. | ||
|
||
Note that we train with mixed video ang image strategy i.e. `--mixed_strategy=mixed_video_image` for stage 3 instead of random number of frames (`mixed_video_random`). Random frame training will be supported in the future. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 确实,目前是对齐torch的参数名 |
||
|
||
|
||
## Training and Inference Using the FiT-Like Pipeline | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# model | ||
model_type: "OpenSoraVAE_V1_2" | ||
freeze_vae_2d: True | ||
pretrained_model_path: "models/sdxl_vae.ckpt" | ||
|
||
# loss | ||
perceptual_loss_weight: 0.1 | ||
kl_loss_weight: 1.e-6 | ||
use_real_rec_loss: False | ||
use_z_rec_loss: True | ||
use_image_identity_loss: True | ||
mixed_strategy: "mixed_video_image" | ||
mixed_image_ratio: 0.2 | ||
|
||
# data | ||
dataset_name: "video" | ||
csv_path: "../videocomposer/datasets/webvid5_copy.csv" | ||
video_folder: "../videocomposer/datasets/webvid5" | ||
frame_stride: 1 | ||
num_frames: 17 | ||
image_size: 256 | ||
|
||
micro_frame_size: null | ||
micro_batch_size: null | ||
|
||
# training recipe | ||
seed: 42 | ||
use_discriminator: False | ||
dtype: "fp16" | ||
batch_size: 1 | ||
clip_grad: True | ||
max_grad_norm: 1.0 | ||
start_learning_rate: 1.e-5 | ||
scale_lr: False | ||
use_recompute: False | ||
|
||
epochs: 2000 | ||
ckpt_save_interval: 100 | ||
init_loss_scale: 1. | ||
|
||
scheduler: "constant" | ||
use_ema: False | ||
|
||
output_path: "outputs/causal_vae" | ||
|
||
# ms settting | ||
jit_level: O1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# model | ||
model_type: "OpenSoraVAE_V1_2" | ||
freeze_vae_2d: False | ||
pretrained_model_path: "outputs/vae_stage1.ckpt" | ||
|
||
# loss | ||
perceptual_loss_weight: 0.1 | ||
kl_loss_weight: 1.e-6 | ||
use_real_rec_loss: False | ||
use_z_rec_loss: True | ||
use_image_identity_loss: False | ||
mixed_strategy: "mixed_video_image" | ||
mixed_image_ratio: 0.2 | ||
|
||
# data | ||
dataset_name: "video" | ||
csv_path: "../videocomposer/datasets/webvid5_copy.csv" | ||
video_folder: "../videocomposer/datasets/webvid5" | ||
frame_stride: 1 | ||
num_frames: 17 | ||
image_size: 256 | ||
|
||
micro_frame_size: null | ||
micro_batch_size: null | ||
# flip: True | ||
|
||
# training recipe | ||
seed: 42 | ||
use_discriminator: False | ||
dtype: "bf16" | ||
batch_size: 1 | ||
clip_grad: True | ||
max_grad_norm: 1.0 | ||
start_learning_rate: 1.e-5 | ||
scale_lr: False | ||
use_recompute: True | ||
|
||
epochs: 500 | ||
ckpt_save_interval: 100 | ||
init_loss_scale: 1. | ||
|
||
scheduler: "constant" | ||
use_ema: False | ||
|
||
output_path: "outputs/vae_stage2" | ||
|
||
# ms settting | ||
jit_level: O1 |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,49 @@ | ||||||
# model | ||||||
model_type: "OpenSoraVAE_V1_2" | ||||||
freeze_vae_2d: False | ||||||
pretrained_model_path: "outputs/vae_stage2.ckpt" | ||||||
|
||||||
# loss | ||||||
perceptual_loss_weight: 0.1 | ||||||
kl_loss_weight: 1.e-6 | ||||||
use_real_rec_loss: True | ||||||
use_z_rec_loss: False | ||||||
use_image_identity_loss: False | ||||||
mixed_strategy: "mixed_video_image" # TODO: use mixed_video_random after dynamic shape adaptation | ||||||
mixed_image_ratio: 0.2 | ||||||
|
||||||
# data | ||||||
dataset_name: "video" | ||||||
csv_path: "../videocomposer/datasets/webvid5_copy.csv" | ||||||
video_folder: "../videocomposer/datasets/webvid5" | ||||||
Comment on lines
+17
to
+18
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 引用上级vc感觉有点奇怪,是否可以cp到当前目录 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 视频文件比较大,避免增大repo |
||||||
frame_stride: 1 | ||||||
num_frames: 33 # TODO: set 33 after dynamic shape adaptation and posterior concat fixed | ||||||
image_size: 256 | ||||||
|
||||||
micro_frame_size: 17 | ||||||
micro_batch_size: 4 | ||||||
# flip: True | ||||||
|
||||||
# training recipe | ||||||
seed: 42 | ||||||
use_discriminator: False | ||||||
dtype: "fp16" | ||||||
batch_size: 1 | ||||||
clip_grad: True | ||||||
max_grad_norm: 1.0 | ||||||
start_learning_rate: 1.e-5 | ||||||
scale_lr: False | ||||||
weight_decay: 0. | ||||||
use_recompute: True | ||||||
|
||||||
epochs: 400 | ||||||
ckpt_save_interval: 100 | ||||||
init_loss_scale: 1. | ||||||
|
||||||
scheduler: "constant" | ||||||
use_ema: False | ||||||
|
||||||
output_path: "outputs/vae_stage2" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
# ms settting | ||||||
jit_level: O1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的意思是在数据处理的时候设置epoch size吗 是否可以在实际训练的时候再repeat
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed