From 7057ddba12ae4837e73c888a126b62ecf5f9505a Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Fri, 27 Sep 2024 12:04:13 +0800 Subject: [PATCH] Fix OpenSora 1.2 training (#643) * lazy_inline to reduce compile time * lazy_inline to reduce compile time * leave large mem for communication, max_device_memory 55GB * fix train_steps bug * fix linting * add lazy_inline to vae encoder/decoder * rm lazy inline for vae due to perf drop * rm lazy_inline in vae * only require decord when backend selected * fix logging * fix logging * x1: rm duplicated norm * x-1: use ops.rms_norm, mint.layer_norm * x-2: rm hs list in vae encode * x-3: use self-impl repeat interleave * fix layernorm * record shape * balance bucket config for A+M * revert repeat interleave for safe * increase bs for 256 res * add shape step time analysis script * fix stop * rm pdb * acc by add Symbol * clear mem in the end of epoch * update doc * impr bucket analysis * add stage3 balanced bucket * fix lint * fix linting * Update README.md * add comments --- examples/opensora_hpcai/README.md | 29 +++++-- .../opensora-v1-2/train/train_stage2_ms.yaml | 84 +++++++++++++++++++ .../opensora-v1-2/train/train_stage3_ms.yaml | 83 ++++++++++++++++++ .../datasets/video_dataset_refactored.py | 3 +- .../opensora/models/layers/blocks.py | 12 +-- .../models/layers/operation_selector.py | 2 + .../opensora/models/stdit/stdit3.py | 12 ++- .../opensora/models/vae/modules.py | 14 ++-- .../opensora/models/vae/vae_temporal.py | 2 + .../opensora/schedulers/rectified_flow.py | 1 + .../scripts/run/run_train_os1.2_stage2.sh | 4 +- examples/opensora_hpcai/scripts/train.py | 59 ++++++------- .../tools/analyze_bucket_result.py | 64 ++++++++++++++ .../tools/annotate_stdit_mix100.py | 56 +++++++++++++ mindone/trainers/callback.py | 2 +- mindone/trainers/recorder.py | 2 + 16 files changed, 373 insertions(+), 56 deletions(-) create mode 100644 examples/opensora_hpcai/configs/opensora-v1-2/train/train_stage2_ms.yaml create mode 100644 examples/opensora_hpcai/configs/opensora-v1-2/train/train_stage3_ms.yaml create mode 100644 examples/opensora_hpcai/tools/analyze_bucket_result.py create mode 100644 examples/opensora_hpcai/tools/annotate_stdit_mix100.py diff --git a/examples/opensora_hpcai/README.md b/examples/opensora_hpcai/README.md index 6b4bfd1bef..8055800671 100644 --- a/examples/opensora_hpcai/README.md +++ b/examples/opensora_hpcai/README.md @@ -159,9 +159,18 @@ Other useful documents and links are listed below. ## Installation 1. Install MindSpore according to the [official instructions](https://www.mindspore.cn/install). - For Ascend devices, please install **CANN driver C18 (0705)** from [here](https://repo.mindspore.cn/ascend/ascend910/20240705/) and install **MindSpore 2.3** from [here](https://www.mindspore.cn/install). + For Ascend devices, please install [CANN8.0.RC2.beta1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.0.RC2.beta1) and install [MindSpore 2.3.1](https://www.mindspore.cn/install). + > To reduce compilation time and training time, you may install MindSpore2.4-20240904 from [here](https://repo.mindspore.cn/mindspore/mindspore/version/202409/20240904/master_20240904010023_67b5df247045f509c4ca2169bac6a551291a3111_newest/unified/aarch64/) -2. Install requirements + You may check your versions by running the following commands. The default installation path of CANN is usually `/usr/local/Ascend/ascend-toolkit` unless you specify a custom one. + + ```bash + cat /usr/local/Ascend/ascend-toolkit/latest/version.cfg + + python -c "import mindspore;mindspore.set_context(device_target='Ascend');mindspore.run_check()" + ``` + +3. Install requirements ```bash pip install -r requirements.txt ``` @@ -517,9 +526,11 @@ bucket_config: "2048": { 1: [ 0.01, 5 ] } ``` +Knowing that the optimal bucket config can varies from device to device, we have tuned and provided bucket config that are more balanced on Ascend + MindSpore in `configs/opensora-v1-2/train/{stage}_ms.yaml`. You may use them for better training performance. + More details on the bucket configuration can be found in [Multi-resolution Training with Buckets](./docs/quick_start.md#4-multi-resolution-training-with-buckets-opensora-v11-and-above). -Then you can launch the dynamic training task following the previous section. An example running script is `scripts/run/run_train_os1.2_stage2.sh`. +The instruction for launching the dynamic training task is smilar to the previous section. An example running script is `scripts/run/run_train_os1.2_stage2.sh`. ### Open-Sora 1.1 @@ -622,11 +633,13 @@ Here ✅ means that the data is seen during training, and 🆗 means although no We evaluate the training performance of Open-Sora v1.2 on the MixKit dataset with high-resolution videos (1080P, duration 12s to 100s). The results are as follows. -| Model | Context | jit_level | Precision | BS | NPUs | Size (TxHxW) | Train T. (s/step) | -|:------------|:-------------|:--------|:---------:|:--:|:----:|:----------------------:|:-----------------:| -| STDiT3-XL/2 | D910\*-[C18](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3](https://www.mindspore.cn/install) | O1 | BF16 | 1 | 8 | 51x720x1280 | **14.60** | -| STDiT3-XL/2 | D910\*-[C18](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3.1(0726)](https://repo.mindspore.cn/mindspore/mindspore/version/202407/20240726/master_20240726220021_4c913fb116c83b9ad28666538483264da8aebe8c_newest/unified/) | O1 | BF16 | 1 | 8 | Stage 2 Dyn. | **33.10** | -| STDiT3-XL/2 | D910\*-[C18](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3.1(0726)](https://repo.mindspore.cn/mindspore/mindspore/version/202407/20240726/master_20240726220021_4c913fb116c83b9ad28666538483264da8aebe8c_newest/unified/) | O1 | BF16 | 1 | 8 | Stage 3 Dyn. | **37.7** | +| Model | Context | jit_level | Precision | BS | NPUs | Size (TxHxW) | Train T. (s/step) | config | +|:------------|:-------------|:--------|:---------:|:--:|:----:|:----------------------:|:-----------------:|:-----------------:| +| STDiT3-XL/2 | D910\*-[C18](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3](https://www.mindspore.cn/install) | O1 | BF16 | 1 | 8 | 51x720x1280 | **14.60** | [yaml](configs/opensora-v1-2/train/train_720x1280x51.yaml) | +| STDiT3-XL/2 | D910\*-[C18](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3.1(0726)](https://repo.mindspore.cn/mindspore/mindspore/version/202407/20240726/master_20240726220021_4c913fb116c83b9ad28666538483264da8aebe8c_newest/unified/) | O1 | BF16 | 1 | 8 | Stage 2 Dyn. | **33.10** | [yaml](configs/opensora-v1-2/train/train_stage2.yaml) | +| STDiT3-XL/2 | D910\*-[C18](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3.1(0726)](https://repo.mindspore.cn/mindspore/mindspore/version/202407/20240726/master_20240726220021_4c913fb116c83b9ad28666538483264da8aebe8c_newest/unified/) | O1 | BF16 | 1 | 8 | Stage 3 Dyn. | **34** | [yaml](configs/opensora-v1-2/train/train_stage3.yaml) | + + > Context: {G:GPU, D:Ascend}{chip type}-{CANN version}-{mindspore version}; "Dyn." is short for dynamic shape. Note that the step time of dynamic training can be influenced by the resolution and duration distribution of the source videos. Training performance is under optimization. diff --git a/examples/opensora_hpcai/configs/opensora-v1-2/train/train_stage2_ms.yaml b/examples/opensora_hpcai/configs/opensora-v1-2/train/train_stage2_ms.yaml new file mode 100644 index 0000000000..b5f3c6bc59 --- /dev/null +++ b/examples/opensora_hpcai/configs/opensora-v1-2/train/train_stage2_ms.yaml @@ -0,0 +1,84 @@ +# model +model_version: v1.2 +pretrained_model_path: PATH_TO_YOUR_MODEL +model_max_length: 300 +freeze_y_embedder: True + +noise_scheduler: rflow +sample_method: logit-normal +use_timestep_transform: True + +vae_type: OpenSoraVAE_V1_2 +vae_checkpoint: models/OpenSora-VAE-v1.2/model.ckpt +vae_dtype: bf16 +vae_micro_batch_size: 4 +vae_micro_frame_size: 17 # keep it unchanged for the best results + +enable_flash_attention: True +use_recompute: True + +# data +num_parallel_workers: 2 +num_workers_dataset: 2 +prefetch_size: 2 +max_rowsize: 256 + +# mindspore params, refer to https://www.mindspore.cn/docs/zh-CN/r2.3.1/api_python/mindspore/mindspore.set_context.html +max_device_memory: "59GB" +jit_level: "O1" +manual_pad: True + +# precision +amp_level: "O2" +dtype: bf16 +loss_scaler_type: static +init_loss_scale: 1 + +# training hyper-params +scheduler: "constant" +start_learning_rate: 1.e-4 +end_learning_rate: 1.e-4 +# warmup_steps: 1000 + +clip_grad: True +max_grad_norm: 1.0 +use_ema: True +# ema_decay: 0.99 # default 0.9999 gives better result in our experiments + +optim: "adamw_re" +optim_eps: 1e-15 +weight_decay: 0. + +# epochs: 2 +train_steps: 23000 +ckpt_save_steps: 500 + +mask_ratios: + random: 0.005 + interpolate: 0.002 + quarter_random: 0.007 + quarter_head: 0.002 + quarter_tail: 0.002 + quarter_head_tail: 0.002 + image_random: 0.0 + image_head: 0.22 + image_tail: 0.005 + image_head_tail: 0.005 + +bucket_config: + # Structure: "resolution": { num_frames: [ keep_prob, batch_size ] } + # Setting [ keep_prob, batch_size ] to [ 0.0, 0 ] forces longer videos into smaller resolution buckets + "144p": { 1: [ 1.0, 475 ], 51: [ 1.0, 40 ], 102: [ [ 1.0, 0.33 ], 20 ], 204: [ [ 1.0, 0.1 ], 10 ], 408: [ [ 1.0, 0.1 ], 6 ] } + "256": { 1: [ 0.4, 297 ], 51: [ 0.5, 24 ], 102: [ [ 0.5, 0.33 ], 12 ], 204: [ [ 0.5, 1.0 ], 6 ], 408: [ [ 0.5, 1.0 ], 2 ] } + "240p": { 1: [ 0.3, 297 ], 51: [ 0.4, 16 ], 102: [ [ 0.4, 0.33 ], 8 ], 204: [ [ 0.4, 1.0 ], 4 ], 408: [ [ 0.4, 1.0 ], 2 ] } + "360p": { 1: [ 0.5, 141 ], 51: [ 0.15, 6 ], 102: [ [ 0.3, 0.5 ], 3 ], 204: [ [ 0.3, 1.0 ], 2 ], 408: [ [ 0.5, 0.5 ], 1 ] } + "512": { 1: [ 0.4, 141 ], 51: [ 0.15, 6 ], 102: [ [ 0.2, 0.4 ], 3 ], 204: [ [ 0.2, 1.0 ], 1 ], 408: [ [ 0.4, 0.5 ], 1 ] } + "480p": { 1: [ 0.5, 89 ], 51: [ 0.2, 5 ], 102: [ 0.2, 2 ], 204: [ 0.1, 1 ] } + "720p": { 1: [ 0.1, 36 ], 51: [ 0.03, 1 ] } + "1024": { 1: [ 0.1, 36 ], 51: [ 0.02, 1 ] } + "1080p": { 1: [ 0.01, 5 ] } + "2048": { 1: [ 0.01, 5 ] } + + +# ---------- Validation ---------- +validate: False diff --git a/examples/opensora_hpcai/configs/opensora-v1-2/train/train_stage3_ms.yaml b/examples/opensora_hpcai/configs/opensora-v1-2/train/train_stage3_ms.yaml new file mode 100644 index 0000000000..65708812e0 --- /dev/null +++ b/examples/opensora_hpcai/configs/opensora-v1-2/train/train_stage3_ms.yaml @@ -0,0 +1,83 @@ +# model +model_version: v1.2 +pretrained_model_path: PATH_TO_YOUR_MODEL +model_max_length: 300 +freeze_y_embedder: True + +noise_scheduler: rflow +sample_method: logit-normal +use_timestep_transform: True + +vae_type: OpenSoraVAE_V1_2 +vae_checkpoint: models/OpenSora-VAE-v1.2/model.ckpt +vae_dtype: bf16 +vae_micro_batch_size: 4 +vae_micro_frame_size: 17 # keep it unchanged for the best results + +enable_flash_attention: True +use_recompute: True + +# data +num_parallel_workers: 2 +num_workers_dataset: 2 +prefetch_size: 2 +max_rowsize: 256 + +# precision +amp_level: "O2" +dtype: bf16 +loss_scaler_type: static +init_loss_scale: 1 + +# mindspore params, refer to https://www.mindspore.cn/docs/zh-CN/r2.3.1/api_python/mindspore/mindspore.set_context.html +max_device_memory: "59GB" +jit_level: "O1" +manual_pad: True + +# training hyper-params +scheduler: "constant" +start_learning_rate: 1.e-4 +end_learning_rate: 1.e-4 +warmup_steps: 1000 + +clip_grad: True +max_grad_norm: 1.0 +use_ema: True +# ema_decay: 0.99 + +optim: "adamw_re" +optim_eps: 1e-15 +weight_decay: 0. + +# epochs: 15 +train_steps: 15000 +ckpt_save_steps: 500 + +mask_ratios: + random: 0.01 + interpolate: 0.002 + quarter_random: 0.002 + quarter_head: 0.002 + quarter_tail: 0.002 + quarter_head_tail: 0.002 + image_random: 0.0 + image_head: 0.22 + image_tail: 0.005 + image_head_tail: 0.005 + +bucket_config: + # Structure: "resolution": { num_frames: [ keep_prob, batch_size ] } + # Setting [ keep_prob, batch_size ] to [ 0.0, 0 ] forces longer videos into smaller resolution buckets + "144p": {1: [1.0, 475], 51: [1.0, 51], 102: [1.0, 27], 204: [1.0, 13], 408: [1.0, 6]} + "256": {1: [1.0, 297], 51: [0.5, 20], 102: [0.5, 10], 204: [0.5, 6], 408: [[0.5, 0.5], 2]} + "240p": {1: [1.0, 297], 51: [0.5, 20], 102: [0.5, 10], 204: [0.5, 5], 408: [[0.5, 0.4], 2]} + "360p": {1: [1.0, 141], 51: [0.5, 8], 102: [0.5, 4], 204: [0.5, 2], 408: [[0.5, 0.3], 1]} + "512": {1: [1.0, 141], 51: [0.5, 8], 102: [0.5, 4], 204: [0.5, 2], 408: [[0.5, 0.2], 1]} + "480p": {1: [1.0, 89], 51: [0.5, 5], 102: [0.5, 2], 204: [[0.5, 0.5], 1], 408: [0.0, 0]} + "720p": {1: [0.3, 36], 51: [0.2, 2], 102: [0.1, 1], 204: [0.0, 0]} + "1024": {1: [0.3, 36], 51: [0.1, 2], 102: [0.1, 1], 204: [0.0, 0]} + "1080p": {1: [0.1, 5]} + "2048": {1: [0.05, 5]} + +# ---------- Validation ---------- +validate: False diff --git a/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py b/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py index a0be3759d2..8c91f917c0 100644 --- a/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py +++ b/examples/opensora_hpcai/opensora/datasets/video_dataset_refactored.py @@ -10,7 +10,6 @@ import cv2 import numpy as np -from decord import VideoReader from tqdm import tqdm import mindspore as ms @@ -252,6 +251,8 @@ def _get_item(self, idx: int) -> Tuple[Any, ...]: else: if self.video_backend == "decord": + from decord import VideoReader + reader = VideoReader(video_path) min_length = self._min_length video_length = len(reader) diff --git a/examples/opensora_hpcai/opensora/models/layers/blocks.py b/examples/opensora_hpcai/opensora/models/layers/blocks.py index 86c1ded846..725f9e0bb5 100644 --- a/examples/opensora_hpcai/opensora/models/layers/blocks.py +++ b/examples/opensora_hpcai/opensora/models/layers/blocks.py @@ -5,7 +5,7 @@ import numpy as np import mindspore as ms -from mindspore import Parameter, Tensor, nn, ops +from mindspore import Parameter, Tensor, mint, nn, ops from mindspore.common.initializer import initializer from mindone.models.modules.flash_attention import FLASH_IS_AVAILABLE, MSFlashAttention @@ -26,9 +26,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def construct(self, hidden_states: Tensor): - variance = hidden_states.pow(2).mean(-1, keep_dims=True) - hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon) - return self.gamma * hidden_states + return ops.rms_norm(hidden_states, self.gamma, self.variance_epsilon)[0] class Attention(nn.Cell): @@ -325,10 +323,12 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine: bool = True, else: self.gamma = ops.ones(normalized_shape, dtype=dtype) self.beta = ops.zeros(normalized_shape, dtype=dtype) - self.layer_norm = ops.LayerNorm(-1, -1, epsilon=eps) def construct(self, x: Tensor): - x, _, _ = self.layer_norm(x, self.gamma, self.beta) + normalized_shape = x.shape[-1:] + # mint layer_norm fuses the operations in layer normorlization and it's faster than ops.LayerNorm + x = mint.nn.functional.layer_norm(x, normalized_shape, self.gamma, self.beta, self.eps) + return x diff --git a/examples/opensora_hpcai/opensora/models/layers/operation_selector.py b/examples/opensora_hpcai/opensora/models/layers/operation_selector.py index b50c7f655d..269f3966c3 100644 --- a/examples/opensora_hpcai/opensora/models/layers/operation_selector.py +++ b/examples/opensora_hpcai/opensora/models/layers/operation_selector.py @@ -56,6 +56,8 @@ def get_repeat_interleave_op(): # provide better performance for static shape in graph mode return ops.repeat_interleave else: + # FIXME: check overflow for v2 + # return repeat_interleave_ext_v2 return repeat_interleave_ext diff --git a/examples/opensora_hpcai/opensora/models/stdit/stdit3.py b/examples/opensora_hpcai/opensora/models/stdit/stdit3.py index a1294168da..6a4b20cabb 100644 --- a/examples/opensora_hpcai/opensora/models/stdit/stdit3.py +++ b/examples/opensora_hpcai/opensora/models/stdit/stdit3.py @@ -35,6 +35,8 @@ class STDiT3Block(nn.Cell): + # to reduce compilation time + @ms.lazy_inline(policy="front") def __init__( self, hidden_size, @@ -100,9 +102,10 @@ def construct( ) # modulate (attention) - x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa) + norm1 = self.norm1(x) + x_m = t2i_modulate(norm1, shift_msa, scale_msa) # frames mask branch - x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero, scale_msa_zero) + x_m_zero = t2i_modulate(norm1, shift_msa_zero, scale_msa_zero) x_m = t_mask_select(frames_mask, x_m, x_m_zero, T, S) # attention @@ -128,9 +131,10 @@ def construct( x = x + self.cross_attn(x, y, mask) # modulate (MLP) - x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp) + norm2 = self.norm2(x) + x_m = t2i_modulate(norm2, shift_mlp, scale_mlp) # frames mask branch - x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero, scale_mlp_zero) + x_m_zero = t2i_modulate(norm2, shift_mlp_zero, scale_mlp_zero) x_m = t_mask_select(frames_mask, x_m, x_m_zero, T, S) # MLP diff --git a/examples/opensora_hpcai/opensora/models/vae/modules.py b/examples/opensora_hpcai/opensora/models/vae/modules.py index 539afc26c9..b98a7ca45a 100644 --- a/examples/opensora_hpcai/opensora/models/vae/modules.py +++ b/examples/opensora_hpcai/opensora/models/vae/modules.py @@ -2,6 +2,7 @@ import numpy as np +# import mindspore as ms from mindspore import nn, ops _logger = logging.getLogger(__name__) @@ -177,6 +178,7 @@ def make_attn(in_channels, attn_type="vanilla"): # used in vae class Encoder(nn.Cell): + # @ms.lazy_inline() def __init__( self, *, @@ -274,18 +276,18 @@ def construct(self, x): temb = None # downsampling - hs = [self.conv_in(x)] + hs = self.conv_in(x) for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) + h = self.down[i_level].block[i_block](hs, temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) - hs.append(h) + hs = h if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) + hs = self.down[i_level].downsample(hs) # middle - h = hs[-1] + h = hs h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) @@ -297,8 +299,8 @@ def construct(self, x): return h -# used in vae class Decoder(nn.Cell): + # @ms.lazy_inline() def __init__( self, *, diff --git a/examples/opensora_hpcai/opensora/models/vae/vae_temporal.py b/examples/opensora_hpcai/opensora/models/vae/vae_temporal.py index 733f1c20c8..1151a36446 100644 --- a/examples/opensora_hpcai/opensora/models/vae/vae_temporal.py +++ b/examples/opensora_hpcai/opensora/models/vae/vae_temporal.py @@ -150,6 +150,7 @@ def get_activation_fn(activation): class Encoder(nn.Cell): """Encoder Blocks.""" + # @ms.lazy_inline() def __init__( self, in_out_channels=4, @@ -260,6 +261,7 @@ def construct(self, x): class Decoder(nn.Cell): """Decoder Blocks.""" + # @ms.lazy_inline() def __init__( self, in_out_channels=4, diff --git a/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py b/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py index 445e105ad4..0be394ca92 100644 --- a/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py +++ b/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py @@ -72,6 +72,7 @@ def __call__( mask_t = frames_mask * self.num_timesteps x0 = z.copy() x_noise = self.scheduler.add_noise(x0, ops.randn_like(x0), t) + # x_noise = self.scheduler.add_noise(x0, ms.Tensor(np.random.randn(*x0.shape), dtype=ms.float32), t) model_kwargs["frames_mask"] = mask_t_upper = mask_t >= t.unsqueeze(1) mask_add_noise = (mask_t_upper * (1 - noise_added)).astype(dtype.bool_) diff --git a/examples/opensora_hpcai/scripts/run/run_train_os1.2_stage2.sh b/examples/opensora_hpcai/scripts/run/run_train_os1.2_stage2.sh index 5c6c440bc8..13a92392a0 100644 --- a/examples/opensora_hpcai/scripts/run/run_train_os1.2_stage2.sh +++ b/examples/opensora_hpcai/scripts/run/run_train_os1.2_stage2.sh @@ -18,8 +18,8 @@ python scripts/train.py \ --pretrained_model_path="models/OpenSora-STDiT-v3/opensora_stdit_v3.ckpt" \ --mode=0 \ --jit_level O1 \ ---max_device_memory 59GB \ ---config configs/opensora-v1-2/train/train_stage2.yaml \ +--max_device_memory 55GB \ +--config configs/opensora-v1-2/train/train_stage2_ms.yaml \ --csv_path datasets/mixkit-100videos/video_caption_train.csv \ --video_folder datasets/mixkit-100videos/mixkit \ --text_embed_folder datasets/mixkit-100videos/t5_emb_300 \ diff --git a/examples/opensora_hpcai/scripts/train.py b/examples/opensora_hpcai/scripts/train.py index 0e42e70590..d554b62e88 100644 --- a/examples/opensora_hpcai/scripts/train.py +++ b/examples/opensora_hpcai/scripts/train.py @@ -147,10 +147,6 @@ def init_env( if dynamic_shape: logger.info("Dynamic shape mode enabled, repeat_interleave/split/chunk will be called from mint module") set_dynamic_mode(True) - if mode == 0: - # FIXME: this is a temp fix for dynamic shape training in graph mode. may remove in future version. - # can append adamw fusion flag if use nn.AdamW optimzation for acceleration - ms.set_context(graph_kernel_flags="--disable_packet_ops=Reshape") return rank_id, device_num @@ -299,7 +295,7 @@ def initialize_dataset( bucket_boundaries, bucket_batch_sizes, element_length_function=hash_func, - drop_remainder=not validation, + drop_remainder=False, ) return dataloader, num_src_samples @@ -562,17 +558,21 @@ def main(args): ) # compute total steps and data epochs (in unit of data sink size) + if args.dataset_sink_mode and args.sink_size != -1: + # in data sink mode, data sink size determines the number of training steps per epoch. + steps_per_epoch = args.sink_size + else: + # without data sink, number of training steps is determined by number of data batches of the whole training set. + steps_per_epoch = dataset_size + if args.train_steps == -1: assert args.epochs != -1 total_train_steps = args.epochs * dataset_size + sink_epochs = math.ceil(total_train_steps / steps_per_epoch) else: total_train_steps = args.train_steps - - if args.dataset_sink_mode and args.sink_size != -1: - steps_per_sink = args.sink_size - else: - steps_per_sink = dataset_size - sink_epochs = math.ceil(total_train_steps / steps_per_sink) + # asume one step need one whole epoch data to ensure enough batch loading for training + sink_epochs = total_train_steps if args.ckpt_save_steps == -1: ckpt_save_interval = args.ckpt_save_interval @@ -583,11 +583,11 @@ def main(args): ckpt_save_interval = args.ckpt_save_steps else: # still need to count interval in sink epochs - ckpt_save_interval = max(1, args.ckpt_save_steps // steps_per_sink) - if args.ckpt_save_steps % steps_per_sink != 0: + ckpt_save_interval = max(1, args.ckpt_save_steps // steps_per_epoch) + if args.ckpt_save_steps % steps_per_epoch != 0: logger.warning( f"`ckpt_save_steps` must be times of sink size or dataset_size under dataset sink mode." - f"Checkpoint will be saved every {ckpt_save_interval * steps_per_sink} steps." + f"Checkpoint will be saved every {ckpt_save_interval * steps_per_epoch} steps." ) step_mode = step_mode if args.step_mode is None else args.step_mode @@ -677,16 +677,17 @@ def main(args): resume_train_net(net_with_grads, resume_ckpt) if (args.mode == 0) and (args.bucket_config is not None): - video = ms.Tensor(shape=[None, None, 3, None, None], dtype=ms.float32) - caption = ms.Tensor(shape=[None, args.model_max_length, 4096], dtype=ms.float32) - mask = ms.Tensor(shape=[None, args.model_max_length], dtype=ms.uint8) - frames_mask = ms.Tensor(shape=[None, None], dtype=ms.bool_) + _bs = ms.Symbol(unique=True) + video = ms.Tensor(shape=[_bs, None, 3, None, None], dtype=ms.float32) + caption = ms.Tensor(shape=[_bs, args.model_max_length, 4096], dtype=ms.float32) + mask = ms.Tensor(shape=[_bs, args.model_max_length], dtype=ms.uint8) + frames_mask = ms.Tensor(shape=[_bs, None], dtype=ms.bool_) # fmt: off - num_frames = ms.Tensor(shape=[None, ], dtype=ms.float32) - height = ms.Tensor(shape=[None, ], dtype=ms.float32) - width = ms.Tensor(shape=[None, ], dtype=ms.float32) - fps = ms.Tensor(shape=[None, ], dtype=ms.float32) - ar = ms.Tensor(shape=[None, ], dtype=ms.float32) + num_frames = ms.Tensor(shape=[_bs, ], dtype=ms.float32) + height = ms.Tensor(shape=[_bs, ], dtype=ms.float32) + width = ms.Tensor(shape=[_bs, ], dtype=ms.float32) + fps = ms.Tensor(shape=[_bs, ], dtype=ms.float32) + ar = ms.Tensor(shape=[_bs, ], dtype=ms.float32) # fmt: on net_with_grads.set_inputs(video, caption, mask, frames_mask, num_frames, height, width, fps, ar) logger.info("Dynamic inputs are initialized for bucket config training in Graph mode!") @@ -731,11 +732,12 @@ def main(args): resume=args.resume, ) callbacks.extend([save_cb, rec_cb]) - if args.train_steps > 0: - callbacks.append(StopAtStepCallback(args.train_steps, global_step=cur_iter)) if args.profile: callbacks.append(ProfilerCallbackEpoch(2, 3, "./profile_data")) + if args.train_steps > 0: + callbacks.append(StopAtStepCallback(args.train_steps, global_step=cur_iter)) + # 5. log and save config if rank_id == 0: if vae is not None: @@ -819,7 +821,7 @@ def main(args): ckpt_manager = CheckpointManager(ckpt_dir, "latest_k", k=args.ckpt_max_keep) if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) - perf_columns = ["step", "loss", "train_time(s)"] + perf_columns = ["step", "loss", "train_time(s)", "shape"] output_dir = ckpt_dir.replace("/ckpt", "") if start_epoch == 0: record = PerfRecorder(output_dir, metric_names=perf_columns) @@ -844,13 +846,14 @@ def main(args): # print(data[0].shape) loss_val = float(loss.asnumpy()) logger.info( - f"Epoch {epoch}, Step {step}, loss {loss_val:.5f}, Global step {global_step}, Step time {step_time*1000:.2f}ms" + f"Epoch {epoch}, Step {step}, loss {loss_val:.5f}, Global step {global_step}," + + f" Shape: {tuple(data[0].shape)}, Step time {step_time*1000:.2f}ms" ) if overflow: logger.warning("overflow detected") if rank_id == 0: - step_pref_value = [global_step, loss_val, step_time] + step_pref_value = [global_step, loss_val, step_time, tuple(data[0].shape)] record.add(*step_pref_value) # save and eval in step if save_by_step and rank_id == 0: diff --git a/examples/opensora_hpcai/tools/analyze_bucket_result.py b/examples/opensora_hpcai/tools/analyze_bucket_result.py new file mode 100644 index 0000000000..11e02eb1d9 --- /dev/null +++ b/examples/opensora_hpcai/tools/analyze_bucket_result.py @@ -0,0 +1,64 @@ +import argparse +import math + +import pandas as pd + +# result_path = 'outputs/analyze_os1.2_stage2_vcg200_ms231/merged_result.log' + + +def analyze(result_path, save_path): + warmup_steps = 50 + max_step_time = 100 # step time larger than this value will be dropped, considering checkpoint saving + + data = pd.read_csv(result_path, sep="\t") + + # filter warmup stage + data = data.iloc[warmup_steps - 1 :] + + # filter out outliers + data = data[data["train_time(s)"] < max_step_time] + global_avg_step_time = data["train_time(s)"].mean() + num_steps = data.shape[0] + + res = data.groupby("shape")["train_time(s)"].agg(["mean", "std", "count"]).reset_index() + + res.columns = ["shape", "mean_step_time", "std_step_time", "occurence"] + + res["std_step_time"].fillna(0, inplace=True) + + res_sorted = res.sort_values(by="mean_step_time", ascending=False) + + percent_col = [] + bs_col = [] + sug_bs_col = [] + rnd_bs_col = [] + tot_occur = res_sorted["occurence"].sum() + # for shape_str in res_sorted['shape'].tolist(): + for idx, row in res_sorted.iterrows(): + percent_col.append(row["occurence"] / tot_occur) + bs = int(row["shape"].split(",")[0][1:]) + bs_col.append(bs) + + suggest_bs = (global_avg_step_time / row["mean_step_time"]) * bs + rounded_bs = max(1, math.floor(suggest_bs)) + sug_bs_col.append(suggest_bs) + rnd_bs_col.append(rounded_bs) + res_sorted["occ_percent"] = percent_col + res_sorted["bs"] = bs_col + res_sorted["suggested_bs"] = sug_bs_col + res_sorted["rounded_bs"] = rnd_bs_col + + res_sorted.to_csv(save_path, index=False) + + print(res_sorted) + print(f"\nAverage step time(s) (in {num_steps} steps starting from step {warmup_steps}): ", global_avg_step_time) + print("Analysis csv saved in", save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", "-i", type=str, default=None, help="path to result log") + parser.add_argument("--output", "-o", type=str, default="shape_step_time.csv", help="path to save analysis output") + + args = parser.parse_args() + analyze(args.input, args.output) diff --git a/examples/opensora_hpcai/tools/annotate_stdit_mix100.py b/examples/opensora_hpcai/tools/annotate_stdit_mix100.py new file mode 100644 index 0000000000..610700ad7e --- /dev/null +++ b/examples/opensora_hpcai/tools/annotate_stdit_mix100.py @@ -0,0 +1,56 @@ +import glob +import json +import math +import os +import random + +root_dir = "/home_host/ddd/workspace/datasets/mixkit-100videos/mixkit/" +annot_fp = "/home_host/ddd/workspace/datasets/mixkit-100videos/anno_jsons/video_mixkit_65f_54735.json" +output_csv = "video_caption.csv" + +# read paths +video_fps = sorted(glob.glob(os.path.join(root_dir, "*/*.mp4"))) +# remove header +video_fps = [fp.replace(root_dir, "") for fp in video_fps] +print(video_fps) + +fp_out = open(output_csv, "w") +fp_out.write("video,caption\n") + +matched_videos = [] +matched_captions = [] +with open(annot_fp, "r") as fp: + annot_list = json.load(fp) + for i, annot in enumerate(annot_list): + video_path = annot["path"] + if video_path in video_fps and video_path not in matched_videos: + caption = annot["cap"] + fp_out.write('{},"{}"\n'.format(video_path, caption)) + matched_videos.append(video_path) + matched_captions.append(caption) + +fp_out.close() +print("Num samples", len(matched_videos)) +print("csv saved in ", output_csv) + +# split into train and test +train_ratio = 0.8 +num_samples = len(matched_videos) +num_train = math.ceil(num_samples * train_ratio) +num_test = num_samples - num_train +vc_list = [(matched_videos[i], matched_captions[i]) for i in range(num_samples)] +random.shuffle(vc_list) + +train_set = sorted(vc_list[:num_train]) +test_set = sorted(vc_list[num_train:]) + + +def write_csv(vcl, save_path): + with open(save_path, "w") as fp: + fp.write("video,caption\n") + for vc in vcl: + fp.write('{},"{}"\n'.format(vc[0], vc[1])) + + +write_csv(train_set, output_csv.replace(".csv", "_train.csv")) +write_csv(test_set, output_csv.replace(".csv", "_test.csv")) diff --git a/mindone/trainers/callback.py b/mindone/trainers/callback.py index f1d0d6f5ac..bdd2e7c8a8 100755 --- a/mindone/trainers/callback.py +++ b/mindone/trainers/callback.py @@ -9,7 +9,7 @@ from .checkpoint import CheckpointManager from .recorder import PerfRecorder -_logger = logging.getLogger(__name__) +_logger = logging.getLogger("") __all__ = ["OverflowMonitor", "EvalSaveCallback", "ProfilerCallback", "StopAtStepCallback"] diff --git a/mindone/trainers/recorder.py b/mindone/trainers/recorder.py index ce43ce98e9..61e509d18c 100644 --- a/mindone/trainers/recorder.py +++ b/mindone/trainers/recorder.py @@ -42,6 +42,8 @@ def add(self, step, *measures): if isinstance(m, float) or isinstance(m, np.ndarray): line += f"{m:.7f}" + elif isinstance(m, tuple): + line += f"{m}" elif m is None: line += "NA" else: