Skip to content

Commit

Permalink
correct typo
Browse files Browse the repository at this point in the history
  • Loading branch information
wtomin committed Oct 17, 2024
1 parent 7b46343 commit f305998
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 15 deletions.
2 changes: 1 addition & 1 deletion examples/opensora_pku/examples/rec_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def main(args):
[k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items()
)
else:
None
state_dict = None
kwarg = {"state_dict": state_dict, "use_safetensors": True}
vae = CausalVAEModelWrapper(args.ae_path, **kwarg)

Expand Down
2 changes: 1 addition & 1 deletion examples/opensora_pku/examples/rec_video_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def main(args):
[k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items()
)
else:
None
state_dict = None
kwarg = {"state_dict": state_dict, "use_safetensors": True}
vae = CausalVAEModelWrapper(args.ae_path, **kwarg)

Expand Down
34 changes: 21 additions & 13 deletions examples/opensora_pku/opensora/train/train_t2v_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,12 +406,12 @@ def main(args):

if steps_per_sink == dataloader_size:
logger.info(
f"Number of training steps: {total_train_steps}; Number of epochs: {args.num_train_epochs}; "
f"Number of training steps: {total_train_steps}, Number of epochs: {args.num_train_epochs}, "
f"Number of batches in a epoch (dataloader size): {dataloader_size}"
)
else:
logger.info(
f"Number of training steps: {total_train_steps}; Number of sink epochs: {sink_epochs}; Number of batches in a sink (sink_size): {steps_per_sink}"
f"Number of training steps: {total_train_steps}, Number of sink epochs: {sink_epochs}, Number of batches in a sink (sink_size): {steps_per_sink}"
)
if args.checkpointing_steps is None:
ckpt_save_interval = args.ckpt_save_interval
Expand All @@ -425,7 +425,7 @@ def main(args):
ckpt_save_interval = max(1, args.checkpointing_steps // steps_per_sink)
if args.checkpointing_steps % steps_per_sink != 0:
logger.warning(
f"`checkpointing_steps` must be times of sink size or dataset_size under dataset sink mode."
"`checkpointing_steps` must be times of sink size or dataset_size under dataset sink mode."
f"Checkpoint will be saved every {ckpt_save_interval * steps_per_sink} steps."
)
if step_mode != args.step_mode:
Expand Down Expand Up @@ -610,13 +610,16 @@ def main(args):
integrated_save=integrated_save,
save_training_resume=save_training_resume,
)
rec_cb = PerfRecorderCallback(
save_dir=args.output_dir,
file_name="result_val.log",
resume=args.resume_from_checkpoint,
metric_names=list(metrics.keys()),
)
callback.extend([save_cb, rec_cb])
callback.append(save_cb)
if args.validate:
assert metrics is not None, "Val during training must set the metric functions"
rec_cb = PerfRecorderCallback(
save_dir=args.output_dir,
file_name="result_val.log",
resume=args.resume_from_checkpoint,
metric_names=list(metrics.keys()),
)
callback.append(rec_cb)
if args.profile:
callback.append(ProfilerCallbackEpoch(2, 2, "./profile_data"))
# Train!
Expand All @@ -637,10 +640,15 @@ def main(args):
f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}",
f"Jit level: {args.jit_level}",
f"Distributed mode: {args.use_parallel}"
+ (f"\nParallel mode: {args.parallel_mode}" if args.use_parallel else "")
+ (
f"\nParallel mode: {args.parallel_mode}"
+ (f"{args.zero_stage}" if args.parallel_mode == "zero" else "")
if args.use_parallel
else ""
)
+ (f"\nsp_size: {args.sp_size}" if args.sp_size != 1 else ""),
f"Num params: {num_params:,} (transformer: {num_params_transformer:,}, vae: {num_params_vae:,})",
f"Num trainable params: {num_params_trainable:,}",
f"Num params: {num_params: , } (transformer: {num_params_transformer: , }, vae: {num_params_vae: , })",
f"Num trainable params: {num_params_trainable: , }",
f"Transformer model dtype: {model_dtype}",
f"Transformer AMP level: {args.amp_level}" if not args.global_bf16 else "Global BF16: True",
f"VAE dtype: {vae_dtype} (amp level O2)"
Expand Down

0 comments on commit f305998

Please sign in to comment.