Skip to content

Commit

Permalink
Fix OpenSora 1.2 training (#643)
Browse files Browse the repository at this point in the history
* lazy_inline to reduce compile time

* lazy_inline to reduce compile time

* leave large mem for communication, max_device_memory 55GB

* fix train_steps bug

* fix linting

* add lazy_inline to vae encoder/decoder

* rm lazy inline for vae due to perf drop

* rm lazy_inline in vae

* only require decord when backend selected

* fix logging

* fix logging

* x1: rm duplicated norm

* x-1: use ops.rms_norm, mint.layer_norm

* x-2: rm hs list in vae encode

* x-3: use self-impl repeat interleave

* fix layernorm

* record shape

* balance bucket config for A+M

* revert repeat interleave for safe

* increase bs for 256 res

* add shape step time analysis script

* fix stop

* rm pdb

* acc by add Symbol

* clear mem in the end of epoch

* update doc

* impr bucket analysis

* add stage3 balanced bucket

* fix lint

* fix linting

* Update README.md

* add comments
  • Loading branch information
SamitHuang authored Sep 27, 2024
1 parent 2ea7619 commit 7057ddb
Show file tree
Hide file tree
Showing 16 changed files with 373 additions and 56 deletions.
29 changes: 21 additions & 8 deletions examples/opensora_hpcai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,18 @@ Other useful documents and links are listed below.
## Installation

1. Install MindSpore according to the [official instructions](https://www.mindspore.cn/install).
For Ascend devices, please install **CANN driver C18 (0705)** from [here](https://repo.mindspore.cn/ascend/ascend910/20240705/) and install **MindSpore 2.3** from [here](https://www.mindspore.cn/install).
For Ascend devices, please install [CANN8.0.RC2.beta1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.0.RC2.beta1) and install [MindSpore 2.3.1](https://www.mindspore.cn/install).
> To reduce compilation time and training time, you may install MindSpore2.4-20240904 from [here](https://repo.mindspore.cn/mindspore/mindspore/version/202409/20240904/master_20240904010023_67b5df247045f509c4ca2169bac6a551291a3111_newest/unified/aarch64/)
2. Install requirements
You may check your versions by running the following commands. The default installation path of CANN is usually `/usr/local/Ascend/ascend-toolkit` unless you specify a custom one.

```bash
cat /usr/local/Ascend/ascend-toolkit/latest/version.cfg

python -c "import mindspore;mindspore.set_context(device_target='Ascend');mindspore.run_check()"
```

3. Install requirements
```bash
pip install -r requirements.txt
```
Expand Down Expand Up @@ -517,9 +526,11 @@ bucket_config:
"2048": { 1: [ 0.01, 5 ] }
```
Knowing that the optimal bucket config can varies from device to device, we have tuned and provided bucket config that are more balanced on Ascend + MindSpore in `configs/opensora-v1-2/train/{stage}_ms.yaml`. You may use them for better training performance.
More details on the bucket configuration can be found in [Multi-resolution Training with Buckets](./docs/quick_start.md#4-multi-resolution-training-with-buckets-opensora-v11-and-above).
Then you can launch the dynamic training task following the previous section. An example running script is `scripts/run/run_train_os1.2_stage2.sh`.
The instruction for launching the dynamic training task is smilar to the previous section. An example running script is `scripts/run/run_train_os1.2_stage2.sh`.
### Open-Sora 1.1
Expand Down Expand Up @@ -622,11 +633,13 @@ Here ✅ means that the data is seen during training, and 🆗 means although no
We evaluate the training performance of Open-Sora v1.2 on the MixKit dataset with high-resolution videos (1080P, duration 12s to 100s). The results are as follows.
| Model | Context | jit_level | Precision | BS | NPUs | Size (TxHxW) | Train T. (s/step) |
|:------------|:-------------|:--------|:---------:|:--:|:----:|:----------------------:|:-----------------:|
| STDiT3-XL/2 | D910\*-[C18](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3](https://www.mindspore.cn/install) | O1 | BF16 | 1 | 8 | 51x720x1280 | **14.60** |
| STDiT3-XL/2 | D910\*-[C18](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3.1(0726)](https://repo.mindspore.cn/mindspore/mindspore/version/202407/20240726/master_20240726220021_4c913fb116c83b9ad28666538483264da8aebe8c_newest/unified/) | O1 | BF16 | 1 | 8 | Stage 2 Dyn. | **33.10** |
| STDiT3-XL/2 | D910\*-[C18](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3.1(0726)](https://repo.mindspore.cn/mindspore/mindspore/version/202407/20240726/master_20240726220021_4c913fb116c83b9ad28666538483264da8aebe8c_newest/unified/) | O1 | BF16 | 1 | 8 | Stage 3 Dyn. | **37.7** |
| Model | Context | jit_level | Precision | BS | NPUs | Size (TxHxW) | Train T. (s/step) | config |
|:------------|:-------------|:--------|:---------:|:--:|:----:|:----------------------:|:-----------------:|:-----------------:|
| STDiT3-XL/2 | D910\*-[C18](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3](https://www.mindspore.cn/install) | O1 | BF16 | 1 | 8 | 51x720x1280 | **14.60** | [yaml](configs/opensora-v1-2/train/train_720x1280x51.yaml) |
| STDiT3-XL/2 | D910\*-[C18](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3.1(0726)](https://repo.mindspore.cn/mindspore/mindspore/version/202407/20240726/master_20240726220021_4c913fb116c83b9ad28666538483264da8aebe8c_newest/unified/) | O1 | BF16 | 1 | 8 | Stage 2 Dyn. | **33.10** | [yaml](configs/opensora-v1-2/train/train_stage2.yaml) |
| STDiT3-XL/2 | D910\*-[C18](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3.1(0726)](https://repo.mindspore.cn/mindspore/mindspore/version/202407/20240726/master_20240726220021_4c913fb116c83b9ad28666538483264da8aebe8c_newest/unified/) | O1 | BF16 | 1 | 8 | Stage 3 Dyn. | **34** | [yaml](configs/opensora-v1-2/train/train_stage3.yaml) |
> Context: {G:GPU, D:Ascend}{chip type}-{CANN version}-{mindspore version}; "Dyn." is short for dynamic shape.
Note that the step time of dynamic training can be influenced by the resolution and duration distribution of the source videos. Training performance is under optimization.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# model
model_version: v1.2
pretrained_model_path: PATH_TO_YOUR_MODEL
model_max_length: 300
freeze_y_embedder: True

noise_scheduler: rflow
sample_method: logit-normal
use_timestep_transform: True

vae_type: OpenSoraVAE_V1_2
vae_checkpoint: models/OpenSora-VAE-v1.2/model.ckpt
vae_dtype: bf16
vae_micro_batch_size: 4
vae_micro_frame_size: 17 # keep it unchanged for the best results

enable_flash_attention: True
use_recompute: True

# data
num_parallel_workers: 2
num_workers_dataset: 2
prefetch_size: 2
max_rowsize: 256

# mindspore params, refer to https://www.mindspore.cn/docs/zh-CN/r2.3.1/api_python/mindspore/mindspore.set_context.html
max_device_memory: "59GB"
jit_level: "O1"
manual_pad: True

# precision
amp_level: "O2"
dtype: bf16
loss_scaler_type: static
init_loss_scale: 1

# training hyper-params
scheduler: "constant"
start_learning_rate: 1.e-4
end_learning_rate: 1.e-4
# warmup_steps: 1000

clip_grad: True
max_grad_norm: 1.0
use_ema: True
# ema_decay: 0.99 # default 0.9999 gives better result in our experiments

optim: "adamw_re"
optim_eps: 1e-15
weight_decay: 0.

# epochs: 2
train_steps: 23000
ckpt_save_steps: 500

mask_ratios:
random: 0.005
interpolate: 0.002
quarter_random: 0.007
quarter_head: 0.002
quarter_tail: 0.002
quarter_head_tail: 0.002
image_random: 0.0
image_head: 0.22
image_tail: 0.005
image_head_tail: 0.005

bucket_config:
# Structure: "resolution": { num_frames: [ keep_prob, batch_size ] }
# Setting [ keep_prob, batch_size ] to [ 0.0, 0 ] forces longer videos into smaller resolution buckets
"144p": { 1: [ 1.0, 475 ], 51: [ 1.0, 40 ], 102: [ [ 1.0, 0.33 ], 20 ], 204: [ [ 1.0, 0.1 ], 10 ], 408: [ [ 1.0, 0.1 ], 6 ] }
"256": { 1: [ 0.4, 297 ], 51: [ 0.5, 24 ], 102: [ [ 0.5, 0.33 ], 12 ], 204: [ [ 0.5, 1.0 ], 6 ], 408: [ [ 0.5, 1.0 ], 2 ] }
"240p": { 1: [ 0.3, 297 ], 51: [ 0.4, 16 ], 102: [ [ 0.4, 0.33 ], 8 ], 204: [ [ 0.4, 1.0 ], 4 ], 408: [ [ 0.4, 1.0 ], 2 ] }
"360p": { 1: [ 0.5, 141 ], 51: [ 0.15, 6 ], 102: [ [ 0.3, 0.5 ], 3 ], 204: [ [ 0.3, 1.0 ], 2 ], 408: [ [ 0.5, 0.5 ], 1 ] }
"512": { 1: [ 0.4, 141 ], 51: [ 0.15, 6 ], 102: [ [ 0.2, 0.4 ], 3 ], 204: [ [ 0.2, 1.0 ], 1 ], 408: [ [ 0.4, 0.5 ], 1 ] }
"480p": { 1: [ 0.5, 89 ], 51: [ 0.2, 5 ], 102: [ 0.2, 2 ], 204: [ 0.1, 1 ] }
"720p": { 1: [ 0.1, 36 ], 51: [ 0.03, 1 ] }
"1024": { 1: [ 0.1, 36 ], 51: [ 0.02, 1 ] }
"1080p": { 1: [ 0.01, 5 ] }
"2048": { 1: [ 0.01, 5 ] }


# ---------- Validation ----------
validate: False
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# model
model_version: v1.2
pretrained_model_path: PATH_TO_YOUR_MODEL
model_max_length: 300
freeze_y_embedder: True

noise_scheduler: rflow
sample_method: logit-normal
use_timestep_transform: True

vae_type: OpenSoraVAE_V1_2
vae_checkpoint: models/OpenSora-VAE-v1.2/model.ckpt
vae_dtype: bf16
vae_micro_batch_size: 4
vae_micro_frame_size: 17 # keep it unchanged for the best results

enable_flash_attention: True
use_recompute: True

# data
num_parallel_workers: 2
num_workers_dataset: 2
prefetch_size: 2
max_rowsize: 256

# precision
amp_level: "O2"
dtype: bf16
loss_scaler_type: static
init_loss_scale: 1

# mindspore params, refer to https://www.mindspore.cn/docs/zh-CN/r2.3.1/api_python/mindspore/mindspore.set_context.html
max_device_memory: "59GB"
jit_level: "O1"
manual_pad: True

# training hyper-params
scheduler: "constant"
start_learning_rate: 1.e-4
end_learning_rate: 1.e-4
warmup_steps: 1000

clip_grad: True
max_grad_norm: 1.0
use_ema: True
# ema_decay: 0.99

optim: "adamw_re"
optim_eps: 1e-15
weight_decay: 0.

# epochs: 15
train_steps: 15000
ckpt_save_steps: 500

mask_ratios:
random: 0.01
interpolate: 0.002
quarter_random: 0.002
quarter_head: 0.002
quarter_tail: 0.002
quarter_head_tail: 0.002
image_random: 0.0
image_head: 0.22
image_tail: 0.005
image_head_tail: 0.005

bucket_config:
# Structure: "resolution": { num_frames: [ keep_prob, batch_size ] }
# Setting [ keep_prob, batch_size ] to [ 0.0, 0 ] forces longer videos into smaller resolution buckets
"144p": {1: [1.0, 475], 51: [1.0, 51], 102: [1.0, 27], 204: [1.0, 13], 408: [1.0, 6]}
"256": {1: [1.0, 297], 51: [0.5, 20], 102: [0.5, 10], 204: [0.5, 6], 408: [[0.5, 0.5], 2]}
"240p": {1: [1.0, 297], 51: [0.5, 20], 102: [0.5, 10], 204: [0.5, 5], 408: [[0.5, 0.4], 2]}
"360p": {1: [1.0, 141], 51: [0.5, 8], 102: [0.5, 4], 204: [0.5, 2], 408: [[0.5, 0.3], 1]}
"512": {1: [1.0, 141], 51: [0.5, 8], 102: [0.5, 4], 204: [0.5, 2], 408: [[0.5, 0.2], 1]}
"480p": {1: [1.0, 89], 51: [0.5, 5], 102: [0.5, 2], 204: [[0.5, 0.5], 1], 408: [0.0, 0]}
"720p": {1: [0.3, 36], 51: [0.2, 2], 102: [0.1, 1], 204: [0.0, 0]}
"1024": {1: [0.3, 36], 51: [0.1, 2], 102: [0.1, 1], 204: [0.0, 0]}
"1080p": {1: [0.1, 5]}
"2048": {1: [0.05, 5]}

# ---------- Validation ----------
validate: False
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import cv2
import numpy as np
from decord import VideoReader
from tqdm import tqdm

import mindspore as ms
Expand Down Expand Up @@ -252,6 +251,8 @@ def _get_item(self, idx: int) -> Tuple[Any, ...]:

else:
if self.video_backend == "decord":
from decord import VideoReader

reader = VideoReader(video_path)
min_length = self._min_length
video_length = len(reader)
Expand Down
12 changes: 6 additions & 6 deletions examples/opensora_hpcai/opensora/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

import mindspore as ms
from mindspore import Parameter, Tensor, nn, ops
from mindspore import Parameter, Tensor, mint, nn, ops
from mindspore.common.initializer import initializer

from mindone.models.modules.flash_attention import FLASH_IS_AVAILABLE, MSFlashAttention
Expand All @@ -26,9 +26,7 @@ def __init__(self, hidden_size, eps=1e-6):
self.variance_epsilon = eps

def construct(self, hidden_states: Tensor):
variance = hidden_states.pow(2).mean(-1, keep_dims=True)
hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon)
return self.gamma * hidden_states
return ops.rms_norm(hidden_states, self.gamma, self.variance_epsilon)[0]


class Attention(nn.Cell):
Expand Down Expand Up @@ -325,10 +323,12 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine: bool = True,
else:
self.gamma = ops.ones(normalized_shape, dtype=dtype)
self.beta = ops.zeros(normalized_shape, dtype=dtype)
self.layer_norm = ops.LayerNorm(-1, -1, epsilon=eps)

def construct(self, x: Tensor):
x, _, _ = self.layer_norm(x, self.gamma, self.beta)
normalized_shape = x.shape[-1:]
# mint layer_norm fuses the operations in layer normorlization and it's faster than ops.LayerNorm
x = mint.nn.functional.layer_norm(x, normalized_shape, self.gamma, self.beta, self.eps)

return x


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def get_repeat_interleave_op():
# provide better performance for static shape in graph mode
return ops.repeat_interleave
else:
# FIXME: check overflow for v2
# return repeat_interleave_ext_v2
return repeat_interleave_ext


Expand Down
12 changes: 8 additions & 4 deletions examples/opensora_hpcai/opensora/models/stdit/stdit3.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@


class STDiT3Block(nn.Cell):
# to reduce compilation time
@ms.lazy_inline(policy="front")
def __init__(
self,
hidden_size,
Expand Down Expand Up @@ -100,9 +102,10 @@ def construct(
)

# modulate (attention)
x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
norm1 = self.norm1(x)
x_m = t2i_modulate(norm1, shift_msa, scale_msa)
# frames mask branch
x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero, scale_msa_zero)
x_m_zero = t2i_modulate(norm1, shift_msa_zero, scale_msa_zero)
x_m = t_mask_select(frames_mask, x_m, x_m_zero, T, S)

# attention
Expand All @@ -128,9 +131,10 @@ def construct(
x = x + self.cross_attn(x, y, mask)

# modulate (MLP)
x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)
norm2 = self.norm2(x)
x_m = t2i_modulate(norm2, shift_mlp, scale_mlp)
# frames mask branch
x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero, scale_mlp_zero)
x_m_zero = t2i_modulate(norm2, shift_mlp_zero, scale_mlp_zero)
x_m = t_mask_select(frames_mask, x_m, x_m_zero, T, S)

# MLP
Expand Down
14 changes: 8 additions & 6 deletions examples/opensora_hpcai/opensora/models/vae/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

# import mindspore as ms
from mindspore import nn, ops

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -177,6 +178,7 @@ def make_attn(in_channels, attn_type="vanilla"):

# used in vae
class Encoder(nn.Cell):
# @ms.lazy_inline()
def __init__(
self,
*,
Expand Down Expand Up @@ -274,18 +276,18 @@ def construct(self, x):
temb = None

# downsampling
hs = [self.conv_in(x)]
hs = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
h = self.down[i_level].block[i_block](hs, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
hs = h
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
hs = self.down[i_level].downsample(hs)

# middle
h = hs[-1]
h = hs
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
Expand All @@ -297,8 +299,8 @@ def construct(self, x):
return h


# used in vae
class Decoder(nn.Cell):
# @ms.lazy_inline()
def __init__(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions examples/opensora_hpcai/opensora/models/vae/vae_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def get_activation_fn(activation):
class Encoder(nn.Cell):
"""Encoder Blocks."""

# @ms.lazy_inline()
def __init__(
self,
in_out_channels=4,
Expand Down Expand Up @@ -260,6 +261,7 @@ def construct(self, x):
class Decoder(nn.Cell):
"""Decoder Blocks."""

# @ms.lazy_inline()
def __init__(
self,
in_out_channels=4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __call__(
mask_t = frames_mask * self.num_timesteps
x0 = z.copy()
x_noise = self.scheduler.add_noise(x0, ops.randn_like(x0), t)
# x_noise = self.scheduler.add_noise(x0, ms.Tensor(np.random.randn(*x0.shape), dtype=ms.float32), t)

model_kwargs["frames_mask"] = mask_t_upper = mask_t >= t.unsqueeze(1)
mask_add_noise = (mask_t_upper * (1 - noise_added)).astype(dtype.bool_)
Expand Down
Loading

0 comments on commit 7057ddb

Please sign in to comment.