diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/__init__.py b/examples/opensora_pku/opensora/models/causalvideovae/model/__init__.py index c3fac22450..5d3351be6e 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/__init__.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/__init__.py @@ -11,6 +11,7 @@ # ) from .causal_vae import CausalVAEModelWrapper from .causal_vae.modeling_causalvae import CausalVAEModel +from .ema_model import EMA videobase_ae_stride = { "CausalVAEModel_4x8x8": [4, 8, 8], diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/ema_model.py b/examples/opensora_pku/opensora/models/causalvideovae/model/ema_model.py new file mode 100644 index 0000000000..6f62a6043e --- /dev/null +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/ema_model.py @@ -0,0 +1,38 @@ +from mindspore.ops import composite as C +from mindspore.ops import functional as F + +from mindone.trainers.ema import EMA as EMA_ + +_ema_op = C.MultitypeFuncGraph("grad_ema_op") + + +@_ema_op.register("Number", "Tensor", "Tensor") +def _ema_weights(factor, ema_weight, weight): + return F.assign(ema_weight, ema_weight * factor + weight * (1 - factor)) + + +class EMA(EMA_): + def ema_update(self): + """Update EMA parameters.""" + self.updates += 1 + # update trainable parameters + success = self.hyper_map(F.partial(_ema_op, self.ema_decay), self.ema_weight, self.net_weight) + self.updates = F.depend(self.updates, success) + + return self.updates + + +def save_ema_ckpts(net, ema, ckpt_manager, ckpt_name): + if ema is not None: + ema.swap_before_eval() + + ckpt_manager.save(net, None, ckpt_name=ckpt_name, append_dict=None) + + if ema is not None: + ema.swap_after_eval() + ckpt_manager.save( + net, + None, + ckpt_name=ckpt_name.replace(".ckpt", "_nonema.ckpt"), + append_dict=None, + ) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py index ef148ebe73..c2e07bbb32 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py @@ -46,7 +46,8 @@ def __init__( self.perceptual_loss = perceptual_loss l1 = nn.L1Loss(reduction="none") - l2 = nn.L2Loss(reduction="none") + # l2 = nn.L2Loss(reduction="none") + l2 = ops.L2Loss() self.loss_func = l1 if loss_type == "l1" else l2 # TODO: is self.logvar trainable? self.logvar = ms.Parameter(ms.Tensor([logvar_init], dtype=dtype)) diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index ac45fa19f5..44e3617fce 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -18,13 +18,12 @@ sys.path.append(".") mindone_lib_path = os.path.abspath("../../") sys.path.insert(0, mindone_lib_path) -from opensora.models.causalvideovae.model import CausalVAEModel +from opensora.models.causalvideovae.model import EMA, CausalVAEModel from opensora.models.causalvideovae.model.dataset_videobase import VideoDataset, create_dataloader from opensora.models.causalvideovae.model.losses.net_with_loss import DiscriminatorWithLoss, GeneratorWithLoss from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate from opensora.models.causalvideovae.model.utils.model_utils import resolve_str_to_obj from opensora.train.commons import create_loss_scaler, parse_args -from opensora.utils.ema import EMA from opensora.utils.ms_utils import init_env from opensora.utils.utils import get_precision @@ -448,10 +447,13 @@ def main(args): step_time = time.time() - start_time_s if step % args.log_interval == 0: loss_ae = float(loss_ae_t.asnumpy()) - logger.info(f"E: {epoch+1}, S: {step+1}, Loss ae: {loss_ae:.4f}, Step time: {step_time*1000:.2f}ms") + logger.info( + f"E: {epoch+1}, S: {step+1}, Loss ae: {loss_ae:.4f}, ae loss scaler {loss_scaler_ae.loss_scale_value},\ + Step time: {step_time*1000:.2f}ms" + ) if global_step >= disc_start: loss_disc = float(loss_disc_t.asnumpy()) - logger.info(f"Loss disc: {loss_disc:.4f}") + logger.info(f"Loss disc: {loss_disc:.4f}, disc loss scaler {loss_scaler_disc.loss_scale_value}") loss_log_file.write(f"{cur_global_step}\t{loss_ae:.7f}\t{loss_disc:.7f}\t{step_time:.2f}\n") else: loss_log_file.write(f"{cur_global_step}\t{loss_ae:.7f}\t{0.0}\t{step_time:.2f}\n") diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh index 8e5f17f49a..b28c4bb136 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh @@ -1,20 +1,20 @@ python opensora/train/train_causalvae.py \ - --exp_name "9x256x256" \ + --exp_name "25x256x256" \ --batch_size 1 \ - --precision fp32 \ + --precision bf16 \ --max_steps 100000 \ --save_steps 2000 \ --output_dir results/causalvae \ --video_path /remote-home1/dataset/data_split_tt \ --video_num_frames 25 \ --resolution 256 \ - --sample_rate 1 \ + --sample_rate 2 \ --dataloader_num_workers 8 \ --load_from_checkpoint pretrained/causal_vae_488_init.ckpt \ --start_learning_rate 1e-5 \ --lr_scheduler constant \ --optim adam \ - --betas 0.5 0.9 \ + --betas 0.9 0.999 \ --clip_grad True \ --weight_decay 0.0 \ --mode 0 \ diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh index e23bf8f6cf..979b7922f7 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh @@ -3,25 +3,25 @@ export MS_ENABLE_NUMA=0 export MS_MEMORY_STATISTIC=1 export GLOG_v=2 output_dir="results/causalvae" -exp_name="9x256x256" +exp_name="25x256x256" msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir=$output_dir/$exp_name/parallel_logs opensora/train/train_causalvae.py \ --exp_name $exp_name \ --batch_size 1 \ - --precision fp32 \ + --precision bf16 \ --max_steps 100000 \ --save_steps 2000 \ --output_dir $output_dir \ --video_path /remote-home1/dataset/data_split_tt \ --video_num_frames 25 \ --resolution 256 \ - --sample_rate 1 \ + --sample_rate 2 \ --dataloader_num_workers 8 \ --load_from_checkpoint pretrained/causal_vae_488_init.ckpt \ --start_learning_rate 1e-5 \ --lr_scheduler constant \ --optim adam \ - --betas 0.5 0.9 \ + --betas 0.9 0.999 \ --clip_grad True \ --weight_decay 0.0 \ --mode 0 \