Skip to content

Commit

Permalink
update causalvae training
Browse files Browse the repository at this point in the history
  • Loading branch information
wtomin committed Oct 16, 2024
1 parent 90ef9b9 commit e91267b
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# )
from .causal_vae import CausalVAEModelWrapper
from .causal_vae.modeling_causalvae import CausalVAEModel
from .ema_model import EMA

videobase_ae_stride = {
"CausalVAEModel_4x8x8": [4, 8, 8],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from mindspore.ops import composite as C
from mindspore.ops import functional as F

from mindone.trainers.ema import EMA as EMA_

_ema_op = C.MultitypeFuncGraph("grad_ema_op")


@_ema_op.register("Number", "Tensor", "Tensor")
def _ema_weights(factor, ema_weight, weight):
return F.assign(ema_weight, ema_weight * factor + weight * (1 - factor))


class EMA(EMA_):
def ema_update(self):
"""Update EMA parameters."""
self.updates += 1
# update trainable parameters
success = self.hyper_map(F.partial(_ema_op, self.ema_decay), self.ema_weight, self.net_weight)
self.updates = F.depend(self.updates, success)

return self.updates


def save_ema_ckpts(net, ema, ckpt_manager, ckpt_name):
if ema is not None:
ema.swap_before_eval()

ckpt_manager.save(net, None, ckpt_name=ckpt_name, append_dict=None)

if ema is not None:
ema.swap_after_eval()
ckpt_manager.save(
net,
None,
ckpt_name=ckpt_name.replace(".ckpt", "_nonema.ckpt"),
append_dict=None,
)
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def __init__(
self.perceptual_loss = perceptual_loss

l1 = nn.L1Loss(reduction="none")
l2 = nn.L2Loss(reduction="none")
# l2 = nn.L2Loss(reduction="none")
l2 = ops.L2Loss()
self.loss_func = l1 if loss_type == "l1" else l2
# TODO: is self.logvar trainable?
self.logvar = ms.Parameter(ms.Tensor([logvar_init], dtype=dtype))
Expand Down
21 changes: 13 additions & 8 deletions examples/opensora_pku/opensora/train/train_causalvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
sys.path.append(".")
mindone_lib_path = os.path.abspath("../../")
sys.path.insert(0, mindone_lib_path)
from opensora.models.causalvideovae.model import CausalVAEModel
from opensora.models.causalvideovae.model import EMA, CausalVAEModel
from opensora.models.causalvideovae.model.dataset_videobase import VideoDataset, create_dataloader
from opensora.models.causalvideovae.model.losses.net_with_loss import DiscriminatorWithLoss, GeneratorWithLoss
from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate
from opensora.models.causalvideovae.model.utils.model_utils import resolve_str_to_obj
from opensora.train.commons import create_loss_scaler, parse_args
from opensora.utils.ema import EMA
from opensora.utils.ms_utils import init_env
from opensora.utils.utils import get_precision

Expand Down Expand Up @@ -153,7 +152,7 @@ def main(args):
dataset,
shuffle=True,
num_parallel_workers=args.dataloader_num_workers,
batch_size=args.batch_size,
batch_size=args.train_batch_size,
drop_remainder=True,
device_num=device_num,
rank_id=rank_id,
Expand All @@ -164,8 +163,10 @@ def main(args):
# 5. build training utils
# torch scale lr by: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
if args.scale_lr:
learning_rate = args.start_learning_rate * args.batch_size * args.gradient_accumulation_steps * device_num
end_learning_rate = args.end_learning_rate * args.batch_size * args.gradient_accumulation_steps * device_num
learning_rate = args.start_learning_rate * args.train_batch_size * args.gradient_accumulation_steps * device_num
end_learning_rate = (
args.end_learning_rate * args.train_batch_size * args.gradient_accumulation_steps * device_num
)
else:
learning_rate = args.start_learning_rate
end_learning_rate = args.end_learning_rate
Expand Down Expand Up @@ -348,7 +349,7 @@ def main(args):
f"dtype: {args.precision}",
f"Use discriminator: {args.use_discriminator}",
f"Learning rate: {learning_rate}",
f"Batch size: {args.batch_size}",
f"Batch size: {args.train_batch_size}",
f"Rescale size: {args.resolution}",
f"Crop size: {args.resolution}",
f"Number of frames: {args.video_num_frames}",
Expand Down Expand Up @@ -448,10 +449,13 @@ def main(args):
step_time = time.time() - start_time_s
if step % args.log_interval == 0:
loss_ae = float(loss_ae_t.asnumpy())
logger.info(f"E: {epoch+1}, S: {step+1}, Loss ae: {loss_ae:.4f}, Step time: {step_time*1000:.2f}ms")
logger.info(
f"E: {epoch+1}, S: {step+1}, Loss ae: {loss_ae:.4f}, ae loss scaler {loss_scaler_ae.loss_scale_value},"
+ f" Step time: {step_time*1000:.2f}ms"
)
if global_step >= disc_start:
loss_disc = float(loss_disc_t.asnumpy())
logger.info(f"Loss disc: {loss_disc:.4f}")
logger.info(f"Loss disc: {loss_disc:.4f}, disc loss scaler {loss_scaler_disc.loss_scale_value}")
loss_log_file.write(f"{cur_global_step}\t{loss_ae:.7f}\t{loss_disc:.7f}\t{step_time:.2f}\n")
else:
loss_log_file.write(f"{cur_global_step}\t{loss_ae:.7f}\t{0.0}\t{step_time:.2f}\n")
Expand Down Expand Up @@ -545,6 +549,7 @@ def parse_causalvae_train_args(parser):
help="Whether to use the discriminator in the training process. "
"Phase 1 training does not use discriminator, set False to reduce memory cost in graph mode.",
)

parser.add_argument(
"--model_config",
default="scripts/causalvae/release.json",
Expand Down
12 changes: 6 additions & 6 deletions examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
python opensora/train/train_causalvae.py \
--exp_name "9x256x256" \
--batch_size 1 \
--precision fp32 \
--exp_name "25x256x256" \
--train_batch_size 1 \
--precision bf16 \
--max_steps 100000 \
--save_steps 2000 \
--output_dir results/causalvae \
--video_path /remote-home1/dataset/data_split_tt \
--video_num_frames 25 \
--resolution 256 \
--sample_rate 1 \
--sample_rate 2 \
--dataloader_num_workers 8 \
--load_from_checkpoint pretrained/causal_vae_488_init.ckpt \
--start_learning_rate 1e-5 \
--lr_scheduler constant \
--optim adam \
--betas 0.5 0.9 \
--betas 0.9 0.999 \
--clip_grad True \
--weight_decay 0.0 \
--mode 0 \
--init_loss_scale 1 \
--init_loss_scale 65536 \
--jit_level "O0" \
--use_discriminator True \
--use_ema True\
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,29 @@ export MS_ENABLE_NUMA=0
export MS_MEMORY_STATISTIC=1
export GLOG_v=2
output_dir="results/causalvae"
exp_name="9x256x256"
exp_name="25x256x256"

msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir=$output_dir/$exp_name/parallel_logs opensora/train/train_causalvae.py \
--exp_name $exp_name \
--batch_size 1 \
--precision fp32 \
--train_batch_size 1 \
--precision bf16 \
--max_steps 100000 \
--save_steps 2000 \
--output_dir $output_dir \
--video_path /remote-home1/dataset/data_split_tt \
--video_num_frames 25 \
--resolution 256 \
--sample_rate 1 \
--sample_rate 2 \
--dataloader_num_workers 8 \
--load_from_checkpoint pretrained/causal_vae_488_init.ckpt \
--start_learning_rate 1e-5 \
--lr_scheduler constant \
--optim adam \
--betas 0.5 0.9 \
--betas 0.9 0.999 \
--clip_grad True \
--weight_decay 0.0 \
--mode 0 \
--init_loss_scale 1 \
--init_loss_scale 65536 \
--jit_level "O0" \
--use_discriminator True \
--use_parallel True \
Expand Down

0 comments on commit e91267b

Please sign in to comment.