Skip to content

Commit

Permalink
update causalvae training
Browse files Browse the repository at this point in the history
  • Loading branch information
wtomin committed Oct 16, 2024
1 parent 90ef9b9 commit 606bb6a
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# )
from .causal_vae import CausalVAEModelWrapper
from .causal_vae.modeling_causalvae import CausalVAEModel
from .ema_model import EMA

videobase_ae_stride = {
"CausalVAEModel_4x8x8": [4, 8, 8],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from mindspore.ops import composite as C
from mindspore.ops import functional as F

from mindone.trainers.ema import EMA as EMA_

_ema_op = C.MultitypeFuncGraph("grad_ema_op")


@_ema_op.register("Number", "Tensor", "Tensor")
def _ema_weights(factor, ema_weight, weight):
return F.assign(ema_weight, ema_weight * factor + weight * (1 - factor))


class EMA(EMA_):
def ema_update(self):
"""Update EMA parameters."""
self.updates += 1
# update trainable parameters
success = self.hyper_map(F.partial(_ema_op, self.ema_decay), self.ema_weight, self.net_weight)
self.updates = F.depend(self.updates, success)

return self.updates


def save_ema_ckpts(net, ema, ckpt_manager, ckpt_name):
if ema is not None:
ema.swap_before_eval()

ckpt_manager.save(net, None, ckpt_name=ckpt_name, append_dict=None)

if ema is not None:
ema.swap_after_eval()
ckpt_manager.save(
net,
None,
ckpt_name=ckpt_name.replace(".ckpt", "_nonema.ckpt"),
append_dict=None,
)
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def __init__(
self.perceptual_loss = perceptual_loss

l1 = nn.L1Loss(reduction="none")
l2 = nn.L2Loss(reduction="none")
# l2 = nn.L2Loss(reduction="none")
l2 = ops.L2Loss()
self.loss_func = l1 if loss_type == "l1" else l2
# TODO: is self.logvar trainable?
self.logvar = ms.Parameter(ms.Tensor([logvar_init], dtype=dtype))
Expand Down
10 changes: 6 additions & 4 deletions examples/opensora_pku/opensora/train/train_causalvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
sys.path.append(".")
mindone_lib_path = os.path.abspath("../../")
sys.path.insert(0, mindone_lib_path)
from opensora.models.causalvideovae.model import CausalVAEModel
from opensora.models.causalvideovae.model import EMA, CausalVAEModel
from opensora.models.causalvideovae.model.dataset_videobase import VideoDataset, create_dataloader
from opensora.models.causalvideovae.model.losses.net_with_loss import DiscriminatorWithLoss, GeneratorWithLoss
from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate
from opensora.models.causalvideovae.model.utils.model_utils import resolve_str_to_obj
from opensora.train.commons import create_loss_scaler, parse_args
from opensora.utils.ema import EMA
from opensora.utils.ms_utils import init_env
from opensora.utils.utils import get_precision

Expand Down Expand Up @@ -448,10 +447,13 @@ def main(args):
step_time = time.time() - start_time_s
if step % args.log_interval == 0:
loss_ae = float(loss_ae_t.asnumpy())
logger.info(f"E: {epoch+1}, S: {step+1}, Loss ae: {loss_ae:.4f}, Step time: {step_time*1000:.2f}ms")
logger.info(
f"E: {epoch+1}, S: {step+1}, Loss ae: {loss_ae:.4f}, ae loss scaler {loss_scaler_ae.loss_scale_value},\
Step time: {step_time*1000:.2f}ms"
)
if global_step >= disc_start:
loss_disc = float(loss_disc_t.asnumpy())
logger.info(f"Loss disc: {loss_disc:.4f}")
logger.info(f"Loss disc: {loss_disc:.4f}, disc loss scaler {loss_scaler_disc.loss_scale_value}")
loss_log_file.write(f"{cur_global_step}\t{loss_ae:.7f}\t{loss_disc:.7f}\t{step_time:.2f}\n")
else:
loss_log_file.write(f"{cur_global_step}\t{loss_ae:.7f}\t{0.0}\t{step_time:.2f}\n")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
python opensora/train/train_causalvae.py \
--exp_name "9x256x256" \
--exp_name "25x256x256" \
--batch_size 1 \
--precision fp32 \
--precision bf16 \
--max_steps 100000 \
--save_steps 2000 \
--output_dir results/causalvae \
--video_path /remote-home1/dataset/data_split_tt \
--video_num_frames 25 \
--resolution 256 \
--sample_rate 1 \
--sample_rate 2 \
--dataloader_num_workers 8 \
--load_from_checkpoint pretrained/causal_vae_488_init.ckpt \
--start_learning_rate 1e-5 \
--lr_scheduler constant \
--optim adam \
--betas 0.5 0.9 \
--betas 0.9 0.999 \
--clip_grad True \
--weight_decay 0.0 \
--mode 0 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@ export MS_ENABLE_NUMA=0
export MS_MEMORY_STATISTIC=1
export GLOG_v=2
output_dir="results/causalvae"
exp_name="9x256x256"
exp_name="25x256x256"

msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir=$output_dir/$exp_name/parallel_logs opensora/train/train_causalvae.py \
--exp_name $exp_name \
--batch_size 1 \
--precision fp32 \
--precision bf16 \
--max_steps 100000 \
--save_steps 2000 \
--output_dir $output_dir \
--video_path /remote-home1/dataset/data_split_tt \
--video_num_frames 25 \
--resolution 256 \
--sample_rate 1 \
--sample_rate 2 \
--dataloader_num_workers 8 \
--load_from_checkpoint pretrained/causal_vae_488_init.ckpt \
--start_learning_rate 1e-5 \
--lr_scheduler constant \
--optim adam \
--betas 0.5 0.9 \
--betas 0.9 0.999 \
--clip_grad True \
--weight_decay 0.0 \
--mode 0 \
Expand Down

0 comments on commit 606bb6a

Please sign in to comment.