Skip to content

Commit

Permalink
hotfix: fix the train_step optimizer.global_step type (int32 -> float…
Browse files Browse the repository at this point in the history
…32) bug (mindspore-lab#686)
  • Loading branch information
The-truthh authored Jun 14, 2023
1 parent fc1f000 commit dbc49e7
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions mindcv/utils/train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

import mindspore as ms
from mindspore import RowTensor, boost, nn, ops
from mindspore import Parameter, RowTensor, Tensor, boost, nn, ops
from mindspore.boost.grad_accumulation import gradient_accumulation_op, gradient_clear_op
from mindspore.ops import functional as F

Expand Down Expand Up @@ -108,6 +108,7 @@ def __init__(
super(TrainStep, self).__init__(network, optimizer, scale_sense)
self.ema = ema
self.ema_decay = ema_decay
self.updates = Parameter(Tensor(0.0, ms.float32))
self.clip_grad = clip_grad
self.clip_value = clip_value
if self.ema:
Expand All @@ -119,8 +120,9 @@ def __init__(
self.gradient_accumulation = GradientAccumulation(gradient_accumulation_steps, optimizer, self.grad_reducer)

def ema_update(self):
self.updates += 1
# ema factor is corrected by (1 - exp(-t/T)), where `t` means time and `T` means temperature.
ema_decay = self.ema_decay * (1 - F.exp(-self.optimizer.global_step / 2000))
ema_decay = self.ema_decay * (1 - F.exp(-self.updates / 2000))
# update trainable parameters
success = self.hyper_map(F.partial(_ema_op, ema_decay), self.ema_weight, self.weights_all)
return success
Expand Down

0 comments on commit dbc49e7

Please sign in to comment.