S3Tokenizer was initially introduced in CosyVoice [Paper] [Repo], it is a Supervised Semantic Speech Tokenizer based on the pre-trained SenseVoice-Large model, which enhances the semantic relationship of extracted tokens to textual and paralinguistic information, is robust to data noise, and reduces the reliance on clean data collection, thereby enabling the use of a broader range of data for model training.
However, as indicated in this [issue], the authors have no intention to open-source the PyTorch implementation of the S3Tokenizer, and only plan to release an ONNX file. Additionally, users aiming to fine-tune CosyVoice must extract speech codes offline, with the batch size restricted to 1, a process that is notably time-consuming (refer to [cosyvoice/tools/extract_speech_token.py]).
This repository undertakes a reverse engineering of the S3Tokenizer, offering:
- A pure PyTorch implementation of S3Tokenizer (see [model.py]), compatible with initializing weights from the released ONNX file (see [utils.py::onnx2torch()]).
- High-throughput (distributed) batch inference, achieving a ~790x speedup compared to the original inference pipeline in [cosyvoice/tools/extract_speech_token.py].
- The capability to perform online speech code extraction during SpeechLLM training.
pip install s3tokenizer
import s3tokenizer
tokenizer = s3tokenizer.load_model("speech_tokenizer_v1").cuda() # or "speech_tokenizer_v1_25hz"
mels = []
wav_paths = ["s3tokenizer/assets/BAC009S0764W0121.wav", "s3tokenizer/assets/BAC009S0764W0122.wav"]
for wav_path in wav_paths:
audio = s3tokenizer.load_audio(wav_path)
mels.append(s3tokenizer.log_mel_spectrogram(audio))
mels, mels_lens = s3tokenizer.padding(mels)
codes, codes_lens = tokenizer.quantize(mels.cuda(), mels_lens.cuda())
for i in range(len(wav_paths)):
print(codes[i, :codes_lens[i].item()])
s3tokenizer --wav_scp xxx.scp \
--device "cpu" \
--output_dir "./" \
--batch_size 32 \
--model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz"
20240911-192822.mp4
torchrun --nproc_per_node=8 --nnodes=1 \
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
`which s3tokenizer` --wav_scp xxx.scp \
--device "cuda" \
--output_dir "./" \
--batch_size 32 \
--model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz"
20240911-192351.mp4
Method | Time cost on Aishell Test Set | Relative speed up | Miss Rate |
---|---|---|---|
[cosyvoice/tools/extract_speech_token.py], cpu | 9 hours | ~ | ~ |
cpu, batchsize 32 | 1.5h | ~6x | 0.76% |
4 gpus (3090), batchsize 32 per gpu | 41s | ~790x | 0.76% |
The miss rate represents the proportion of tokens that are inconsistent between the batch inference predictions and the ONNX (batch=1) inference predictions.
Before (extract code offline) | After (extract code online) |
---|---|
class SpeechLLM(nn.Module):
...
def __init__(self, ...):
...
def forward(self, speech_codes: Tensor, text_ids: Tensor, ...):
... |
import s3tokenizer
class SpeechLLM(nn.Module):
...
def __init__(self, ...):
...
self.speech_tokenizer = s3tokenizer.load_model("speech_tokenizer_v1") # or "speech_tokenizer_v1_25hz"
self.speech_tokenizer.freeze()
def forward(self, speech: Tensor, speech_lens: Tensor, text_ids: Tensor, ...):
...
speech_codes, speech_codes_lens = self.speech_tokenizer.quantize(speech, speech_lens)
speech_codes = speech_codes.clone() # for backward compatbility
speech_codes_lens = speeech_codes_lens.clone() # for backward compatbility |
- Usage-1: Offline batch inference
- Usage-2: Distributed offline batch inference via command-line tools
- Usage-3: Online speech code extraction