We have put a large emphasis on making training as fast as possible. Consequently, some pre-processing steps are required.
Namely, before starting any training, we
- Encode training audios into spectrograms and then with VAE into mean/std
- Extract CLIP and synchronization features from videos
- Extract CLIP features from text (captions)
- Encode all extracted features into MemoryMappedTensors with TensorDict
NOTE: for maximum training speed (e.g., when training the base model with 2*H100s), you would need around 3~5 GB/s of random read speed. Spinning disks would not be able to catch up and most consumer-grade SSDs would struggle. In my experience, the best bet is to have a large enough system memory such that the OS can cache the data. This way, the data is read from RAM instead of disk.
The current training script does not support _v2
training.
Install av-benchmark. We use this library to automatically evaluate on the validation set during training, and on the test set after training. Extract features for evaluation using av-benchmark for the validation and test set as a validation cache and a test cache. You can also download the precomputed evaluation cache here.
You will also need ffmpeg for video frames extraction. Note that torchaudio
imposes a maximum version limit (ffmpeg<7
). You can install it as follows:
conda install -c conda-forge 'ffmpeg<7'
Download the corresponding VAE (v1-16.pth
for 16kHz training, and v1-44.pth
for 44.1kHz training), vocoder models (best_netG.pt
for 16kHz training; the vocoder for 44.1kHz training will be downloaded automatically), the empty string encoding, and Synchformer weights from MODELS.md place them in ext_weights/
.
We have prepared some example data in training/example_videos
.
Running the training/extract_video_training_latents.py
script will extract the audio, video, and text features and save them as a TensorDict
with a .tsv
file containing metadata on disk.
To run this script, use the torchrun
utility:
torchrun --standalone training/extract_video_training_latents.py
You can run this with multiple GPUs (with --nproc_per_node=<n>
) to speed up extraction.
Check the top of the script to switch between 16kHz/44.1kHz extraction and data path definitions.
Arguments:
latent_dir
-- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.output_dir
-- where TensorDict and the metadata file are saved.
We have prepared some example data in training/example_audios
.
We first need to run training/partition_clips
to partition each audio file into clips.
Then, we run the training/extract_audio_training_latents.py
script, which will extract the audio and text features and save them as a TensorDict
with a .tsv
file containing metadata on the disk.
To run this script:
python training/partition_clips.py
Arguments:
data_path
-- path to the audio files (.flac
or.wav
)output_dir
-- path to the output.csv
filestart
-- optional; useful when you need to run multiple processes to speed up processing -- this defines the beginning of the chunk to be processedend
-- optional; useful when you need to run multiple processes to speed up processing -- this defines the end of the chunk to be processed
Then, run the extract_audio_training_latents.py
with torchrun
:
torchrun --standalone training/extract_audio_training_latents.py
You can run this with multiple GPUs (with --nproc_per_node=<n>
) to speed up extraction.
Check the top of the script to switch between 16kHz/44.1kHz extraction.
Arguments:
data_dir
-- path to the audio files (.flac
or.wav
), same as the previous stepcaptions_tsv
-- path to the captions file, a csv file at least with columnsid
andcaption
clips_tsv
-- path to the clips file, generated in the last steplatent_dir
-- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.output_dir
-- where TensorDict and the metadata file are saved.
The reference tsv files (with overlaps removed as mentioned in the paper) can be found here. Note that audioset_sl.tsv
, bbcsound.tsv
, and freesound.tsv
are subsets that are parts of WavCaps. These subsets might be smaller than the original datasets.
We use Distributed Data Parallel (DDP) for training.
First, specify the data path in config/data/base.yaml
. If you used the default parameters in the scripts above to extract features for the example data, the Example_video
and Example_audio
items should already be correct.
To run training on the example data, use the following command:
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=1 train.py exp_id=debug compile=False debug=True example_train=True batch_size=1
This will not train a useful model, but it will check if everything is set up correctly.
For full training on the base model with two GPUs, use the following command:
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=2 train.py exp_id=exp_1 model=small_16k
Any outputs from training will be stored in output/<exp_id>
.
More configuration options can be found in config/base_config.yaml
and config/train_config.yaml
.
Model checkpoints, including optimizer states and the latest EMA weights, are available here: https://huggingface.co/hkchengrex/MMAudio
Godspeed!