From 46e5866cb0c7a6e417f6b35b09d17a1f36c4aac9 Mon Sep 17 00:00:00 2001 From: IASZHT Date: Thu, 25 Jul 2024 17:36:45 +0800 Subject: [PATCH] feat: add device_target config and set default to Ascend. --- config.py | 6 +++++- validate.py | 8 ++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/config.py b/config.py index b2eeeaa1..c2b9084b 100644 --- a/config.py +++ b/config.py @@ -43,6 +43,8 @@ def create_parser(): help='Interval for print training log. Unit: step (default=100)') group.add_argument('--seed', type=int, default=42, help='Seed value for determining randomness in numpy, random, and mindspore (default=42)') + group.add_argument('--device_target', type=str, default='Ascend', + help='Device target for validating, which can be Ascend, GPU or CPU. (default=Ascend)') # Dataset parameters group = parser.add_argument_group('Dataset parameters') @@ -94,7 +96,7 @@ def create_parser(): 'Example: "randaug-m10-n2-w0-mstd0.5-mmax10-inc0", "autoaug-mstd0.5" or autoaugr-mstd0.5.') group.add_argument('--aug_splits', type=int, default=0, help='Number of augmentation splits (default: 0, valid: 3 (currently, only support 3 splits))' - 'it should be set with one auto_augment') + 'it should be set with one auto_augment') group.add_argument('--re_prob', type=float, default=0.0, help='Probability of performing erasing (default=0.0)') group.add_argument('--re_scale', type=tuple, default=(0.02, 0.33), @@ -267,6 +269,8 @@ def create_parser(): help='Whether to shuffle the evaluation data (default=False)') return parser_config, parser + + # fmt: on diff --git a/validate.py b/validate.py index fc428577..4c7a67fa 100644 --- a/validate.py +++ b/validate.py @@ -26,6 +26,14 @@ def check_batch_size(num_samples, ori_batch_size=32, refine=True): def validate(args): + try: + ms.set_context(device_target=args.device_target) + except Exception as e: + raise e + print( + "Please check whether the Ascend environment is installed and configured correctly. Now the process will use the CPU to perform calculations.") + ms.set_context(device_target="CPU") + ms.set_context(mode=args.mode) if args.mode == ms.GRAPH_MODE: ms.set_context(jit_config={"jit_level": "O2"})