From dac2427fcf02cbe77368fed13d4cbd512bb474d2 Mon Sep 17 00:00:00 2001 From: litingyu <67683219+LiTingyu1997@users.noreply.github.com> Date: Tue, 19 Mar 2024 17:08:31 +0800 Subject: [PATCH] fix trainonestepbug for 2.3 (#183) --- examples/deepspeech2/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/deepspeech2/train.py b/examples/deepspeech2/train.py index 1627a4e..4c01086 100644 --- a/examples/deepspeech2/train.py +++ b/examples/deepspeech2/train.py @@ -6,7 +6,6 @@ from mindspore import ParameterTuple, Tensor, context from mindspore.communication.management import get_group_size, get_rank, init from mindspore.context import ParallelMode -from mindspore.nn import TrainOneStepCell from mindspore.nn.optim import Adam from mindspore.train import Model from mindspore.train.callback import ( @@ -21,6 +20,7 @@ from mindaudio.models.deepspeech2 import DeepSpeechModel from mindaudio.scheduler.scheduler_factory import step_lr from mindaudio.utils.hparams import parse_args +from mindaudio.utils.train_one_step import TrainOneStepWithLossScaleCell def train(args): @@ -61,7 +61,7 @@ def train(args): eps=args.OptimConfig.epsilon, loss_scale=args.OptimConfig.loss_scale, ) - train_net = TrainOneStepCell(loss_net, optimizer) + train_net = TrainOneStepWithLossScaleCell(loss_net, optimizer, Tensor(1024)) train_net.set_train(True) if args.Pretrained_model != "": param_dict = load_checkpoint(args.Pretrained_model)