Skip to content

PyTorch implementation of Soft MoE by Google Brain in "From Sparse to Soft Mixtures of Experts" (https://arxiv.org/pdf/2308.00951.pdf)

License

Notifications You must be signed in to change notification settings

fkodom/soft-mixture-of-experts

Repository files navigation

soft-mixture-of-experts

PyTorch implementation of Soft MoE by Google Brain in From Sparse to Soft Mixtures of Experts

soft-moe-layer

Thanks to lucidrains for his excellent x-transformers library! 🎉

The ViT implementations here are heavily based on his ViTransformerWrapper.

TODO

  • Implement Soft MoE layer (Usage, Code)
  • Example end-to-end Transformer models
    • vision transformer (Usage, Code)
    • language model (skip for now)
    • add to README
  • Set up unit tests
    • SoftMoE
    • Transformer layers
    • ViT models
  • Reproduce parameter counts from Table 3 (Ablations)
  • Reproduce inference benchmarks from Tables 1, 2 (Ablations)
  • Release on PyPI
    • Prerelease
    • Stable

Install

PyPI:

pip install soft-mixture-of-experts

From source:

pip install "soft-mixture-of-experts @ git+ssh://git@github.com/fkodom/soft-mixture-of-experts.git"

For contributors:

# Clone/fork this repo. Example:
gh repo clone fkodom/soft-mixture-of-experts
cd soft-mixture-of-experts
# Install all dev dependencies (tests etc.) in editable mode
pip install -e .[test]
# Setup pre-commit hooks
pre-commit install

Usage

Vision Transformers

Using the ViT and SoftMoEViT classes directly:

from soft_mixture_of_experts.vit import ViT, SoftMoEViT

vit = ViT(num_classes=1000, device="cuda")
moe_vit = SoftMoEViT(num_classes=1000, num_experts=32, device="cuda")

# image shape: (batch_size, channels, height, width)
image = torch.randn(1, 3, 224, 224, device="cuda")

# classification prediction
# output shape: (batch_size, num_classes)
y_vit = vit(image)
y_moe = moe_vit(image)

# feature embeddings
# output shape: (batch_size, num_patches, d_model)
features_vit = vit(image, return_features=True)
features_moe = moe_vit(image, return_features=True)

or using pre-configured models:

from soft_mixture_of_experts.vit import soft_moe_vit_small

# Available models:
# - soft_moe_vit_small
# - soft_moe_vit_base
# - soft_moe_vit_large
# - vit_small
# - vit_base
# - vit_large
# - vit_huge

# Roughly 930M parameters 👀
moe_vit = soft_moe_vit_small(num_classes=1000, device="cuda")

# Everything else works the same as above...

Transformer Layers

from soft_mixture_of_experts.transformer import (
    TransformerEncoder,
    TransformerEncoderLayer,
    TransformerDecoder,
    TransformerDecoderLayer,
)

encoder = TransformerEncoder(
    TransformerEncoderLayer(d_model=512, nhead=8),
    num_layers=6,
)
decoder = TransformerDecoder(
    TransformerDecoderLayer(d_model=512, nhead=8),
    num_layers=6,
)

# input shape: (batch_size, seq_len, d_model)
x = torch.randn(2, 128, 512, device="cuda")

mem = encoder(x)
print(mem.shape)
# torch.Size([2, 128, 512])

y = decoder(x, mem)
print(y.shape)
# torch.Size([2, 128, 512])

Soft MoE

import torch

from soft_mixture_of_experts.soft_moe import SoftMoE

# SoftMoE with 32 experts, 2 slots per expert (64 total):
moe = SoftMoE(
    in_features=512,
    out_features=512,
    num_experts=32,
    slots_per_expert=2,
    bias=False,  # optional, default: True
    device="cuda",  # optional, default: None
)

# input shape: (batch_size, seq_len, embed_dim)
x = torch.randn(2, 128, 512, device="cuda")

y = moe(x)
print(y.shape)
# torch.Size([2, 128, 512])

Ablations

I closely reproduce the parameter counts and (relative) inference times from the paper.

Table 3

All models are benchmarked with:

batch_size = 8  # see note below
image_size = 224
num_channels = 3
num_classes = 21000  # as in ImageNet 21k

$\dagger$ The authors benchmark "eval ms/img" using TPUv3, and I use single A100 40GB. The authors also are not clear on the batch size used for inference. In Figure 6, they specifically mention using batch size 8. So, I assume a batch size of 8, and observe that inference times are similar to what is reported in the paper.

Model Params Params
(paper)
Eval ms/img $\dagger$ Eval ms/img
(paper)
ViT S/16 30 M 33 M 0.9 0.5
Soft MoE S/16 128E 932 M 933 M 1.3 0.7
Soft MoE S/14 128E 1.8 B 1.8 B 1.5 0.9
ViT B/16 102 M 108 M 1.0 0.9
Soft MoE B/16 128E 3.7 B 3.7 B 1.5 1.5
ViT L/16 325 M 333 M 1.8 4.9
Soft MoE L/16 128E 13.1 B 13.1 B 3.5 4.8

Test

Tests run automatically through GitHub Actions on each git push.

You can also run tests manually with pytest:

pytest

Citations

@misc{puigcerver2023sparse,
      title={From Sparse to Soft Mixtures of Experts}, 
      author={Joan Puigcerver and Carlos Riquelme and Basil Mustafa and Neil Houlsby},
      year={2023},
      eprint={2308.00951},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}