Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
SamitHuang committed Sep 20, 2024
1 parent 26626e6 commit a7a073f
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 32 deletions.
8 changes: 2 additions & 6 deletions examples/opensora_hpcai/opensora/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def __init__(self, hidden_size, eps=1e-6):
self.variance_epsilon = eps

def construct(self, hidden_states: Tensor):
# variance = hidden_states.pow(2).mean(-1, keep_dims=True)
# hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon)
# return self.gamma * hidden_states

return ops.rms_norm(hidden_states, self.gamma, self.variance_epsilon)[0]


Expand Down Expand Up @@ -326,12 +324,10 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine: bool = True,
else:
self.gamma = ops.ones(normalized_shape, dtype=dtype)
self.beta = ops.zeros(normalized_shape, dtype=dtype)
# self.layer_norm = ops.LayerNorm(-1, -1, epsilon=eps)

def construct(self, x: Tensor):
# x, _, _ = self.layer_norm(x, self.gamma, self.beta)

normalized_shape = x.shape[-1:]
# mint layer_norm fuses the operations in layer normorlization and it's faster than ops.LayerNorm
x = mint.nn.functional.layer_norm(x, normalized_shape, self.gamma, self.beta, self.eps)

return x
Expand Down
14 changes: 0 additions & 14 deletions examples/opensora_hpcai/opensora/models/vae/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,20 +276,6 @@ def construct(self, x):
temb = None

# downsampling
"""
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
"""
hs = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
Expand Down
2 changes: 0 additions & 2 deletions examples/opensora_hpcai/opensora/schedulers/rectified_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
except ImportError:
from typing_extensions import Literal # FIXME: python 3.7

# import numpy as np
# import mindspore as ms
from tqdm import tqdm

from mindspore import Tensor, dtype, ops
Expand Down
18 changes: 8 additions & 10 deletions examples/opensora_hpcai/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,6 @@ def init_env(
if dynamic_shape:
logger.info("Dynamic shape mode enabled, repeat_interleave/split/chunk will be called from mint module")
set_dynamic_mode(True)
# if mode == 0:
# FIXME: this is a temp fix for dynamic shape training in graph mode. may remove in future version.
# can append adamw fusion flag if use nn.AdamW optimzation for acceleration
# ms.set_context(graph_kernel_flags="--disable_packet_ops=Reshape")

return rank_id, device_num

Expand Down Expand Up @@ -563,14 +559,16 @@ def main(args):

# compute total steps and data epochs (in unit of data sink size)
if args.dataset_sink_mode and args.sink_size != -1:
steps_per_sink = args.sink_size
# in data sink mode, data sink size determines the number of training steps per epoch.
steps_per_epoch = args.sink_size
else:
steps_per_sink = dataset_size
# without data sink, number of training steps is determined by number of data batches of the whole training set.
steps_per_epoch = dataset_size

if args.train_steps == -1:
assert args.epochs != -1
total_train_steps = args.epochs * dataset_size
sink_epochs = math.ceil(total_train_steps / steps_per_sink)
sink_epochs = math.ceil(total_train_steps / steps_per_epoch)
else:
total_train_steps = args.train_steps
# asume one step need one whole epoch data to ensure enough batch loading for training
Expand All @@ -585,11 +583,11 @@ def main(args):
ckpt_save_interval = args.ckpt_save_steps
else:
# still need to count interval in sink epochs
ckpt_save_interval = max(1, args.ckpt_save_steps // steps_per_sink)
if args.ckpt_save_steps % steps_per_sink != 0:
ckpt_save_interval = max(1, args.ckpt_save_steps // steps_per_epoch)
if args.ckpt_save_steps % steps_per_epoch != 0:
logger.warning(
f"`ckpt_save_steps` must be times of sink size or dataset_size under dataset sink mode."
f"Checkpoint will be saved every {ckpt_save_interval * steps_per_sink} steps."
f"Checkpoint will be saved every {ckpt_save_interval * steps_per_epoch} steps."
)
step_mode = step_mode if args.step_mode is None else args.step_mode

Expand Down

0 comments on commit a7a073f

Please sign in to comment.