Skip to content

Commit

Permalink
Merge pull request #187 from daiyuxin0511/main
Browse files Browse the repository at this point in the history
update jit level
  • Loading branch information
LiTingyu1997 authored Jun 25, 2024
2 parents a53326b + 22952cb commit 21f7422
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 10 deletions.
9 changes: 6 additions & 3 deletions examples/conformer/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import os

import mindspore
import numpy as np
from asr_model import creadte_asr_model
from dataset import create_asr_predict_dataset, load_language_dict
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net

Expand Down Expand Up @@ -44,9 +44,12 @@ def main():
os.makedirs(decode_dir, exist_ok=True)
result_file = open(os.path.join(decode_dir, "result.txt"), "w")

context.set_context(
mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id()
mindspore.set_context(
mode=0,
device_target="Ascend",
device_id=get_device_id(),
)
mindspore.set_context(jit_config={"jit_level": "O2"})

# load test data
test_dataset = create_asr_predict_dataset(
Expand Down
12 changes: 7 additions & 5 deletions examples/conformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import os

import mindspore
from asr_model import creadte_asr_model, create_asr_eval_net
from dataset import create_dataset
from mindspore import ParameterTuple, context, set_seed
from mindspore import ParameterTuple, set_seed
from mindspore.communication.management import init
from mindspore.context import ParallelMode
from mindspore.nn.optim import Adam
Expand Down Expand Up @@ -57,21 +58,22 @@ def train():
model_dir = os.path.join(exp_dir, "model")
graph_dir = os.path.join(exp_dir, "graph")
summary_dir = os.path.join(exp_dir, "summary")
context.set_context(
mode=context.GRAPH_MODE,
mindspore.set_context(
mode=0,
device_target="Ascend",
device_id=get_device_id(),
save_graphs=config.save_graphs,
save_graphs_path=graph_dir,
)
mindspore.set_context(jit_config={"jit_level": "O2"})

device_num = get_device_num()
rank = get_rank_id()
# configurations for distributed training
if config.is_distributed:
init()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(
mindspore.reset_auto_parallel_context()
mindspore.set_auto_parallel_context(
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
device_num=device_num,
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mindspore==2.0.0
numpy>=1.17.0
numpy>=1.17.0, <2
scipy>=1.6.0
pyyaml>=5.3
tqdm
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mindspore==2.0.0
numpy>=1.17.0
numpy>=1.17.0, <2
scipy>=1.6.0
pyyaml>=5.3
tqdm
Expand Down

0 comments on commit 21f7422

Please sign in to comment.