Skip to content

Commit

Permalink
fix global step
Browse files Browse the repository at this point in the history
  • Loading branch information
SamitHuang committed Aug 26, 2024
1 parent 88de7e2 commit 06aded8
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mindone/trainers/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,10 @@ def choice_func(x):
def on_train_step_end(self, run_context):
cb_params = run_context.original_args()
loss = _handle_loss(cb_params.net_outputs)
# cur_step = cb_params.cur_step_num + self.start_epoch * cb_params.batch_num
opt = self._get_optimizer_from_cbp(cb_params)
cur_step = int(opt.global_step.asnumpy().item())
if cur_step <= 0:
cur_step = cb_params.cur_step_num + self.start_epoch * cb_params.batch_num

step_num = (cb_params.batch_num * cb_params.epoch_num) if self.train_steps < 0 else self.train_steps

Expand Down

0 comments on commit 06aded8

Please sign in to comment.