This repository is a pytorch implementation of Speculative Decoding / Speculative Sampling (Leviathan et al., 2023; Chen et al., 2023). It contains the code for three generation strategies: classic auto-regressive decoding, beam search decoding (with length penalty) and speculative decoding. Auto-regressive decoding and Speculative Decoding can be used in a greedy or nucleus sampling (temperature, top k and top p) setting.
Figure 1: Example of generation, comparing Speculative Decoding and Vanilla Decoding. The difference of text is due to the pseudo-random estimation of computers.
Speculative Decoding is a decoding strategy for transformers that allows to generate sequences faster than the classic auto-regressive decoding without changing the output distribution or requiring further fine-tuning. It uses a smaller, more efficient approximation model (called a "drafter") to generate speculative token prefixes. These prefixes are then evaluated in parallel by the larger target model, reducing the number of serial decoding steps required and leading to inference speedups.
The core process rely on the specific behavior of the Transformer model that allows to compute the probability distribution of all the fed in tokens. This distribution is then used to verify the drafts generated by the drafter model.
Figure 2: Overview of Speculative Decoding.
This project requires Python 3.7 or later and the following dependencies:
rich
tqdm
termcolor
tokenizers>=0.19.1
torch>=2.3.0
transformers>=4.41.1
accelerate>=0.30.1
bitsandbytes>=0.43.1
Simply fork this repository and install the dependencies.
The target model is the transformer model we want to accelerate, while the drafter model is the smaller model that will be used to generate drafts to the target model.
Here are some requirements to make speculative decoding work:
- The target model must be a transformer model (decoder only or encoder-decoder).
- The drafter model must share the same tokenizer as the target model.
- The target model and the drafter model should output same shape logits.
- The target model should be large enough to benefit from the acceleration. (causing a bottleneck in memory)
- The drafter model should be small enough to be faster than the target model.
from transformers import AutoTokenizer, AutoModelForCausalLM
# We will use the Google Llama-3.2 3B Instruct as the model we want to accelerate (3B parameters)
target_model_name = "meta-llama/Llama-3.2-3B-Instruct"
target = AutoModelForCausalLM.from_pretrained(target_model_name)
# We will use the Google Llama-3.2 1B Instruct as the drafter model (1B parameters)
drafter_model_name = "meta-llama/Llama-3.2-1B-Instruct"
drafter = AutoModelForCausalLM.from_pretrained(drafter_model_name)
# Don't forget to load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(target_model_name)
Before generating text, we need to prepare the input. The input should be tokenized and encoded using the tokenizer.
prefix = "Translate to English: Je m'appelle Romain. N'hésitez pas à contribuer à mon projet !"
chat_templated = f"<bos><start_of_turn>user\n{prefix}<end_of_turn>\n<start_of_turn>model\n" # Gemma chat template
input_ids = tokenizer(chat_templated, return_tensors="pt").input_ids
input_ids = input_ids[0].tolist() # Generation methods require a list of ids
Speculative Decoding uses one hyperparameter:
Increasing the value of
from sampling import speculative_generate, autoregressive_generate
# from sampling import speculative_generate_encoder_decoder, autoregressive_generate_encoder_decoder
from utils.logits_processors import NucleusProcessor
# Parameters
gen_len = 100 # Maximum number of tokens generated (could over pass when using speculative decoding)
gamma = 4 # Number of drafts generated by the drafter model at each step
logits_processor = NucleusProcessor(temperature=.6, top_p=.9) # Nucleus sampling with p=0.9 and T=0.6
# Generate text using the classic auto-regressive decoding (slow)
output_ids_ar = autoregressive_generate( # or autoregressive_generate_encoder_decoder for encoder-decoder models
input_ids,
target,
logits_processor=logits_processor,
max_gen_len=gen_len,
end_tokens_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
output_ar = tokenizer.decode(output_ids_ar, skip_special_tokens=True)
# Generate text using the speculative decoding (faster)
output_ids_sd, alpha = speculative_generate( # or speculative_generate_encoder_decoder for encoder-decoder models
input_ids,
drafter,
target,
logits_processor=logits_processor,
gamma=gamma,
max_gen_len=gen_len,
end_tokens_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
output_sd = tokenizer.decode(output_ids_sd, skip_special_tokens=True)
print("Auto-regressive decoding:", output_ar)
print("Speculative decoding:", output_sd)
print("Acceptance rate:", alpha) # Number of drafts accepted by the target model divided by the number of drafts generated
To use Beam Search Decoding, you can use the beam_search_generate
function. The beam_search_generate
function requires top_k
(number of tokens to evaluate at each branch), num_beams
(number of beams that run in parallel), min_length
and alpha
(for length penalty) hyperparameters.
from sampling import beam_search_generate # Beam Search Decoding is not compatible with encoder-decoder models yet.
output_ids_bs = beam_search_generate(
input_ids,
target,
max_gen_len=gen_len,
end_tokens_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
top_k=3,
num_beams=5,
min_length=5,
alpha=1.2,
)
You can run infer.py
in your console to generate text using the console interface. You can easily change the hyperparameters of the generation, compare target and speculative generation, enable drafter generation and much more.
python infer.py
To change the models used, you can change the target_model_name
and drafter_model_name
in the infer.py
file.
Be careful to change the generate methods to encoder-decoder models if you are using encoder-decoder models.
On top of this implementation, one of the works of my MSc thesis is implemented: Ngram Assisted Speculative Decoding (NASD). NASD replaces the drafter model of Speculative Decoding with an None
if the context has never been seen.
Before generation, the
The advantage of NASD is that it allows for faster generation without the need for a second model. It is training-free and task- and model-agnostic, making it a versatile approach for accelerating sequence generation in transformers.
As I did not published this work, a similar approach as been introduced later in NAPD (Ou et al., 2024).
To reproduce their results, you can use the NASD implementation by simply setting top_k_filler=1
.
The documentation of NASD will be published soon...
Figure 3: Overview of NASD method.
The cache feature is very inconsistent and sometimes incorrectly implemented in huggingface transformers (mainly depending on the model). This can lead to incorrect results or even errors when using the cache feature. To avoid this issue, you can disable the cache feature by setting use_cache=False
in the generate methods. This will slow down the generation but will avoid any cache-related issues.
Please open an issue or submit a pull request if you find any bug. Contributions are welcome!
[1] Leviathan, Y., Kalman, M. & Matias, Y.. (2023). Fast Inference from Transformers via Speculative Decoding. Proceedings of the 40th International Conference on Machine Learning, in Proceedings of Machine Learning Research 202:19274-19286 Available from https://proceedings.mlr.press/v202/leviathan23a.html.
[2] Chen, C., Borgeaud, S., Irving, G., Lespiau, J. B., Sifre, L., & Jumper, J. (2023). Accelerating large language model decoding with speculative sampling. arXiv preprint arXiv:2302.01318.
[3] Jie Ou, Yueming Chen, Wenhong Tian. (2024). Lossless Acceleration of Large Language Model via Adaptive N-gram Parallel Decoding. Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 6: Industry Track), pages 10–22