diff --git a/examples/conformer/predict.py b/examples/conformer/predict.py index c429136..c4b8230 100644 --- a/examples/conformer/predict.py +++ b/examples/conformer/predict.py @@ -5,10 +5,10 @@ import os +import mindspore import numpy as np from asr_model import creadte_asr_model from dataset import create_asr_predict_dataset, load_language_dict -from mindspore import context from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net @@ -44,9 +44,12 @@ def main(): os.makedirs(decode_dir, exist_ok=True) result_file = open(os.path.join(decode_dir, "result.txt"), "w") - context.set_context( - mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id() + mindspore.set_context( + mode=0, + device_target="Ascend", + device_id=get_device_id(), ) + mindspore.set_context(jit_config={"jit_level": "O2"}) # load test data test_dataset = create_asr_predict_dataset( diff --git a/examples/conformer/train.py b/examples/conformer/train.py index 60cf24f..1b20a32 100644 --- a/examples/conformer/train.py +++ b/examples/conformer/train.py @@ -5,9 +5,10 @@ import os +import mindspore from asr_model import creadte_asr_model, create_asr_eval_net from dataset import create_dataset -from mindspore import ParameterTuple, context, set_seed +from mindspore import ParameterTuple, set_seed from mindspore.communication.management import init from mindspore.context import ParallelMode from mindspore.nn.optim import Adam @@ -57,21 +58,22 @@ def train(): model_dir = os.path.join(exp_dir, "model") graph_dir = os.path.join(exp_dir, "graph") summary_dir = os.path.join(exp_dir, "summary") - context.set_context( - mode=context.GRAPH_MODE, + mindspore.set_context( + mode=0, device_target="Ascend", device_id=get_device_id(), save_graphs=config.save_graphs, save_graphs_path=graph_dir, ) + mindspore.set_context(jit_config={"jit_level": "O2"}) device_num = get_device_num() rank = get_rank_id() # configurations for distributed training if config.is_distributed: init() - context.reset_auto_parallel_context() - context.set_auto_parallel_context( + mindspore.reset_auto_parallel_context() + mindspore.set_auto_parallel_context( parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num, diff --git a/requirements-dev.txt b/requirements-dev.txt index cc6cae0..6eb672e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ mindspore==2.0.0 -numpy>=1.17.0 +numpy>=1.17.0, <2 scipy>=1.6.0 pyyaml>=5.3 tqdm diff --git a/requirements.txt b/requirements.txt index 5296538..121944f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ mindspore==2.0.0 -numpy>=1.17.0 +numpy>=1.17.0, <2 scipy>=1.6.0 pyyaml>=5.3 tqdm