Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: hpcai opensora 1.2 - VAE 3D training #621

Merged
merged 29 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
07c32a8
add vae3d training
SamitHuang Aug 1, 2024
9d1b10b
rm redundant
SamitHuang Aug 1, 2024
585111b
fix format
SamitHuang Aug 1, 2024
a2af062
vae split tuple
SamitHuang Aug 1, 2024
6d4d228
rewrite micro_frame_size impl, vae3d reconstruct ok
SamitHuang Aug 1, 2024
d404a39
use make_tuple, reconstruct & vae static train ok
SamitHuang Aug 2, 2024
38400aa
better micro batch/frame size writing
SamitHuang Aug 3, 2024
c571585
add dynamic shape script
SamitHuang Aug 7, 2024
8c5ac03
fix empty ckpt
SamitHuang Aug 7, 2024
22f8075
fix empty ckpt
SamitHuang Aug 7, 2024
1c8af61
debug: add dynamic shape support
SamitHuang Aug 7, 2024
18b5288
update
SamitHuang Aug 8, 2024
f0ab936
fix min(a,b) in dynamic shape
SamitHuang Aug 8, 2024
86426e3
Merge branch 'pr_vae1.2_train' of https://github.com/samithuang/mindo…
SamitHuang Aug 8, 2024
54ac478
default dynamic shape in script
SamitHuang Aug 8, 2024
8b7c9c9
debug: update dyanmic shape train script
SamitHuang Aug 8, 2024
fac4272
Merge branch 'master' into pr_vae1.2_train
SamitHuang Aug 24, 2024
1f5aba1
download lpips auto
SamitHuang Aug 24, 2024
7a4bf32
fix typo and linting
SamitHuang Aug 26, 2024
88de7e2
rm redundancy
SamitHuang Aug 26, 2024
a93bcb5
small fix
SamitHuang Aug 26, 2024
2b5a31e
linting
SamitHuang Aug 26, 2024
a0b0dcc
jit_level O0 for less overflow
SamitHuang Aug 28, 2024
9979d8d
fix docs
SamitHuang Aug 28, 2024
6c00ed3
rm redundancy
SamitHuang Aug 28, 2024
11b80cc
rm file
SamitHuang Aug 28, 2024
2f95e96
update doc
SamitHuang Sep 26, 2024
65baffc
update perf
SamitHuang Sep 27, 2024
ca06f32
update config
SamitHuang Sep 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 70 additions & 1 deletion examples/opensora_hpcai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ Your contributions are welcome.
* [Data Processing](#data-processing)
* [Training](#training)
* [Evaluation](#evaluation)
* [VAE Training & Evaluation](#vae-training--evaluation)
* [Contribution](#contribution)
* [Acknowledgement](#acknowledgement)

Expand Down Expand Up @@ -284,6 +285,7 @@ parameters is 724M. More information about training can be found in HPC-AI Tech'
</details>



## Inference

### Open-Sora 1.2 and 1.1 Command Line Inference
Expand Down Expand Up @@ -759,7 +761,74 @@ Here are some generation results after fine-tuning STDiT on a subset of WebVid d
#### Quality Evaluation
For quality evaluation, please refer to the original HPC-AI Tech [evaluation doc](https://github.com/hpcaitech/Open-Sora/blob/main/eval/README.md) for video generation quality evaluation.

</details>

## VAE Training & Evaluation

A 3D-VAE pipeline consisting of a spatial VAE followed by a temporal VAE is trained in OpenSora v1.1. For more details, refer to [VAE Documentation](https://github.com/hpcaitech/Open-Sora/blob/main/docs/vae.md).

### Prepare Pretrained Weights

- Download pretained VAE-2D checkpoint from [PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers](https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main/vae) if you aim to train VAE-3D from spatial VAE initialization.

Convert to ms checkpoint:
```
python tools/convert_vae1.2.py --src /path/to/pixart_sigma_sdxlvae_T5_diffusers/vae/diffusion_pytorch_model.safetensors --target models/sdxl_vae.ckpt --from_vae2d
```

- Downalod pretrained VAE-3D checkpoint from [hpcai-tech/OpenSora-VAE-v1.2](https://huggingface.co/hpcai-tech/OpenSora-VAE-v1.2/tree/main) if you aim to train VAEA-3D from the VAE-3D model pre-trained with 3 stages.

Convert to ms checkpoint:
```
python tools/convert_vae1.2.py --src /path/OpenSora-VAE-v1.2/models.safetensors --target models/OpenSora-VAE-v1.2/sdxl_vae.ckpt
```

- Download lpips mindspore checkpoint from [here](https://download-mindspore.osinfra.cn/toolkits/mindone/autoencoders/lpips_vgg-426bf45c.ckpt) and put it under 'models/'


### Data Preprocess
If you want to train your own VAE, we need to prepare data in the csv following the [data processing](#data-processing) pipeline, then run the following commands.
Note that you need to adjust the number of trained epochs (`epochs`) in the config file accordingly with respect to your own csv data size.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的意思是在数据处理的时候设置epoch size吗 是否可以在实际训练的时候再repeat

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


Task UCF-101 for example. After downloading the [UCF-101](https://www.crcv.ucf.edu/data/UCF101.php) dataset and extract it to `datasets/UCF-101` folder, we can get the csv annotation by running `python tools/annotate_vae_ucf101.py`.

The resulting train/test annotation csv files, which constain the relative video paths for train/test, will be saved as `datasets/ucf101_train.csv` and `datasets/ucf101_test.csv`.

### Training
```bash
# stage 1 training, 8 NPUs
msrun --worker_num=8 --local_work_num=8 \
python scripts/train_vae.py --config configs/vae/train/stage1.yaml --use_parallel=True --csv_path datasets/ucf101_train.csv --video_folder datasets/UCF-101

# stage 2 training, 8 NPUs
msrun --worker_num=8 --local_work_num=8 \
python scripts/train_vae.py --config configs/vae/train/stage2.yaml --use_parallel=True --csv_path datasets/ucf101_train.csv --video_folder datasets/UCF-101

# stage 3 training, 8 NPUs
msrun --worker_num=8 --local_work_num=8 \
python scripts/train_vae.py --config configs/vae/train/stage3.yaml --use_parallel=True --csv_path datasets/ucf101_train.csv --video_folder datasets/UCF-101
```

You can change the `csv_path` and `video_folder` to train on your own data.

### Performance Evaluation
To evaluate the VAE performance, you need to run VAE inference first to generate the videos, then calculate scores on the generated videos:

```bash
# video generation and evaluation
python scripts/inference_vae.py --ckpt_path /path/to/you_vae_ckpt --image_size 256 --num_frames=17 --csv_path datasets/ucf101_test.csv --video_folder datasets/UCF-101
```

You can change the `csv_path` and `video_folder` to evaluate on your own data.

Here, we report the training performance and evaluation results on the UCF-101 dataset.

| Model | Context | jit_level | Precision | BS | NPUs | Resolution(framesxHxW) | Train T. (s/step) | PSNR | SSIM |
|:------------|:-------------|:--------|:---------:|:--:|:----:|:----------------------:|:-----------------:|:-----------------:|:-----------------:|
| STDiT2-XL/2 | D910\*-[CANN C18(0705)](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3](https://www.mindspore.cn/install) | O1 | BF16 | 1 | 8 | 17x256x256 | 0.97 | 29.29 | 0.88 |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1、3个stage的性能如果有的话可以一起加一下
2、并行策略和datasink有用到的话建议也加一下

> Context: {G:GPU, D:Ascend}{chip type}-{mindspore version}.

Note that we train with mixed video ang image strategy i.e. `--mixed_strategy=mixed_video_image` for stage 3 instead of random number of frames (`mixed_video_random`). Random frame training will be supported in the future.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--mixed_strategy 这个参数感觉有些不清晰,感觉没有表达出 video/image sample stretagy 的含义

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实,目前是对齐torch的参数名



## Training and Inference Using the FiT-Like Pipeline

Expand Down
47 changes: 47 additions & 0 deletions examples/opensora_hpcai/configs/vae/train/stage1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# model
model_type: "OpenSoraVAE_V1_2"
freeze_vae_2d: True
pretrained_model_path: "models/sdxl_vae.ckpt"

# loss
perceptual_loss_weight: 0.1
kl_loss_weight: 1.e-6
use_real_rec_loss: False
use_z_rec_loss: True
use_image_identity_loss: True
mixed_strategy: "mixed_video_image"
mixed_image_ratio: 0.2

# data
dataset_name: "video"
csv_path: "../videocomposer/datasets/webvid5_copy.csv"
video_folder: "../videocomposer/datasets/webvid5"
frame_stride: 1
num_frames: 17
image_size: 256

micro_frame_size: null
micro_batch_size: null

# training recipe
seed: 42
use_discriminator: False
dtype: "fp16"
batch_size: 1
clip_grad: True
max_grad_norm: 1.0
start_learning_rate: 1.e-5
scale_lr: False
use_recompute: False

epochs: 2000
ckpt_save_interval: 100
init_loss_scale: 1.

scheduler: "constant"
use_ema: False

output_path: "outputs/causal_vae"

# ms settting
jit_level: O1
48 changes: 48 additions & 0 deletions examples/opensora_hpcai/configs/vae/train/stage2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# model
model_type: "OpenSoraVAE_V1_2"
freeze_vae_2d: False
pretrained_model_path: "outputs/vae_stage1.ckpt"

# loss
perceptual_loss_weight: 0.1
kl_loss_weight: 1.e-6
use_real_rec_loss: False
use_z_rec_loss: True
use_image_identity_loss: False
mixed_strategy: "mixed_video_image"
mixed_image_ratio: 0.2

# data
dataset_name: "video"
csv_path: "../videocomposer/datasets/webvid5_copy.csv"
video_folder: "../videocomposer/datasets/webvid5"
frame_stride: 1
num_frames: 17
image_size: 256

micro_frame_size: null
micro_batch_size: null
# flip: True

# training recipe
seed: 42
use_discriminator: False
dtype: "bf16"
batch_size: 1
clip_grad: True
max_grad_norm: 1.0
start_learning_rate: 1.e-5
scale_lr: False
use_recompute: True

epochs: 500
ckpt_save_interval: 100
init_loss_scale: 1.

scheduler: "constant"
use_ema: False

output_path: "outputs/vae_stage2"

# ms settting
jit_level: O1
49 changes: 49 additions & 0 deletions examples/opensora_hpcai/configs/vae/train/stage3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# model
model_type: "OpenSoraVAE_V1_2"
freeze_vae_2d: False
pretrained_model_path: "outputs/vae_stage2.ckpt"

# loss
perceptual_loss_weight: 0.1
kl_loss_weight: 1.e-6
use_real_rec_loss: True
use_z_rec_loss: False
use_image_identity_loss: False
mixed_strategy: "mixed_video_image" # TODO: use mixed_video_random after dynamic shape adaptation
mixed_image_ratio: 0.2

# data
dataset_name: "video"
csv_path: "../videocomposer/datasets/webvid5_copy.csv"
video_folder: "../videocomposer/datasets/webvid5"
Comment on lines +17 to +18
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

引用上级vc感觉有点奇怪,是否可以cp到当前目录

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

视频文件比较大,避免增大repo

frame_stride: 1
num_frames: 33 # TODO: set 33 after dynamic shape adaptation and posterior concat fixed
image_size: 256

micro_frame_size: 17
micro_batch_size: 4
# flip: True

# training recipe
seed: 42
use_discriminator: False
dtype: "fp16"
batch_size: 1
clip_grad: True
max_grad_norm: 1.0
start_learning_rate: 1.e-5
scale_lr: False
weight_decay: 0.
use_recompute: True

epochs: 400
ckpt_save_interval: 100
init_loss_scale: 1.

scheduler: "constant"
use_ema: False

output_path: "outputs/vae_stage2"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_path: "outputs/vae_stage2"
output_path: "outputs/vae_stage3"


# ms settting
jit_level: O1
Loading