diff --git a/config.py b/config.py index 87e18751..e0208b07 100644 --- a/config.py +++ b/config.py @@ -92,6 +92,9 @@ def create_parser(): '"trivialaugwide" for TrivialAugmentWide. ' 'If apply, recommend for imagenet: randaug-m7-mstd0.5 (default=None).' 'Example: "randaug-m10-n2-w0-mstd0.5-mmax10-inc0", "autoaug-mstd0.5" or autoaugr-mstd0.5.') + group.add_argument('--aug_splits', type=int, default=0, + help='Number of augmentation splits (default: 0, valid: 3 (currently, only support 3 splits))' + 'it should be set with one auto_augment') group.add_argument('--re_prob', type=float, default=0.0, help='Probability of performing erasing (default=0.0)') group.add_argument('--re_scale', type=tuple, default=(0.02, 0.33), diff --git a/mindcv/data/loader.py b/mindcv/data/loader.py index 69e966bc..5d56c907 100644 --- a/mindcv/data/loader.py +++ b/mindcv/data/loader.py @@ -2,8 +2,11 @@ Create dataloader """ +import inspect import warnings +import numpy as np + import mindspore as ms from mindspore.dataset import transforms @@ -26,6 +29,7 @@ def create_loader( target_transform=None, num_parallel_workers=None, python_multiprocessing=False, + separate=False, ): r"""Creates dataloader. @@ -54,6 +58,7 @@ def create_loader( (default=None). python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker processes. This option could be beneficial if the Python operation is computational heavy (default=False). + separate(bool, optional): separate the image origin and the image been transformed Note: 1. cutmix is now experimental (which means performance gain is not guarantee) @@ -66,20 +71,6 @@ def create_loader( BatchDataset, dataset batched. """ - if transform is None: - warnings.warn( - "Using None as the default value of transform will set it back to " - "traditional image transform, which is not recommended. " - "You should explicitly call `create_transforms` and pass it to `create_loader`." - ) - transform = create_transforms("imagenet", is_training=False) - dataset = dataset.map( - operations=transform, - input_columns="image", - num_parallel_workers=num_parallel_workers, - python_multiprocessing=python_multiprocessing, - ) - if target_transform is None: target_transform = transforms.TypeCast(ms.int32) target_input_columns = "label" if "label" in dataset.get_col_names() else "fine_label" @@ -90,7 +81,47 @@ def create_loader( python_multiprocessing=python_multiprocessing, ) - dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder) + if transform is None: + warnings.warn( + "Using None as the default value of transform will set it back to " + "traditional image transform, which is not recommended. " + "You should explicitly call `create_transforms` and pass it to `create_loader`." + ) + transform = create_transforms("imagenet", is_training=False) + + # only apply augment splits to train dataset + if separate and is_training: + assert isinstance(transform, tuple) and len(transform) == 3 + + # Note: mindspore-2.0 delete the parameter column_order + sig = inspect.signature(dataset.map) + pass_column_order = False if "kwargs" in sig.parameters else True + + # map all the transform + dataset = map_transform_splits( + dataset, transform, num_parallel_workers, python_multiprocessing, pass_column_order + ) + # after batch, datasets has 4 columns + dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder) + # concat the 3 columns of image + dataset = dataset.map( + operations=concat_per_batch_map, + input_columns=["image_clean", "image_aug1", "image_aug2", "label"], + output_columns=["image", "label"], + column_order=["image", "label"] if pass_column_order else None, + num_parallel_workers=num_parallel_workers, + python_multiprocessing=python_multiprocessing, + ) + + else: + dataset = dataset.map( + operations=transform, + input_columns="image", + num_parallel_workers=num_parallel_workers, + python_multiprocessing=python_multiprocessing, + ) + + dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder) if is_training: if (mixup + cutmix > 0.0) and batch_size > 1: @@ -113,3 +144,77 @@ def create_loader( ) return dataset + + +def map_transform_splits(dataset, transform, num_parallel_workers, python_multiprocessing, pass_column_order): + # map the primary_tfl such as to all the images + dataset = dataset.map( + operations=transform[0], + input_columns="image", + num_parallel_workers=num_parallel_workers, + python_multiprocessing=python_multiprocessing, + ) + + # duplicate the columns 'image' twice for the auto_augmentation + dataset = dataset.map( + operations=transforms.Duplicate(), + input_columns=["image"], + output_columns=["image_clean", "image_aug2"], + column_order=["image_clean", "image_aug2", "label"] if pass_column_order else None, + num_parallel_workers=num_parallel_workers, + python_multiprocessing=python_multiprocessing, + ) + dataset = dataset.map( + operations=transforms.Duplicate(), + input_columns=["image_clean"], + output_columns=["image_clean", "image_aug1"], + column_order=["image_clean", "image_aug1", "image_aug2", "label"] if pass_column_order else None, + num_parallel_workers=num_parallel_workers, + python_multiprocessing=python_multiprocessing, + ) + + # map the secondary_tfl (auto_augmentation for the image_aug1 and img_aug2) + dataset = dataset.map( + operations=transform[1], + input_columns="image_aug1", + num_parallel_workers=num_parallel_workers, + python_multiprocessing=python_multiprocessing, + ) + + dataset = dataset.map( + operations=transform[1], + input_columns="image_aug2", + num_parallel_workers=num_parallel_workers, + python_multiprocessing=python_multiprocessing, + ) + + # map the final_tfl to all the images + + dataset = dataset.map( + operations=transform[2], + input_columns="image_clean", + num_parallel_workers=num_parallel_workers, + python_multiprocessing=python_multiprocessing, + ) + + dataset = dataset.map( + operations=transform[2], + input_columns="image_aug1", + num_parallel_workers=num_parallel_workers, + python_multiprocessing=python_multiprocessing, + ) + + dataset = dataset.map( + operations=transform[2], + input_columns="image_aug2", + num_parallel_workers=num_parallel_workers, + python_multiprocessing=python_multiprocessing, + ) + + return dataset + + +def concat_per_batch_map(image_clean, image_aug1, image_aug2, label): + image = np.concatenate((image_clean, image_aug1, image_aug2)) + label = np.concatenate((label, label, label)) + return image, label diff --git a/mindcv/data/transforms_factory.py b/mindcv/data/transforms_factory.py index 85ddf0e7..e30e21cb 100644 --- a/mindcv/data/transforms_factory.py +++ b/mindcv/data/transforms_factory.py @@ -36,6 +36,7 @@ def transforms_imagenet_train( re_ratio=(0.3, 3.3), re_value=0, re_max_attempts=10, + separate=False, ): """Transform operation list when training on ImageNet.""" # Define map operations for training dataset @@ -44,7 +45,7 @@ def transforms_imagenet_train( else: interpolation = Inter.BILINEAR - trans_list = [ + primary_tfl = [ vision.RandomCropDecodeResize( size=image_resize, scale=scale, @@ -53,10 +54,11 @@ def transforms_imagenet_train( ) ] if hflip > 0.0: - trans_list += [vision.RandomHorizontalFlip(prob=hflip)] + primary_tfl += [vision.RandomHorizontalFlip(prob=hflip)] if vflip > 0.0: - trans_list += [vision.RandomVerticalFlip(prob=vflip)] + primary_tfl += [vision.RandomVerticalFlip(prob=vflip)] + secondary_tfl = [] if auto_augment is not None: assert isinstance(auto_augment, str) if isinstance(image_resize, (tuple, list)): @@ -69,14 +71,14 @@ def transforms_imagenet_train( ) augement_params["interpolation"] = interpolation if auto_augment.startswith("randaug"): - trans_list += [rand_augment_transform(auto_augment, augement_params)] + secondary_tfl += [rand_augment_transform(auto_augment, augement_params)] elif auto_augment.startswith("autoaug") or auto_augment.startswith("3a"): - trans_list += [auto_augment_transform(auto_augment, augement_params)] + secondary_tfl += [auto_augment_transform(auto_augment, augement_params)] elif auto_augment.startswith("trivialaugwide"): - trans_list += [trivial_augment_wide_transform(auto_augment, augement_params)] + secondary_tfl += [trivial_augment_wide_transform(auto_augment, augement_params)] elif auto_augment.startswith("augmix"): augement_params["translate_pct"] = 0.3 - trans_list += [augment_and_mix_transform(auto_augment, augement_params)] + secondary_tfl += [augment_and_mix_transform(auto_augment, augement_params)] else: assert False, "Unknown auto augment policy (%s)" % auto_augment elif color_jitter is not None: @@ -86,14 +88,15 @@ def transforms_imagenet_train( assert len(color_jitter) in (3, 4) else: color_jitter = (float(color_jitter),) * 3 - trans_list += [vision.RandomColorAdjust(*color_jitter)] + secondary_tfl += [vision.RandomColorAdjust(*color_jitter)] - trans_list += [ + final_tfl = [] + final_tfl += [ vision.Normalize(mean=mean, std=std), vision.HWC2CHW(), ] if re_prob > 0.0: - trans_list.append( + final_tfl.append( vision.RandomErasing( prob=re_prob, scale=re_scale, @@ -103,7 +106,9 @@ def transforms_imagenet_train( ) ) - return trans_list + if separate: + return primary_tfl, secondary_tfl, final_tfl + return primary_tfl + secondary_tfl + final_tfl def transforms_imagenet_eval( @@ -179,6 +184,7 @@ def create_transforms( image_resize=224, is_training=False, auto_augment=None, + separate=False, **kwargs, ): r"""Creates a list of transform operation on image data. @@ -189,6 +195,8 @@ def create_transforms( Default: ''. image_resize (int): the image size after resize for adapting to network. Default: 224. is_training (bool): if True, augmentation will be applied if support. Default: False. + auto_augment(str):augmentation strategies, such as "augmix", "autoaug" etc. + separate: separate the image origin and the image been transformed. **kwargs: additional args parsed to `transforms_imagenet_train` and `transforms_imagenet_eval` Returns: @@ -200,7 +208,7 @@ def create_transforms( if dataset_name in ("imagenet", ""): trans_args = dict(image_resize=image_resize, **kwargs) if is_training: - return transforms_imagenet_train(auto_augment=auto_augment, **trans_args) + return transforms_imagenet_train(auto_augment=auto_augment, separate=separate, **trans_args) return transforms_imagenet_eval(**trans_args) elif dataset_name in ("cifar10", "cifar100"): diff --git a/tests/modules/test_loader.py b/tests/modules/test_loader.py index 764b3b19..4ba30139 100644 --- a/tests/modules/test_loader.py +++ b/tests/modules/test_loader.py @@ -11,7 +11,7 @@ from mindcv.data import create_dataset, create_loader, get_dataset_download_root from mindcv.utils.download import DownLoad -num_classes = 1 +num_classes = 2 @pytest.mark.parametrize("mode", [0, 1]) diff --git a/tests/modules/test_transforms.py b/tests/modules/test_transforms.py index e10070f2..5f5936d5 100644 --- a/tests/modules/test_transforms.py +++ b/tests/modules/test_transforms.py @@ -18,7 +18,9 @@ @pytest.mark.parametrize("image_resize", [224, 256]) @pytest.mark.parametrize("is_training", [True, False]) @pytest.mark.parametrize("auto_augment", [None, "autoaug", "autoaugr", "3a", "randaug", "augmix", "trivialaugwide"]) -def test_transforms_standalone_imagenet(mode, name, image_resize, is_training, auto_augment): +@pytest.mark.parametrize("batch_size", [32, 64, 128]) +@pytest.mark.parametrize("aug_splits", [0, 3]) +def test_transforms_standalone_imagenet(mode, name, image_resize, is_training, auto_augment, batch_size, aug_splits): """ test transform_list API(distribute) command: pytest -s test_transforms.py::test_transforms_standalone_imagenet @@ -27,6 +29,7 @@ def test_transforms_standalone_imagenet(mode, name, image_resize, is_training, a dataset_name='', image_resize=224, is_training=False, + auto_augment=None **kwargs """ ms.set_context(mode=mode) @@ -49,25 +52,35 @@ def test_transforms_standalone_imagenet(mode, name, image_resize, is_training, a download=False, ) + num_aug_splits = 0 + if aug_splits > 0 and auto_augment is not None: + assert aug_splits == 3, "Currently, only support 3 splits of augmentation" + num_aug_splits = aug_splits + # create transforms transform_list = create_transforms( dataset_name=name, image_resize=image_resize, is_training=is_training, auto_augment=auto_augment, + separate=num_aug_splits > 0, ) # load dataset loader = create_loader( dataset=dataset, - batch_size=32, + batch_size=batch_size, drop_remainder=True, is_training=is_training, transform=transform_list, num_parallel_workers=2, + separate=num_aug_splits > 0, ) - assert loader.output_shapes()[0][2] == image_resize, "image_resize error !" + output_shape = loader.output_shapes() + assert output_shape[0][2] == image_resize, "image_resize error !" + if num_aug_splits == 3 and is_training: + assert output_shape[0][0] == 3 * batch_size and output_shape[1][0] == 3 * batch_size, "augment splits error!" # test mnist cifar10 diff --git a/train.py b/train.py index 80a0db6f..ed22748d 100644 --- a/train.py +++ b/train.py @@ -71,6 +71,12 @@ def train(args): num_classes = args.num_classes # create transforms + num_aug_splits = 0 + if args.aug_splits > 0: + assert args.aug_splits == 3, "Currently, only support 3 splits of augmentation" + assert args.auto_augment is not None, "aug_splits should be set with one auto_augment" + num_aug_splits = args.aug_splits + transform_list = create_transforms( dataset_name=args.dataset, is_training=True, @@ -89,6 +95,7 @@ def train(args): re_ratio=args.re_ratio, re_value=args.re_value, re_max_attempts=args.re_max_attempts, + separate=num_aug_splits > 0, ) # load dataset @@ -103,6 +110,7 @@ def train(args): num_classes=num_classes, transform=transform_list, num_parallel_workers=args.num_parallel_workers, + separate=num_aug_splits > 0, ) if args.val_while_train: