From 11ebd4ffe40aa97517a6e8732487dc23bb0b6713 Mon Sep 17 00:00:00 2001 From: li-tingyu <605979840@qq.com> Date: Thu, 7 Mar 2024 16:02:23 +0800 Subject: [PATCH] fix trainonestepbug for 2.3 --- examples/deepspeech2/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/deepspeech2/train.py b/examples/deepspeech2/train.py index 1627a4e..4c01086 100644 --- a/examples/deepspeech2/train.py +++ b/examples/deepspeech2/train.py @@ -6,7 +6,6 @@ from mindspore import ParameterTuple, Tensor, context from mindspore.communication.management import get_group_size, get_rank, init from mindspore.context import ParallelMode -from mindspore.nn import TrainOneStepCell from mindspore.nn.optim import Adam from mindspore.train import Model from mindspore.train.callback import ( @@ -21,6 +20,7 @@ from mindaudio.models.deepspeech2 import DeepSpeechModel from mindaudio.scheduler.scheduler_factory import step_lr from mindaudio.utils.hparams import parse_args +from mindaudio.utils.train_one_step import TrainOneStepWithLossScaleCell def train(args): @@ -61,7 +61,7 @@ def train(args): eps=args.OptimConfig.epsilon, loss_scale=args.OptimConfig.loss_scale, ) - train_net = TrainOneStepCell(loss_net, optimizer) + train_net = TrainOneStepWithLossScaleCell(loss_net, optimizer, Tensor(1024)) train_net.set_train(True) if args.Pretrained_model != "": param_dict = load_checkpoint(args.Pretrained_model)