diff --git a/examples/opensora_pku/examples/rec_video.py b/examples/opensora_pku/examples/rec_video.py index f905627c00..5c3a4d1f6f 100644 --- a/examples/opensora_pku/examples/rec_video.py +++ b/examples/opensora_pku/examples/rec_video.py @@ -134,7 +134,7 @@ def main(args): [k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items() ) else: - None + state_dict = None kwarg = {"state_dict": state_dict, "use_safetensors": True} vae = CausalVAEModelWrapper(args.ae_path, **kwarg) diff --git a/examples/opensora_pku/examples/rec_video_folder.py b/examples/opensora_pku/examples/rec_video_folder.py index 74795799dd..20426592c7 100644 --- a/examples/opensora_pku/examples/rec_video_folder.py +++ b/examples/opensora_pku/examples/rec_video_folder.py @@ -71,7 +71,7 @@ def main(args): [k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items() ) else: - None + state_dict = None kwarg = {"state_dict": state_dict, "use_safetensors": True} vae = CausalVAEModelWrapper(args.ae_path, **kwarg) diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index 02391ab4d1..3c02305d96 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -314,7 +314,8 @@ def main(args): val_dataloader = None if args.validate: - assert args.val_data is not None and os.path.exists(args.val_data), "validation dataset must be specified!" + assert args.val_data is not None, f"validation dataset must be specified, but got {args.val_data}" + assert os.path.exists(args.val_data), f"validation dataset file must exist, but got {args.val_data}" print_banner("Validation dataset Loading...") val_dataset = getdataset(args, dataset_file=args.val_data) sampler = ( @@ -406,12 +407,12 @@ def main(args): if steps_per_sink == dataloader_size: logger.info( - f"Number of training steps: {total_train_steps}; Number of epochs: {args.num_train_epochs}; " + f"Number of training steps: {total_train_steps}, Number of epochs: {args.num_train_epochs}, " f"Number of batches in a epoch (dataloader size): {dataloader_size}" ) else: logger.info( - f"Number of training steps: {total_train_steps}; Number of sink epochs: {sink_epochs}; Number of batches in a sink (sink_size): {steps_per_sink}" + f"Number of training steps: {total_train_steps}, Number of sink epochs: {sink_epochs}, Number of batches in a sink (sink_size): {steps_per_sink}" ) if args.checkpointing_steps is None: ckpt_save_interval = args.ckpt_save_interval @@ -425,7 +426,7 @@ def main(args): ckpt_save_interval = max(1, args.checkpointing_steps // steps_per_sink) if args.checkpointing_steps % steps_per_sink != 0: logger.warning( - f"`checkpointing_steps` must be times of sink size or dataset_size under dataset sink mode." + "`checkpointing_steps` must be times of sink size or dataset_size under dataset sink mode." f"Checkpoint will be saved every {ckpt_save_interval * steps_per_sink} steps." ) if step_mode != args.step_mode: @@ -610,13 +611,16 @@ def main(args): integrated_save=integrated_save, save_training_resume=save_training_resume, ) - rec_cb = PerfRecorderCallback( - save_dir=args.output_dir, - file_name="result_val.log", - resume=args.resume_from_checkpoint, - metric_names=list(metrics.keys()), - ) - callback.extend([save_cb, rec_cb]) + callback.append(save_cb) + if args.validate: + assert metrics is not None, "Val during training must set the metric functions" + rec_cb = PerfRecorderCallback( + save_dir=args.output_dir, + file_name="result_val.log", + resume=args.resume_from_checkpoint, + metric_names=list(metrics.keys()), + ) + callback.append(rec_cb) if args.profile: callback.append(ProfilerCallbackEpoch(2, 2, "./profile_data")) # Train! @@ -637,10 +641,15 @@ def main(args): f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", f"Jit level: {args.jit_level}", f"Distributed mode: {args.use_parallel}" - + (f"\nParallel mode: {args.parallel_mode}" if args.use_parallel else "") + + ( + f"\nParallel mode: {args.parallel_mode}" + + (f"{args.zero_stage}" if args.parallel_mode == "zero" else "") + if args.use_parallel + else "" + ) + (f"\nsp_size: {args.sp_size}" if args.sp_size != 1 else ""), - f"Num params: {num_params:,} (transformer: {num_params_transformer:,}, vae: {num_params_vae:,})", - f"Num trainable params: {num_params_trainable:,}", + f"Num params: {num_params: , } (transformer: {num_params_transformer: , }, vae: {num_params_vae: , })", + f"Num trainable params: {num_params_trainable: , }", f"Transformer model dtype: {model_dtype}", f"Transformer AMP level: {args.amp_level}" if not args.global_bf16 else "Global BF16: True", f"VAE dtype: {vae_dtype} (amp level O2)"