Skip to content

Commit

Permalink
fix trainonestepbug for 2.3
Browse files Browse the repository at this point in the history
  • Loading branch information
LiTingyu1997 committed Mar 7, 2024
1 parent 6fd25c3 commit 11ebd4f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/deepspeech2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from mindspore import ParameterTuple, Tensor, context
from mindspore.communication.management import get_group_size, get_rank, init
from mindspore.context import ParallelMode
from mindspore.nn import TrainOneStepCell
from mindspore.nn.optim import Adam
from mindspore.train import Model
from mindspore.train.callback import (
Expand All @@ -21,6 +20,7 @@
from mindaudio.models.deepspeech2 import DeepSpeechModel
from mindaudio.scheduler.scheduler_factory import step_lr
from mindaudio.utils.hparams import parse_args
from mindaudio.utils.train_one_step import TrainOneStepWithLossScaleCell


def train(args):
Expand Down Expand Up @@ -61,7 +61,7 @@ def train(args):
eps=args.OptimConfig.epsilon,
loss_scale=args.OptimConfig.loss_scale,
)
train_net = TrainOneStepCell(loss_net, optimizer)
train_net = TrainOneStepWithLossScaleCell(loss_net, optimizer, Tensor(1024))
train_net.set_train(True)
if args.Pretrained_model != "":
param_dict = load_checkpoint(args.Pretrained_model)
Expand Down

0 comments on commit 11ebd4f

Please sign in to comment.