This is the code repository of the paper:
Adaptive Length Image Tokenization via Recurrent Allocation
Shivam Duggal, Phillip Isola, Antonio Torralba, William T. Freeman
MIT CSAIL
Abstract
Approach Overview
Setup
Datasets
Pretrained Checkpoints
Training
Evaluation
Citation
mamba env create -f environment.yaml
mamba activate adaptive_representations
Training the adaptive tokenizer requires pretrained checkpoints of the base 2D image tokenizers. We use VQGAN or VAE as the base tokenizers. We acknowldege Mage / Mar for releasing Imagenet-trained checkpoints of VQGAN / VAE. Run the following to download the pretrained base tokenizers at base_tokenizers/pretrained_models
python base_tokenizers/pretrained_models/download.py
To use a custom base tokenizer, add the tokenizer code in base_tokenizers
, corresponding pretrained checkpoint in base_tokenizers/pretrained_models
and a wrapper in modules/base_tokenizers.py
. See VQGANWrapper
or LDMVAEWrapper
for reference.
We mainly used ImageNet and ImageNet100 (subset of ImageNet) for training. Download ImageNet dataset and place it in $IMAGENET_DIR. To create the ImageNet100 sybset, run the following:
python run_scripts/create_imagenet100.py --imagenet_dir $IMAGENET_DIR --imagenet100_dir datasets/imagenet100/
set -x IMAGENET100_DIR datasets/imagenet100/
We also evaluated ALIT on COCO val2017, NYUv2, Wikipedia Image-Text (WIT) datasets.
Download the required checkpoint and place it at adaptive_tokenizers/pretrained_models/imagenet100/
. Optinally run the following to download all the models:
python adaptive_tokenizers/pretrained_models/download.py
Figure 9. of the paper clearly shows the power of scaling the adaptive tokenizer to larger model sizes, longer training and larger datasets. Due to compute resources in an academic setting, we are not able to do that. Please feel free to reach out if you're interested in scaling up the approach and we would be happy to help!
Adaptive Tokenizer | Base Tokenizer | Dataset | Latent Quantization | Latent Factorization | Pretrained Checkpoint |
---|---|---|---|---|---|
alit_small | vqgan | ImageNet100 | Download Link | ||
alit_base | vqgan | ImageNet100 | Download Link | ||
alit_semilarge | vqgan | ImageNet100 | Download Link | ||
alit_small | vqgan | ImageNet100 | Download Link | ||
alit_small | vae | ImageNet100 | Download Link | ||
alit_small | vae | ImageNet100 | Download Link |
ALIT is trained in two stages – latent distillation pretrain
and full finetuning (with gan loss)
.
We train the latent-distillation encoder / decoder modules in this stage, keeping image encoder / decoder fixed.
set -x TRAIN_DATA_DIR $IMAGENET100_DIR # Set to $IMAGENET_DIR or some other dataset to change the training dataset.
bash run_scripts/latent_distillation_pretrain.sh
Reference guide for adaptive tokenizer arguments:
--base_tokenizer
selects 2D Image Tokenizer, current options include vqgan or vae.--model
selects the adaptive tokenizer configurations. Options:alit_tiny | alit_small | alit_base | alit_semilarge
. Note: our semilarge is smaller than usual ViT large with 24 layers.--quantize_latent
leads to quantization of the learned 1D tokens before decoding (this helps create compressed image representations).--factorize_latent
performs feature dimension factorization of the learned 1D tokens before quantization. If--quantize_latent
is set True,--factorize_latent
will be set True automatically.- For rest of the arguments, please refer (and directly edit) the config files at
adaptive_tokenizers/configs/adaptive_vqgan.yaml
andadaptive_tokenizers/configs/adaptive_vae.yaml
. - See
--output_dir
for training logs and checkpoints.
Performs full finetuning of the latent-distillation encoder / decoder and image encoder / decoder with gan losses.
bash run_scripts/full_finetuning.sh
--finetune
loads the checkpoint trained in the previous stage (set the argument to the corresponding path accordingly).- See
--output_dir
for training logs and checkpoints.
The following command will load your trained model checkpoint for alit_small
with vqgan
base tokenizer and quantize_latent=True
.
python evaluate_rfid.py \
--model alit_small \
--base_tokenizer vqgan \
--quantize_latent \
--output_dir ./output_dir/full_finetuning/alit_small_vqgan_quantized_latents/ \
--ckpt ./output_dir/full_finetuning/alit_small_vqgan_quantized_latents/checkpoint-last.pth \
--data_path $TRAIN_DATA_DIR
To evaluate on any other custom dataset, different from the training set, simply change the – data_path to your $CUSTOM_DATA_DIR
. The code requires the $CUSTOM_DATA_DIR
to have a val
folder containing the images to be evaluated.
Create a folder with arbitrary name at assets/custom_images/
and place the custom images inside it. See assets/custom_images/birds
for reference.
python evaluate_rfid.py \
--model alit_small \
--base_tokenizer vae \
--output_dir ./output_dir/custom_images/alit_small_vae_continuous_latents/ \
--ckpt adaptive_tokenizers/pretrained_models/imagenet100/alit_small_vae_continuous_latents.pth \
--data_path ./assets/custom_images/ \
--testing_custom_images
If you can support compute resources for scaling Adaptive Tokenizers on larger datasets, bigger model sizes, extended training periods, or Video ALITs, please reach out! We are very exicted by this direction.
To sample minimum-length encoding for the input image –
(We currently support only "Reconstruction Loss < Threshold" as an Automatic Token Selection Criteria.)
min_length_embedding, _ = adaptive_tokenizer.encode(image_tensor, return_min_length_embedding=True) # default threshold=0.07 for reconstruction loss
To sample all encodings for the input image –
all_length_embeddings, _ = adaptive_tokenizer.encode(image_tensor, return_min_length_embedding=False)
See adaptive_tokenizer_demo.ipynb
for detailed api calls
If you use our code or the paper, please consider citing the following:
@article{duggal2024adaptivetokenizer,
author = {Shivam Duggal and Phillip Isola and Antonio Torralba and William T. Freeman},
title = {Adaptive Length Image Tokenization via Recurrent Allocation},
journal= {arxiv},
year = {2024}
}