Skip to content
/ DSP Public

PyTorch implementation of "Dynamic Structure Pruning for Compressing CNNs" (AAAI 2023 Oral)

License

Notifications You must be signed in to change notification settings

irishev/DSP

Repository files navigation

Dynamic Structure Pruning

Dynamic Structure Pruning for Compressing CNNs [AAAI 2023]

Jun-Hyung Park, Yeachan Kim, Junho Kim, Joon-Young Choi, and SangKeun Lee

Generate compact and efficient CNNs based on grouping and pruning that transform single-branch convolutional layers into multi-branch convolutional layers [AAAI] [arXiv]

Introduction

Dynamic Structure Pruning automatically learns intra-channel sparsity by optimizing filter groups and regularizing group channels. This allows for higher efficiency from fine-grained pruing granularities, while producing generally accelerable (i.e., regular) structures.

Requirements

  • Python 3.7
  • PyTorch 1.10.0
  • TorchVision 0.11.0
  • tqdm

How to use DSP in your code

You should first train pre-trained models to learn groups and then prune and finetune the group-learned models.

Our group-learning and pruning modules require three steps.

  1. Defining a wrapper
  2. Initializing
  3. Processing after every update (step)

Following sections show code examples using our modules.

Differentiable Group Learning

from dsp_module import *

...

# After defining your model, optimizer, criterion, etc.
group_trainer = GroupWrapper(model, optimizer, criterion, regularization_power, total_num_iterations, num_groups, temperature)

...

# Training iteration
for epoch in range(args.epochs):
    for x, y in train_dataloader:
        # Before forward (model(x))
        group_trainer.initialize()
        out = model(x)
        ...

        # After model update (optimizer.step())
        group_trainer.after_step(x, y)

...

Group Channel Pruning

from dsp_module import *

...

# Before loading group-learned checkpoints
pruner = PruneWrapper(model, num_groups, fp_every_nth_conv)

# fp_every_nth_conv means it prunes filters of every nth convolution layers.
# In our paper, we prune filters of the final layer in each residual block.
# In CIFAR-10, fp_every_nth_conv = 2, and in ImageNet, fp_every_nth_conv = 2 (ResNet18) or 3 (ResNet50)
# If your models have irregular numbers of layers in each residual block, 
# you can specify the indices of layers by setting fp_layer_indices=[1, 3, 5, 8, 11, 14, ...]
# If you set both fp_every_nth_conv and fp_layer_indices, the latter is prioritized.

...

# Before training starts
flops, params = pruner.initialize(pruning_rate)

# Training iteration
for epoch in range(args.epochs):
    for x, y in train_dataloader:
        
        ...

        # After model update (optimizer.step())
        pruner.after_step()
...

Please refer to our CIFAR-10 pruning codes (cifar_dsp.py and cifar_finetune.py) to help your understanding of our modules.

Pruning on CIFAR-10

Pretraining

# pretrain ResNet20
python cifar_pretrain.py -l 20 [--save-dir ./cifarmodel] [--epochs 164] [--batch-size 128] [--lr 0.1] [--momentum 0.9] [--wd 1e-4]

# pretrain ResNet56
python cifar_pretrain.py -l 56 [--save-dir ./cifarmodel] [--epochs 164] [--batch-size 128] [--lr 0.1] [--momentum 0.9] [--wd 1e-4]

Differentiable Group Learning

# ResNet20 with group 4, lambda=2e-3
python cifar_dsp.py -l 20 -g 4 -r 2e-3

# ResNet20 with group 2, lambda=2e-3
python cifar_dsp.py -l 20 -g 2 -r 2e-3

# ResNet56 with group 4, lambda=5e-4
python cifar_dsp.py -l 56 -g 4 -r 5e-4

Group Channel Pruning

# ResNet20 with group 4, pruning rate=0.5
python cifar_finetune.py -l 20 -g 4 -p 0.5

# ResNet56 with group 4, pruning rate=0.5
python cifar_finetune.py -l 56 -g 4 -p 0.5

Packing Pruned Models

python pack_model.py --ckpt [pruned_model_path] --save [save_path]

We upload checkpoints that show the median accuracy of five runs

Model ACC P.FLOPS P.PARAMS CKPT
ResNet20 (g=4) 92.22 63.57 50.45 Link
ResNet20 (g=3) 92.14 62.43 49.15 Link
ResNet20 (g=2) 92.07 61.23 48.35 Link
ResNet56 (g=4) 94.25 65.11 56.30 Link
ResNet56 (g=3) 94.07 64.14 55.01 Link
ResNet56 (g=2) 93.99 63.26 56.24 Link

How to use checkpoints

import torch
cnn = torch.jit.load('[CKPT_PATH]')
# You can use TensorRT or torch.jit.optimize_for_inference to achieve further acceleration.

We slightly changed the implementation of regularization scaling to obtain better speedup.

As a result, pruned results may be different from those in the paper (usually more pruned FLOPS and fewer pruned parameters).

Pruning on ImageNet

Pruned Models

Model Top-1 ACC P.FLOPS P.PARAMS CKPT
ResNet18 (g=2) 69.55
67.95
65.81
60.08
70.05
80.00
45.70
57.25
69.92
Link
Link
Link
ResNet50 (g=2) 76.54
75.45
73.29
70.00
80.03
90.02
52.54
66.00
81.64
Link
Link
Link

TODO

  • Implement model-agnostic pruner
  • Release ImageNet models

About

PyTorch implementation of "Dynamic Structure Pruning for Compressing CNNs" (AAAI 2023 Oral)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages