Skip to content

Latest commit

 

History

History

CDA

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 

Dataset Distillation in Large Data Era

Official PyTorch implementation of paper :

"Dataset Distillation in Large Data Era"
Zeyuan Yin, Zhiqiang Shen
MBZUAI

Table of Contents

Abstract

Dataset distillation aims to generate a smaller but representative subset from a large dataset, which allows a model to be trained efficiently, meanwhile evaluating on the original testing data distribution to achieve decent performance. Many prior works have aimed to align with diverse aspects of the original datasets, such as matching the training weight trajectories, gradient, feature/BatchNorm distributions, etc. In this work, we show how to distill various large-scale datasets such as full ImageNet-1K/21K under a conventional input resolution of 224 $\times$ 224 to achieve the best accuracy over all previous approaches, including SRe2L, TESLA and MTT. To achieve this, we introduce a simple yet effective Curriculum Data Augmentation (CDA) during data synthesis that obtains the accuracy on large-scale ImageNet-1K and 21K with 63.2% under IPC (Images Per Class) 50 and 36.1% under IPC 20, respectively. Finally, we show that, by integrating all our enhancements together, the proposed model beats the current state-of-the-art by more than 4% top-1 accuracy on ImageNet-1K and for the first time, reduces the gap to its full-data training counterpart to less than absolute 15%. Moreover, this work represents the inaugural success in dataset distillation on larger-scale ImageNet-21K under the standard 224 $\times$ 224 resolution.

Requirements

  • Python Environment

    • Python 3.8
    • CUDA 11.7
    • torch 1.13.1
    • torchvision 0.14.1
  • Hardware

    • NVIDIA RTX 4090 24GB GPU
    • 4x NVIDIA A100 40GB GPUs (recommended for squeezing ImageNet-21K)

Usage

Squeeze

Recover

  • Tiny-ImageNet

    python recover_cda_tiny.py \
    --arch-name "resnet18" \
    --arch-path 'path/to/squeezed_model.pth' \
    --exp-name "cda_tiny_rn18E50_4K_ipc50" \
    --syn-data-path './syn-data' \
    --batch-size 100 \
    --lr 0.4 \
    --r-bn 0.05 \
    --iteration 4000 \
    --store-best-images \
    --easy2hard-mode "cosine" --milestone 1 \
    --ipc-start 0 --ipc-end 50 \

    It will take about 4.5 hours to recover the distilled Tiny-ImageNet with IPC50 on a single 4090 GPU.

  • ImageNet-1K

    python recover_cda_in1k.py \
    --arch-name "resnet18" \
    --exp-name "cda_in1k_rn18_4K_ipc50" \
    --syn-data-path './syn-data' \
    --batch-size 100 \
    --lr 0.25 \
    --r-bn 0.01 \
    --iteration 4000 \
    --store-best-images \
    --easy2hard-mode "cosine" --milestone 1 \
    --ipc-start 0 --ipc-end 50

    It will take about 29 hours to recover the distilled ImageNet-1K with IPC50 on a single 4090 GPU.

  • ImgaNet-21K

    python recover_cda_in21k.py \
    --arch-name "resnet18" \
    --arch-path 'path/to/squeezed_model.pth' \
    --exp-name "cda_in21k_rn18E80_2K_ipc20" \
    --syn-data-path './syn-data' \
    --batch-size 100 \
    --lr 0.05 \
    --r-bn 0.25 \
    --iteration 2000 \
    --store-best-images \
    --easy2hard-mode "cosine" --milestone 1 \
    --ipc-start 0 --ipc-end 20

    It will take about 55 hours to recover the distilled ImageNet-21K with IPC20 on 4x 4090 GPUs.

Relabel

We follow SRe2L relabeling method and use the above squeezing model to relabel distilled datasets.

Results

Our Top-1 accuracy (%) under different IPC settings on Tiny-ImageNet, ImageNet-1K and ImageNet-21K datasets:

We present a comparative visualization of the gradient synthetic images at recovery steps of {100, 500, 1,000, 2,000} to illustrate the differences between SRe2L (upper) and our CDA (lower) within the dataset distillation process.

Download

You can download distilled data from Hugging Face Datasets.

dataset resolution iteration IPC files
Tiny-ImageNet-200 64x64 1K 50 images
ImageNet-1K 224x224 4K 200 images
ImageNet-21K 224x224 2K 20 images

Citation

@article{yin2023dataset,
  title={Dataset Distillation in Large Data Era},
  author={Yin, Zeyuan and Shen, Zhiqiang},
  journal={arXiv preprint arXiv:2311.18838},
  year={2023}
}