Official PyTorch implementation of paper :
"Dataset Distillation in Large Data Era"
Zeyuan Yin, Zhiqiang Shen
MBZUAI
Dataset distillation aims to generate a smaller but representative subset from a large dataset, which allows a model to be trained efficiently, meanwhile evaluating on the original testing data distribution to achieve decent performance. Many prior works have aimed to align with diverse aspects of the original datasets, such as matching the training weight trajectories, gradient, feature/BatchNorm distributions, etc.
In this work, we show how to distill various large-scale datasets such as full ImageNet-1K/21K under a conventional input resolution of 224
-
Python Environment
- Python 3.8
- CUDA 11.7
- torch 1.13.1
- torchvision 0.14.1
-
Hardware
- NVIDIA RTX 4090 24GB GPU
- 4x NVIDIA A100 40GB GPUs (recommended for squeezing ImageNet-21K)
-
Tiny-ImageNet
We adapt the official torchvision classification code to train squeezing models on Tiny-ImageNet. You can find the training code and checkpoints at tiny-imagenet repo. -
ImageNet-1K
We use the off-the-shelf PyTorch’s pretrained models with IMAGENET1K_V1 weight as squeezing models freely. -
ImageNet-21K
We follow the ImageNet-21K-P to train squeezing models on ImageNet-21K (Winter 2021 version). You can find the checkpoints at .
-
Tiny-ImageNet
python recover_cda_tiny.py \ --arch-name "resnet18" \ --arch-path 'path/to/squeezed_model.pth' \ --exp-name "cda_tiny_rn18E50_4K_ipc50" \ --syn-data-path './syn-data' \ --batch-size 100 \ --lr 0.4 \ --r-bn 0.05 \ --iteration 4000 \ --store-best-images \ --easy2hard-mode "cosine" --milestone 1 \ --ipc-start 0 --ipc-end 50 \
It will take about 4.5 hours to recover the distilled Tiny-ImageNet with IPC50 on a single 4090 GPU.
-
ImageNet-1K
python recover_cda_in1k.py \ --arch-name "resnet18" \ --exp-name "cda_in1k_rn18_4K_ipc50" \ --syn-data-path './syn-data' \ --batch-size 100 \ --lr 0.25 \ --r-bn 0.01 \ --iteration 4000 \ --store-best-images \ --easy2hard-mode "cosine" --milestone 1 \ --ipc-start 0 --ipc-end 50
It will take about 29 hours to recover the distilled ImageNet-1K with IPC50 on a single 4090 GPU.
-
ImgaNet-21K
python recover_cda_in21k.py \ --arch-name "resnet18" \ --arch-path 'path/to/squeezed_model.pth' \ --exp-name "cda_in21k_rn18E80_2K_ipc20" \ --syn-data-path './syn-data' \ --batch-size 100 \ --lr 0.05 \ --r-bn 0.25 \ --iteration 2000 \ --store-best-images \ --easy2hard-mode "cosine" --milestone 1 \ --ipc-start 0 --ipc-end 20
It will take about 55 hours to recover the distilled ImageNet-21K with IPC20 on 4x 4090 GPUs.
We follow SRe2L relabeling method and use the above squeezing model to relabel distilled datasets.
Our Top-1 accuracy (%) under different IPC settings on Tiny-ImageNet, ImageNet-1K and ImageNet-21K datasets:
We present a comparative visualization of the gradient synthetic images at recovery steps of {100, 500, 1,000, 2,000} to illustrate the differences between SRe2L (upper) and our CDA (lower) within the dataset distillation process.
You can download distilled data from .
dataset | resolution | iteration | IPC | files |
---|---|---|---|---|
Tiny-ImageNet-200 | 64x64 | 1K | 50 | images |
ImageNet-1K | 224x224 | 4K | 200 | images |
ImageNet-21K | 224x224 | 2K | 20 | images |
@article{yin2023dataset,
title={Dataset Distillation in Large Data Era},
author={Yin, Zeyuan and Shen, Zhiqiang},
journal={arXiv preprint arXiv:2311.18838},
year={2023}
}