The optimal augmentation policy, which is the latent variable, cannot be directly observed.
- LatentAugment estimates the probability of latent augmentation using the EM algorithm.
- LatentAugment is simple and computationally efficient. It can estimate using the simple stochastic gradient descent algorithm without an adversarial network.
- LatentAugment has higher test accuracy than previous augmentation methods on the CIFAR-10, CIFAR-100, SVHN, and ImageNet datasets.
Figure 1. An overview of the proposed LatentAugment. The loss functions with augmentation poli- cies are calculated using the input data and the unconditional probability of augmentation policies. The model parameters are updated by the EM algorithm. In E-step, the expectation of the weighted loss function is calculated using the conditional probability of the highest loss. In M-step, the expected loss function is minimized using the standard stochastic gradient descent. The conditional probabilities of the highest loss are calculated using the loss function with the updated parameters and input data. The unconditional probabilities of the augmentation policies are generated by the moving average of the conditional probability.
Dataset | Model | Baseline | AA | AdvAA | UBS | MA | LA (proposed) |
---|---|---|---|---|---|---|---|
CIFAR-10 | Wide-ResNet-40-2 | 94.70 | 96.30 | - | - | 96.79 | 97.27 |
Wide-ResNet-28-10 | 96.13 | 97.32 | 98.10 | 97.89 | 97.76 | 98.25 | |
Shake-Shake (26 2x32d) | 96.45 | 97.53 | 97.64 | - | - | 97.68 | |
Shake-Shake (26 2x96d) | 97.14 | 98.01 | 98.15 | 98.27 | 98.29 | 98.42 | |
Shake-Shake (26 2x112d) | 97.18 | 98.11 | 98.22 | - | 98.28 | 98.44 | |
PyramidNet+ShakeDrop | 97.33 | 98.52 | 98.64 | 98.66 | 98.57 | 98.72 | |
CIFAR-100 | Wide-ResNet-40-2 | 74.00 | 79.30 | - | - | 80.60 | 80.90 |
Wide-ResNet-28-10 | 81.20 | 82.91 | 84.51 | 84.54 | 83.79 | 84.98 | |
Shake-Shake (26 2x96d) | 82.95 | 85.72 | 85.90 | - | 85.97 | 85.88 | |
SVHN | Wide-ResNet-28-10 | 98.50 | 98.93 | - | - | - | 98.96 |
ImageNet | ResNet-50 | 75.30 / 92.20 | 77.63 / 93.82 | 79.40 / 94.47 | - | 79.74 / 94.64 | 80.02 / 94.88 |
AutoAugment (AA) (Cubuk et al., 2018), Adversarial AutoAugment (AdvAA) (Zhang et al., 2019), Uncertainty-Based Sampling (UBS) (Wu et al., 2020), MetaAugment (MA) (Zhou et al., 2020), and proposed LatentAugment (LA).
wrn40x2 model with cifar10 dataset:
$ python train.py --dataset cifar10 \
--name cifar10-wrn40x2 \
--dataroot /home/user/data/ \
--checkpoint /home/user/runs/latent/ \
--num_k 6 \
--epochs 200 \
--batch-size 128 \
--lr 0.1 \
--weight-decay 0.0002 \
--model wrn \
--layers 40 \
--widen-factor 2 \
--cutmix_prob 0.5 \
--cutmix True
For other models and datasets, you can find script files in the script folder.
MIT License
@misc{kuriyama2023latentaugment,
title={LatentAugment: Dynamically Optimized Latent Probabilities of Data Augmentation},
author={Koichi Kuriyama},
year={2023},
eprint={2305.02668},
archivePrefix={arXiv},
primaryClass={cs.CV}
}