From a7a073fadf87647af465355a6e20622abc144de9 Mon Sep 17 00:00:00 2001 From: Samit <285365963@qq.com> Date: Fri, 20 Sep 2024 14:37:56 +0800 Subject: [PATCH] add comments --- .../opensora/models/layers/blocks.py | 8 ++------ .../opensora/models/vae/modules.py | 14 -------------- .../opensora/schedulers/rectified_flow.py | 2 -- examples/opensora_hpcai/scripts/train.py | 18 ++++++++---------- 4 files changed, 10 insertions(+), 32 deletions(-) diff --git a/examples/opensora_hpcai/opensora/models/layers/blocks.py b/examples/opensora_hpcai/opensora/models/layers/blocks.py index c4b17422c8..4926d2c2c5 100644 --- a/examples/opensora_hpcai/opensora/models/layers/blocks.py +++ b/examples/opensora_hpcai/opensora/models/layers/blocks.py @@ -26,9 +26,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def construct(self, hidden_states: Tensor): - # variance = hidden_states.pow(2).mean(-1, keep_dims=True) - # hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon) - # return self.gamma * hidden_states + return ops.rms_norm(hidden_states, self.gamma, self.variance_epsilon)[0] @@ -326,12 +324,10 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine: bool = True, else: self.gamma = ops.ones(normalized_shape, dtype=dtype) self.beta = ops.zeros(normalized_shape, dtype=dtype) - # self.layer_norm = ops.LayerNorm(-1, -1, epsilon=eps) def construct(self, x: Tensor): - # x, _, _ = self.layer_norm(x, self.gamma, self.beta) - normalized_shape = x.shape[-1:] + # mint layer_norm fuses the operations in layer normorlization and it's faster than ops.LayerNorm x = mint.nn.functional.layer_norm(x, normalized_shape, self.gamma, self.beta, self.eps) return x diff --git a/examples/opensora_hpcai/opensora/models/vae/modules.py b/examples/opensora_hpcai/opensora/models/vae/modules.py index 83703faef6..b98a7ca45a 100644 --- a/examples/opensora_hpcai/opensora/models/vae/modules.py +++ b/examples/opensora_hpcai/opensora/models/vae/modules.py @@ -276,20 +276,6 @@ def construct(self, x): temb = None # downsampling - """ - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - """ hs = self.conv_in(x) for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): diff --git a/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py b/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py index 0464dd5387..0be394ca92 100644 --- a/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py +++ b/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py @@ -5,8 +5,6 @@ except ImportError: from typing_extensions import Literal # FIXME: python 3.7 -# import numpy as np -# import mindspore as ms from tqdm import tqdm from mindspore import Tensor, dtype, ops diff --git a/examples/opensora_hpcai/scripts/train.py b/examples/opensora_hpcai/scripts/train.py index d749b8f6e9..595eae95c1 100644 --- a/examples/opensora_hpcai/scripts/train.py +++ b/examples/opensora_hpcai/scripts/train.py @@ -147,10 +147,6 @@ def init_env( if dynamic_shape: logger.info("Dynamic shape mode enabled, repeat_interleave/split/chunk will be called from mint module") set_dynamic_mode(True) - # if mode == 0: - # FIXME: this is a temp fix for dynamic shape training in graph mode. may remove in future version. - # can append adamw fusion flag if use nn.AdamW optimzation for acceleration - # ms.set_context(graph_kernel_flags="--disable_packet_ops=Reshape") return rank_id, device_num @@ -563,14 +559,16 @@ def main(args): # compute total steps and data epochs (in unit of data sink size) if args.dataset_sink_mode and args.sink_size != -1: - steps_per_sink = args.sink_size + # in data sink mode, data sink size determines the number of training steps per epoch. + steps_per_epoch = args.sink_size else: - steps_per_sink = dataset_size + # without data sink, number of training steps is determined by number of data batches of the whole training set. + steps_per_epoch = dataset_size if args.train_steps == -1: assert args.epochs != -1 total_train_steps = args.epochs * dataset_size - sink_epochs = math.ceil(total_train_steps / steps_per_sink) + sink_epochs = math.ceil(total_train_steps / steps_per_epoch) else: total_train_steps = args.train_steps # asume one step need one whole epoch data to ensure enough batch loading for training @@ -585,11 +583,11 @@ def main(args): ckpt_save_interval = args.ckpt_save_steps else: # still need to count interval in sink epochs - ckpt_save_interval = max(1, args.ckpt_save_steps // steps_per_sink) - if args.ckpt_save_steps % steps_per_sink != 0: + ckpt_save_interval = max(1, args.ckpt_save_steps // steps_per_epoch) + if args.ckpt_save_steps % steps_per_epoch != 0: logger.warning( f"`ckpt_save_steps` must be times of sink size or dataset_size under dataset sink mode." - f"Checkpoint will be saved every {ckpt_save_interval * steps_per_sink} steps." + f"Checkpoint will be saved every {ckpt_save_interval * steps_per_epoch} steps." ) step_mode = step_mode if args.step_mode is None else args.step_mode