Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix OpenSora 1.2 training #643

Merged
merged 35 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
3359694
lazy_inline to reduce compile time
SamitHuang Aug 23, 2024
db6a4bf
lazy_inline to reduce compile time
SamitHuang Aug 23, 2024
6252f31
leave large mem for communication, max_device_memory 55GB
SamitHuang Aug 23, 2024
21fff2c
fix train_steps bug
SamitHuang Aug 28, 2024
d814975
Merge branch 'os1.2_stable_fix' of https://github.com/samithuang/mind…
SamitHuang Aug 28, 2024
5b5fe82
fix linting
SamitHuang Aug 28, 2024
22aed35
add lazy_inline to vae encoder/decoder
SamitHuang Aug 28, 2024
6c220c7
rm lazy inline for vae due to perf drop
SamitHuang Aug 31, 2024
60e7705
rm lazy_inline in vae
SamitHuang Sep 2, 2024
907dde4
Merge branch 'os1.2_stable_fix' of github.com:SamitHuang/mindone into…
SamitHuang Sep 2, 2024
b3a4ff9
only require decord when backend selected
SamitHuang Sep 2, 2024
8f918f6
fix logging
SamitHuang Sep 3, 2024
a52d80e
fix logging
SamitHuang Sep 3, 2024
a0d6632
x1: rm duplicated norm
SamitHuang Sep 3, 2024
1b5f3f1
x-1: use ops.rms_norm, mint.layer_norm
SamitHuang Sep 3, 2024
9205100
x-2: rm hs list in vae encode
SamitHuang Sep 3, 2024
c444d62
x-3: use self-impl repeat interleave
SamitHuang Sep 3, 2024
f59d166
fix layernorm
SamitHuang Sep 3, 2024
4594f9c
record shape
SamitHuang Sep 5, 2024
8855f4e
balance bucket config for A+M
SamitHuang Sep 5, 2024
560cd67
revert repeat interleave for safe
SamitHuang Sep 5, 2024
1f72aa5
increase bs for 256 res
SamitHuang Sep 5, 2024
b757917
add shape step time analysis script
SamitHuang Sep 5, 2024
147435e
fix stop
SamitHuang Sep 9, 2024
94cbbb4
rm pdb
SamitHuang Sep 9, 2024
277410c
acc by add Symbol
SamitHuang Sep 12, 2024
0a83a37
clear mem in the end of epoch
SamitHuang Sep 13, 2024
e706e2c
update doc
SamitHuang Sep 13, 2024
a9e7126
impr bucket analysis
SamitHuang Sep 13, 2024
f60148c
fix
SamitHuang Sep 13, 2024
86d6fef
add stage3 balanced bucket
SamitHuang Sep 17, 2024
f9c1c65
fix lint
SamitHuang Sep 19, 2024
2e9d7f1
fix linting
SamitHuang Sep 19, 2024
067bb3d
Update README.md
SamitHuang Sep 20, 2024
5c5e382
add comments
SamitHuang Sep 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/opensora_hpcai/opensora/models/stdit/stdit3.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@


class STDiT3Block(nn.Cell):
# to reduce compilation time
@ms.lazy_inline(policy="front")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vae的encoder和decoder也加下吧,解决动态shape OOM的问题

def __init__(
self,
hidden_size,
Expand Down
3 changes: 3 additions & 0 deletions examples/opensora_hpcai/opensora/models/vae/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

import mindspore as ms
from mindspore import nn, ops

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -177,6 +178,7 @@ def make_attn(in_channels, attn_type="vanilla"):

# used in vae
class Encoder(nn.Cell):
@ms.lazy_inline()
def __init__(
self,
*,
Expand Down Expand Up @@ -299,6 +301,7 @@ def construct(self, x):

# used in vae
class Decoder(nn.Cell):
@ms.lazy_inline()
def __init__(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions examples/opensora_hpcai/opensora/models/vae/vae_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def get_activation_fn(activation):
class Encoder(nn.Cell):
"""Encoder Blocks."""

@ms.lazy_inline()
def __init__(
self,
in_out_channels=4,
Expand Down Expand Up @@ -260,6 +261,7 @@ def construct(self, x):
class Decoder(nn.Cell):
"""Decoder Blocks."""

@ms.lazy_inline()
def __init__(
self,
in_out_channels=4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ python scripts/train.py \
--pretrained_model_path="models/OpenSora-STDiT-v3/opensora_stdit_v3.ckpt" \
--mode=0 \
--jit_level O1 \
--max_device_memory 59GB \
--max_device_memory 55GB \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this affect?

--config configs/opensora-v1-2/train/train_stage2.yaml \
--csv_path datasets/mixkit-100videos/video_caption_train.csv \
--video_folder datasets/mixkit-100videos/mixkit \
Expand Down
14 changes: 8 additions & 6 deletions examples/opensora_hpcai/scripts/train.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set drop_remainder here to False as well to keep the consistency?

Original file line number Diff line number Diff line change
Expand Up @@ -562,17 +562,19 @@ def main(args):
)

# compute total steps and data epochs (in unit of data sink size)
if args.dataset_sink_mode and args.sink_size != -1:
steps_per_sink = args.sink_size
else:
steps_per_sink = dataset_size

if args.train_steps == -1:
assert args.epochs != -1
total_train_steps = args.epochs * dataset_size
sink_epochs = math.ceil(total_train_steps / steps_per_sink)
else:
total_train_steps = args.train_steps

if args.dataset_sink_mode and args.sink_size != -1:
steps_per_sink = args.sink_size
else:
steps_per_sink = dataset_size
sink_epochs = math.ceil(total_train_steps / steps_per_sink)
# asume one step need one whole epoch data to ensure enough batch loading for training
sink_epochs = total_train_steps

if args.ckpt_save_steps == -1:
ckpt_save_interval = args.ckpt_save_interval
Expand Down