Skip to content

Commit

Permalink
feat: supports the augmentations splits (mindspore-lab#658)
Browse files Browse the repository at this point in the history
  • Loading branch information
sageyou authored Jul 10, 2023
1 parent e32643a commit 2f24f90
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 31 deletions.
3 changes: 3 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def create_parser():
'"trivialaugwide" for TrivialAugmentWide. '
'If apply, recommend for imagenet: randaug-m7-mstd0.5 (default=None).'
'Example: "randaug-m10-n2-w0-mstd0.5-mmax10-inc0", "autoaug-mstd0.5" or autoaugr-mstd0.5.')
group.add_argument('--aug_splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 3 (currently, only support 3 splits))'
'it should be set with one auto_augment')
group.add_argument('--re_prob', type=float, default=0.0,
help='Probability of performing erasing (default=0.0)')
group.add_argument('--re_scale', type=tuple, default=(0.02, 0.33),
Expand Down
135 changes: 120 additions & 15 deletions mindcv/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
Create dataloader
"""

import inspect
import warnings

import numpy as np

import mindspore as ms
from mindspore.dataset import transforms

Expand All @@ -26,6 +29,7 @@ def create_loader(
target_transform=None,
num_parallel_workers=None,
python_multiprocessing=False,
separate=False,
):
r"""Creates dataloader.
Expand Down Expand Up @@ -54,6 +58,7 @@ def create_loader(
(default=None).
python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker processes. This
option could be beneficial if the Python operation is computational heavy (default=False).
separate(bool, optional): separate the image origin and the image been transformed
Note:
1. cutmix is now experimental (which means performance gain is not guarantee)
Expand All @@ -66,20 +71,6 @@ def create_loader(
BatchDataset, dataset batched.
"""

if transform is None:
warnings.warn(
"Using None as the default value of transform will set it back to "
"traditional image transform, which is not recommended. "
"You should explicitly call `create_transforms` and pass it to `create_loader`."
)
transform = create_transforms("imagenet", is_training=False)
dataset = dataset.map(
operations=transform,
input_columns="image",
num_parallel_workers=num_parallel_workers,
python_multiprocessing=python_multiprocessing,
)

if target_transform is None:
target_transform = transforms.TypeCast(ms.int32)
target_input_columns = "label" if "label" in dataset.get_col_names() else "fine_label"
Expand All @@ -90,7 +81,47 @@ def create_loader(
python_multiprocessing=python_multiprocessing,
)

dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder)
if transform is None:
warnings.warn(
"Using None as the default value of transform will set it back to "
"traditional image transform, which is not recommended. "
"You should explicitly call `create_transforms` and pass it to `create_loader`."
)
transform = create_transforms("imagenet", is_training=False)

# only apply augment splits to train dataset
if separate and is_training:
assert isinstance(transform, tuple) and len(transform) == 3

# Note: mindspore-2.0 delete the parameter column_order
sig = inspect.signature(dataset.map)
pass_column_order = False if "kwargs" in sig.parameters else True

# map all the transform
dataset = map_transform_splits(
dataset, transform, num_parallel_workers, python_multiprocessing, pass_column_order
)
# after batch, datasets has 4 columns
dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder)
# concat the 3 columns of image
dataset = dataset.map(
operations=concat_per_batch_map,
input_columns=["image_clean", "image_aug1", "image_aug2", "label"],
output_columns=["image", "label"],
column_order=["image", "label"] if pass_column_order else None,
num_parallel_workers=num_parallel_workers,
python_multiprocessing=python_multiprocessing,
)

else:
dataset = dataset.map(
operations=transform,
input_columns="image",
num_parallel_workers=num_parallel_workers,
python_multiprocessing=python_multiprocessing,
)

dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder)

if is_training:
if (mixup + cutmix > 0.0) and batch_size > 1:
Expand All @@ -113,3 +144,77 @@ def create_loader(
)

return dataset


def map_transform_splits(dataset, transform, num_parallel_workers, python_multiprocessing, pass_column_order):
# map the primary_tfl such as to all the images
dataset = dataset.map(
operations=transform[0],
input_columns="image",
num_parallel_workers=num_parallel_workers,
python_multiprocessing=python_multiprocessing,
)

# duplicate the columns 'image' twice for the auto_augmentation
dataset = dataset.map(
operations=transforms.Duplicate(),
input_columns=["image"],
output_columns=["image_clean", "image_aug2"],
column_order=["image_clean", "image_aug2", "label"] if pass_column_order else None,
num_parallel_workers=num_parallel_workers,
python_multiprocessing=python_multiprocessing,
)
dataset = dataset.map(
operations=transforms.Duplicate(),
input_columns=["image_clean"],
output_columns=["image_clean", "image_aug1"],
column_order=["image_clean", "image_aug1", "image_aug2", "label"] if pass_column_order else None,
num_parallel_workers=num_parallel_workers,
python_multiprocessing=python_multiprocessing,
)

# map the secondary_tfl (auto_augmentation for the image_aug1 and img_aug2)
dataset = dataset.map(
operations=transform[1],
input_columns="image_aug1",
num_parallel_workers=num_parallel_workers,
python_multiprocessing=python_multiprocessing,
)

dataset = dataset.map(
operations=transform[1],
input_columns="image_aug2",
num_parallel_workers=num_parallel_workers,
python_multiprocessing=python_multiprocessing,
)

# map the final_tfl to all the images

dataset = dataset.map(
operations=transform[2],
input_columns="image_clean",
num_parallel_workers=num_parallel_workers,
python_multiprocessing=python_multiprocessing,
)

dataset = dataset.map(
operations=transform[2],
input_columns="image_aug1",
num_parallel_workers=num_parallel_workers,
python_multiprocessing=python_multiprocessing,
)

dataset = dataset.map(
operations=transform[2],
input_columns="image_aug2",
num_parallel_workers=num_parallel_workers,
python_multiprocessing=python_multiprocessing,
)

return dataset


def concat_per_batch_map(image_clean, image_aug1, image_aug2, label):
image = np.concatenate((image_clean, image_aug1, image_aug2))
label = np.concatenate((label, label, label))
return image, label
32 changes: 20 additions & 12 deletions mindcv/data/transforms_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def transforms_imagenet_train(
re_ratio=(0.3, 3.3),
re_value=0,
re_max_attempts=10,
separate=False,
):
"""Transform operation list when training on ImageNet."""
# Define map operations for training dataset
Expand All @@ -44,7 +45,7 @@ def transforms_imagenet_train(
else:
interpolation = Inter.BILINEAR

trans_list = [
primary_tfl = [
vision.RandomCropDecodeResize(
size=image_resize,
scale=scale,
Expand All @@ -53,10 +54,11 @@ def transforms_imagenet_train(
)
]
if hflip > 0.0:
trans_list += [vision.RandomHorizontalFlip(prob=hflip)]
primary_tfl += [vision.RandomHorizontalFlip(prob=hflip)]
if vflip > 0.0:
trans_list += [vision.RandomVerticalFlip(prob=vflip)]
primary_tfl += [vision.RandomVerticalFlip(prob=vflip)]

secondary_tfl = []
if auto_augment is not None:
assert isinstance(auto_augment, str)
if isinstance(image_resize, (tuple, list)):
Expand All @@ -69,14 +71,14 @@ def transforms_imagenet_train(
)
augement_params["interpolation"] = interpolation
if auto_augment.startswith("randaug"):
trans_list += [rand_augment_transform(auto_augment, augement_params)]
secondary_tfl += [rand_augment_transform(auto_augment, augement_params)]
elif auto_augment.startswith("autoaug") or auto_augment.startswith("3a"):
trans_list += [auto_augment_transform(auto_augment, augement_params)]
secondary_tfl += [auto_augment_transform(auto_augment, augement_params)]
elif auto_augment.startswith("trivialaugwide"):
trans_list += [trivial_augment_wide_transform(auto_augment, augement_params)]
secondary_tfl += [trivial_augment_wide_transform(auto_augment, augement_params)]
elif auto_augment.startswith("augmix"):
augement_params["translate_pct"] = 0.3
trans_list += [augment_and_mix_transform(auto_augment, augement_params)]
secondary_tfl += [augment_and_mix_transform(auto_augment, augement_params)]
else:
assert False, "Unknown auto augment policy (%s)" % auto_augment
elif color_jitter is not None:
Expand All @@ -86,14 +88,15 @@ def transforms_imagenet_train(
assert len(color_jitter) in (3, 4)
else:
color_jitter = (float(color_jitter),) * 3
trans_list += [vision.RandomColorAdjust(*color_jitter)]
secondary_tfl += [vision.RandomColorAdjust(*color_jitter)]

trans_list += [
final_tfl = []
final_tfl += [
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW(),
]
if re_prob > 0.0:
trans_list.append(
final_tfl.append(
vision.RandomErasing(
prob=re_prob,
scale=re_scale,
Expand All @@ -103,7 +106,9 @@ def transforms_imagenet_train(
)
)

return trans_list
if separate:
return primary_tfl, secondary_tfl, final_tfl
return primary_tfl + secondary_tfl + final_tfl


def transforms_imagenet_eval(
Expand Down Expand Up @@ -179,6 +184,7 @@ def create_transforms(
image_resize=224,
is_training=False,
auto_augment=None,
separate=False,
**kwargs,
):
r"""Creates a list of transform operation on image data.
Expand All @@ -189,6 +195,8 @@ def create_transforms(
Default: ''.
image_resize (int): the image size after resize for adapting to network. Default: 224.
is_training (bool): if True, augmentation will be applied if support. Default: False.
auto_augment(str):augmentation strategies, such as "augmix", "autoaug" etc.
separate: separate the image origin and the image been transformed.
**kwargs: additional args parsed to `transforms_imagenet_train` and `transforms_imagenet_eval`
Returns:
Expand All @@ -200,7 +208,7 @@ def create_transforms(
if dataset_name in ("imagenet", ""):
trans_args = dict(image_resize=image_resize, **kwargs)
if is_training:
return transforms_imagenet_train(auto_augment=auto_augment, **trans_args)
return transforms_imagenet_train(auto_augment=auto_augment, separate=separate, **trans_args)

return transforms_imagenet_eval(**trans_args)
elif dataset_name in ("cifar10", "cifar100"):
Expand Down
2 changes: 1 addition & 1 deletion tests/modules/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mindcv.data import create_dataset, create_loader, get_dataset_download_root
from mindcv.utils.download import DownLoad

num_classes = 1
num_classes = 2


@pytest.mark.parametrize("mode", [0, 1])
Expand Down
19 changes: 16 additions & 3 deletions tests/modules/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
@pytest.mark.parametrize("image_resize", [224, 256])
@pytest.mark.parametrize("is_training", [True, False])
@pytest.mark.parametrize("auto_augment", [None, "autoaug", "autoaugr", "3a", "randaug", "augmix", "trivialaugwide"])
def test_transforms_standalone_imagenet(mode, name, image_resize, is_training, auto_augment):
@pytest.mark.parametrize("batch_size", [32, 64, 128])
@pytest.mark.parametrize("aug_splits", [0, 3])
def test_transforms_standalone_imagenet(mode, name, image_resize, is_training, auto_augment, batch_size, aug_splits):
"""
test transform_list API(distribute)
command: pytest -s test_transforms.py::test_transforms_standalone_imagenet
Expand All @@ -27,6 +29,7 @@ def test_transforms_standalone_imagenet(mode, name, image_resize, is_training, a
dataset_name='',
image_resize=224,
is_training=False,
auto_augment=None
**kwargs
"""
ms.set_context(mode=mode)
Expand All @@ -49,25 +52,35 @@ def test_transforms_standalone_imagenet(mode, name, image_resize, is_training, a
download=False,
)

num_aug_splits = 0
if aug_splits > 0 and auto_augment is not None:
assert aug_splits == 3, "Currently, only support 3 splits of augmentation"
num_aug_splits = aug_splits

# create transforms
transform_list = create_transforms(
dataset_name=name,
image_resize=image_resize,
is_training=is_training,
auto_augment=auto_augment,
separate=num_aug_splits > 0,
)

# load dataset
loader = create_loader(
dataset=dataset,
batch_size=32,
batch_size=batch_size,
drop_remainder=True,
is_training=is_training,
transform=transform_list,
num_parallel_workers=2,
separate=num_aug_splits > 0,
)

assert loader.output_shapes()[0][2] == image_resize, "image_resize error !"
output_shape = loader.output_shapes()
assert output_shape[0][2] == image_resize, "image_resize error !"
if num_aug_splits == 3 and is_training:
assert output_shape[0][0] == 3 * batch_size and output_shape[1][0] == 3 * batch_size, "augment splits error!"


# test mnist cifar10
Expand Down
8 changes: 8 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ def train(args):
num_classes = args.num_classes

# create transforms
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits == 3, "Currently, only support 3 splits of augmentation"
assert args.auto_augment is not None, "aug_splits should be set with one auto_augment"
num_aug_splits = args.aug_splits

transform_list = create_transforms(
dataset_name=args.dataset,
is_training=True,
Expand All @@ -89,6 +95,7 @@ def train(args):
re_ratio=args.re_ratio,
re_value=args.re_value,
re_max_attempts=args.re_max_attempts,
separate=num_aug_splits > 0,
)

# load dataset
Expand All @@ -103,6 +110,7 @@ def train(args):
num_classes=num_classes,
transform=transform_list,
num_parallel_workers=args.num_parallel_workers,
separate=num_aug_splits > 0,
)

if args.val_while_train:
Expand Down

0 comments on commit 2f24f90

Please sign in to comment.