diff --git a/configs/edgenext/edgenext_base_ascend.yaml b/configs/edgenext/edgenext_base_ascend.yaml index f08740c6..c37716ae 100644 --- a/configs/edgenext/edgenext_base_ascend.yaml +++ b/configs/edgenext/edgenext_base_ascend.yaml @@ -3,6 +3,7 @@ mode: 0 distribute: True num_parallel_workers: 8 val_while_train: True +seed: 1 # dataset dataset: 'imagenet' @@ -60,4 +61,5 @@ filter_bias_and_bn: True momentum: 0.9 weight_decay: 0.05 loss_scale: 1024 +drop_overflow_update: True use_nesterov: False diff --git a/configs/edgenext/edgenext_small_ascend.yaml b/configs/edgenext/edgenext_small_ascend.yaml index 9962664f..c9cb6d97 100644 --- a/configs/edgenext/edgenext_small_ascend.yaml +++ b/configs/edgenext/edgenext_small_ascend.yaml @@ -3,6 +3,7 @@ mode: 0 distribute: True num_parallel_workers: 8 val_while_train: True +seed: 1 # dataset dataset: 'imagenet' @@ -59,4 +60,5 @@ filter_bias_and_bn: True momentum: 0.9 weight_decay: 0.05 loss_scale: 1024 +drop_overflow_update: True use_nesterov: False diff --git a/configs/edgenext/edgenext_x_small_ascend.yaml b/configs/edgenext/edgenext_x_small_ascend.yaml index 12fd8bd5..b21b7585 100644 --- a/configs/edgenext/edgenext_x_small_ascend.yaml +++ b/configs/edgenext/edgenext_x_small_ascend.yaml @@ -3,6 +3,7 @@ mode: 0 distribute: True num_parallel_workers: 8 val_while_train: True +seed: 1 # dataset dataset: 'imagenet' @@ -59,4 +60,5 @@ filter_bias_and_bn: True momentum: 0.9 weight_decay: 0.05 loss_scale: 1024 +drop_overflow_update: True use_nesterov: False diff --git a/configs/edgenext/edgenext_xx_small_ascend.yaml b/configs/edgenext/edgenext_xx_small_ascend.yaml index 47bea3ac..f8a97084 100644 --- a/configs/edgenext/edgenext_xx_small_ascend.yaml +++ b/configs/edgenext/edgenext_xx_small_ascend.yaml @@ -3,6 +3,7 @@ mode: 0 distribute: True num_parallel_workers: 8 val_while_train: True +seed: 1 # dataset dataset: 'imagenet' @@ -58,4 +59,5 @@ filter_bias_and_bn: True momentum: 0.9 weight_decay: 0.05 loss_scale: 1024 +drop_overflow_update: True use_nesterov: False diff --git a/mindcv/utils/train_step.py b/mindcv/utils/train_step.py index a23e427e..e091e8c6 100644 --- a/mindcv/utils/train_step.py +++ b/mindcv/utils/train_step.py @@ -152,6 +152,8 @@ def construct(self, *inputs): # if there is no overflow, do optimize if not overflow: loss = self.gradient_accumulation(loss, grads) + if self.ema: + loss = F.depend(loss, self.ema_update()) else: # apply grad reducer on grads grads = self.grad_reducer(grads) @@ -161,6 +163,8 @@ def construct(self, *inputs): # if there is no overflow, do optimize if not overflow: loss = F.depend(loss, self.optimizer(grads)) + if self.ema: + loss = F.depend(loss, self.ema_update()) else: # scale_sense = loss_scale: Tensor --> TrainOneStepCell.construct if self.accumulate_grad: loss = self.gradient_accumulation(loss, grads) @@ -168,7 +172,7 @@ def construct(self, *inputs): grads = self.grad_reducer(grads) loss = F.depend(loss, self.optimizer(grads)) - if self.ema: - loss = F.depend(loss, self.ema_update()) + if self.ema: + loss = F.depend(loss, self.ema_update()) return loss