- [2024.12.18] 🔥🔥🔥 The checkpoints are available here!
- [2024.12.18] Code is available now! Welcome to watch 👀 this repository for the latest updates.
- [2024.12.17] The paper has been published on Arxiv 🎉. The pdf version is available here!
GRAM learns and then aligns modalities directly in the higher-dimensional space in which modality embeddings lie by minimizing the Gramian volume of the k-dimensional parallelotope spanned by the modality vectors, ensuring the geometric alignment of all modalities simultaneously.
GRAM can replace cosine similarity in any downstream method, holding for 2 to modality and providing more meaningful alignment with respect to previous similarity measures. Moreover, the novel GRAM-based contrastive loss function enhances the alignment of multimodal models in the higher-dimensional embedding space, leading to new state-of-the-art performance in downstream tasks such as video-audio-text retrieval and audio-video classification.
An aligned shared latent space among n modalities is a strong baseline for whatever downstream task that rely on embedding extraction. The results obtained from this paper will lead to superior performance in existing downstream tasks (T2I, T2V, V2A, etc.) but also unlock fancy tasks such as for example image to audio generation or image generation conditioned on text and audio.
GRAM is implemented based on Pytorch. We use Python-3.9 and Cuda-11.7. Other version could be also compatible. Other needed packages are listed in preinstall.sh.
conda create -n gram python=3.9
conda activate gram
sh preinstall.sh
Make a dir named pretrained_weights under the main work dir.
- Download evaclip weight:
wget -P pretrained_weights/clip/ https://huggingface.co/QuanSun/EVA-CLIP/resolve/main/EVA01_CLIP_g_14_psz14_s11B.pt
-
Download beats weight from https://github.com/microsoft/unilm/tree/master/beats
-
Download bert weight:
from transformers import BertModel, BertTokenizer
bert = BertModel.from_pretrained('bert-base-uncased')
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert.save_pretrained('pretrained_weights/bert/bert-base-uncased')
bert_tokenizer.save_pretrained('pretrained_weights/bert/bert-base-uncased')
The processed pretrained_weights path should be as follows:
├── pretrained_weights
│ ├── beats
│ │ └── BEATs_iter3_plus_AS2M.pt
│ ├── bert
│ │ └── bert-base-uncased
│ ├── clip
│ │ └── EVA01_CLIP_g_14_psz14_s11B.pt
All models are available here!
Name | Training Dataset | Testing Dataset | R@1 in Testing Dataset | link |
---|---|---|---|---|
GRAM_pretrained_5modalities | Vast27M 150k Subset TVAS | MSRVTT | 54.8 | link |
GRAM_pretrained_4modalities | Vast27M 150k Subset TVASD | MSRVTT | 55.3 | link |
GRAM_finetuned_MSRVTT | MSRVTT | MSRVTT | 64.0 | link |
GRAM_finetuned_DIDEMO | DIDEMO | DIDEMO | 67.3 | link |
GRAM_finetuned_ANET | ActivityNet | ActivityNet | 69.9 | link |
GRAM_finetuned_VATEX | VATEX | VATEX | 87.7 | link |
Download the entire folder that consists of a subfolder "log" and another one "ckpt. Place the folder whatever you prefer and record the location for future commands.
An example of paths after the download could be as follow:
├── pretrained_models
│ ├── GRAM_pretrained_4modalities
│ │ ├── log
│ │ ├── ckpt
VAST-27M DATASET could be downloaded following the official repo
We used a subset of VAST-27M for the pretraining phase of GRAM. This is the annotation file used here
Download annotations150k.json file subset. Reference it in scripts/gram/finetune_ret.sh and in config/gram/finetune_cfg/finetune-area.json
sh scripts/gram/finetune_ret.sh
Change configuration internally at scripts/gram/finetune_ret.sh and then run
sh scripts/gram/finetune_ret.sh
For example, if the cmd for finetuning retrieval model is as follows:
python3 -m torch.distributed.launch \
--nnodes 1 \
--node_rank 0 \
--nproc_per_node 8 \
--master_port 9834 \
./run.py \
--learning_rate 2e-5 \
--checkpointing true \
--first_eval true \
--save_best true \
--config ./config/gram/finetune_cfg/retrieval-msrvtt.json \
--pretrain_dir $PATH-TO-CKPT-FOLDER \
--output_dir $PATH-WHERE-TO-STORE-RESULTS \
if you want to test model, just add following two rows to the cmd:
--mode 'testing' \
--checkpoint /PATH/TO/SAVED_CHECKPOINT.pt
If you find this code useful for your research, please consider citing the following paper:
@misc{cicchetti2024gramianmultimodalrepresentationlearning,
title={Gramian Multimodal Representation Learning and Alignment},
author={Giordano Cicchetti and Eleonora Grassucci and Luigi Sigillo and Danilo Comminiello},
year={2024},
eprint={2412.11959},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2412.11959},
}
For the full list of third-party licenses used in this project, please see the THIRD_PARTY_LICENSES.md file.