diff --git a/config.py b/config.py index b2eeeaa1..4d8390a1 100644 --- a/config.py +++ b/config.py @@ -43,6 +43,8 @@ def create_parser(): help='Interval for print training log. Unit: step (default=100)') group.add_argument('--seed', type=int, default=42, help='Seed value for determining randomness in numpy, random, and mindspore (default=42)') + group.add_argument('--device_target', type=str, default='Ascend', + help='Device target for computing, which can be Ascend, GPU or CPU. (default=Ascend)') # Dataset parameters group = parser.add_argument_group('Dataset parameters') @@ -94,7 +96,7 @@ def create_parser(): 'Example: "randaug-m10-n2-w0-mstd0.5-mmax10-inc0", "autoaug-mstd0.5" or autoaugr-mstd0.5.') group.add_argument('--aug_splits', type=int, default=0, help='Number of augmentation splits (default: 0, valid: 3 (currently, only support 3 splits))' - 'it should be set with one auto_augment') + 'it should be set with one auto_augment') group.add_argument('--re_prob', type=float, default=0.0, help='Probability of performing erasing (default=0.0)') group.add_argument('--re_scale', type=tuple, default=(0.02, 0.33), @@ -267,6 +269,8 @@ def create_parser(): help='Whether to shuffle the evaluation data (default=False)') return parser_config, parser + + # fmt: on diff --git a/tests/tasks/test_train_val_imagenet_subset.py b/tests/tasks/test_train_val_imagenet_subset.py index 2a2163e8..63b5a269 100644 --- a/tests/tasks/test_train_val_imagenet_subset.py +++ b/tests/tasks/test_train_val_imagenet_subset.py @@ -30,6 +30,7 @@ def test_train(mode, val_while_train, model="resnet18"): DownLoad().download_and_extract_archive(dataset_url, root_dir) # ---------------- test running train.py using the toy data --------- + device_target = "CPU" dataset = "imagenet" num_classes = 2 ckpt_dir = "./tests/ckpt_tmp" @@ -48,7 +49,8 @@ def test_train(mode, val_while_train, model="resnet18"): f"python {train_file} --dataset={dataset} --num_classes={num_classes} --model={model} " f"--epoch_size={num_epochs} --ckpt_save_interval=2 --lr=0.0001 --num_samples={num_samples} --loss=CE " f"--weight_decay=1e-6 --ckpt_save_dir={ckpt_dir} {download_str} --train_split=train --batch_size={batch_size} " - f"--pretrained --num_parallel_workers=2 --val_while_train={val_while_train} --val_split=val --val_interval=1" + f"--pretrained --num_parallel_workers=2 --val_while_train={val_while_train} --val_split=val --val_interval=1 " + f"--device_target={device_target}" ) print(f"Running command: \n{cmd}") @@ -57,10 +59,11 @@ def test_train(mode, val_while_train, model="resnet18"): # --------- Test running validate.py using the trained model ------------- # # begin_ckpt = os.path.join(ckpt_dir, f'{model}-1_1.ckpt') - end_ckpt = os.path.join(ckpt_dir, f"{model}-{num_epochs}_{num_samples//batch_size}.ckpt") + end_ckpt = os.path.join(ckpt_dir, f"{model}-{num_epochs}_{num_samples // batch_size}.ckpt") cmd = ( f"python validate.py --model={model} --dataset={dataset} --val_split=val --data_dir={data_dir} " - f"--num_classes={num_classes} --ckpt_path={end_ckpt} --batch_size=40 --num_parallel_workers=2" + f"--num_classes={num_classes} --ckpt_path={end_ckpt} --batch_size=40 --num_parallel_workers=2 " + f"--device_target={device_target}" ) # ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr) print(f"Running command: \n{cmd}") diff --git a/validate.py b/validate.py index fc428577..a4d33486 100644 --- a/validate.py +++ b/validate.py @@ -26,9 +26,10 @@ def check_batch_size(num_samples, ori_batch_size=32, refine=True): def validate(args): + ms.set_context(device_target=args.device_target) ms.set_context(mode=args.mode) - if args.mode == ms.GRAPH_MODE: - ms.set_context(jit_config={"jit_level": "O2"}) + # if args.mode == ms.GRAPH_MODE: + # ms.set_context(jit_config={"jit_level": "O2"}) # create dataset dataset_eval = create_dataset(