-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FP8 KVCache support #2028
Add FP8 KVCache support #2028
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR.
I think a lot has to be changed (to simplify it).
Also I don't see any core logic to actually handle the fp8, are the kernels ready?
Is it possible to test/add tests ?
launcher/src/main.rs
Outdated
/// version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead | ||
/// supported for common inference criteria. | ||
#[clap(default_value = "auto", long, env)] | ||
kv_cache_dtype: Option<String>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put a real enum.
enum KvDtype{
Fp8(Path)
}
...
#[clap(long, env, value_enum)]
Option<KvDtype>
This should work much better. None
is equivalent to auto (it just means the user hasn't specified anything we can do whateverwe want with it).
KvDtype
will automatically be sanitized/error checked (String isn't since all strings are available).
I tried putting Fp8(Path)
directly in clap, I'm not sure it actually works in clap internals but this is what we want, if fp8 is chosen we need a path for the scales. and Path
should also ensure the string is a valid path.
Maybe clap doesn't support algebraic enunms and we can't have Fp8(Path)
and need Fp8
instead.
In that case you need to handle validation early (There are other forms of validation in that layer, before pushing the args to the shard).
All CLI validation should happen here, as early as possible, with the best possible error messages.
launcher/src/main.rs
Outdated
/// greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead | ||
/// supported for common inference criteria. | ||
#[clap(long, env)] | ||
quantization_param_path: Option<String>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why an external file is needed ? Can't the user specify a value ?
If this is linked per model, shouldn't the model config/repo contain that information ?
If it's needed with kvcache=fp8 let's try to make sure it's actually. Ideally it's one option for users, if not possible we need manual validation here (and vaildation can be skipped later)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I'm trying to avoid adding too many flags, TGI already has too many, and since we don't break, we never remove stuff that was added, that's why if we can read the information from some consistant config in the repo it keeps the interface for the user simpler)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to load the scales from the checkpoint!
class KVCacheQuantSchema(BaseModel): | ||
dtype: str | ||
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's | ||
# layer indices to their per-tensor KV cache scaling factor. | ||
# TODO: Consider pulling this and its validation methods out into its | ||
# own schema class (tricky as its members are variable) | ||
scaling_factor: Dict[int, Dict[int, float]] | ||
|
||
@model_validator(mode="after") | ||
def check_is_fp8(self) -> "KVCacheQuantSchema": | ||
assert self.dtype == "float8_e4m3fn", ( | ||
"Loaded scaling factors intended for KV cache dtype = " | ||
f"{self.dtype} rather than float8_e4m3fn!" | ||
) | ||
return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably be done higher in the stack (ideally in launcher directly).
Rust is much more efficient at running these kind of checks but most importantly errors should happen as early as possible (and launcher has all the user flags too).
return self | ||
|
||
@model_validator(mode="after") | ||
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you not fusing this with the first validator ?
if context: | ||
model_type = context.get("model_type", None) | ||
if model_type is not None: | ||
assert model_type == self.model_type, ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering what kind of bug this is supposed to prevent.
If the scaling are contained directly within a single repo (so not user supplied) the validity should be obvious (no need for extra keys).
If it is user supplied, well it is unlikely to contain a model_type, no ?
kv_cache_dtype: str = "auto", | ||
quantization_param_path: Optional[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fused?
if callable(getattr(self.model, "load_kv_cache_scales", None)): | ||
self.model.load_kv_cache_scales(quantization_param_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If done correctly, you don't need any of that because you handle this at load time (and every model will need to be updated to support the new values)
"provided. Defaulting to scaling factors of 1.0. " | ||
"This may lead to less accurate results!" | ||
) | ||
elif quantization_param_path is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be handled all the way back in the launcher (again ideally directly by clap, if not possible manually in the rust part).
And it should be hard error, not a soft one (paramters sent by the user don't make sense, we never silently ignore.)
kv_cache_dtype: Optional[str] = "auto", | ||
quantization_param_path: Optional[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fused?
except FileNotFoundError: | ||
logger.error(f"File or directory '{filename}' not found.") | ||
except json.JSONDecodeError: | ||
logger.error(f"Error decoding JSON in file '{filename}'.") | ||
except Exception as e: | ||
logger.error(f"An error occurred while reading '{filename}': {e}") | ||
# This section is reached if and only if any of the excepts are hit | ||
# Return an empty iterable (list) => no KV cache scales are loaded | ||
# which ultimately defaults to 1.0 scales | ||
logger.warning( | ||
"Defaulting to KV cache scaling factors = 1.0 " | ||
f"for all layers in TP rank {tp_rank} " | ||
"as an error occurred during loading." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything here should be hard error (I think the standard ones would do fine).
If a user sends invalid information, we shouldn't silently ignore. They should fix it.
kv_cache[0], | ||
kv_cache[1], | ||
slots, | ||
self.kv_cache_dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one isn't necessary here, it should already be inferrable from kv_cache[0]
.
Happy to help with the rebase btw. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nothing shocking to me for benchmarking!
# The scaling factor convention we are assuming is | ||
# quantized_value * scaling_factor ~= true_value | ||
# which is consistent with the practice of setting | ||
# scaling_factor = tensor_amax / FPtype_max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Narsil I believe it is related to https://onnx.ai/onnx/technical/float8.html and the diff between e4m3fn and e4m3 (different exponent bias). Is that so @mht-sharma?
But shouldn't it be based on this param https://github.com/vllm-project/vllm/blob/319ad7f1d386699e94f629341c9988a926821f24/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json#L4 ? Also, we have to kind of decide whether we want to follow the scheme in vllm to store quantization params in a json and keep compatibility for e.g. models on the Hub, or not.
Thanks for the review @Narsil @fxmarty I will rebase and address the comments. Regarding the format for loading the FP8 scales: VLLM offers two methods:
VLLM intends to deprecate the |
Removed the |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
The core logic for handling FP8 is managed by the Paged Attention kernel in VLLM, with the necessary kernel tests. If you have any specific tests in mind, we can discuss them. VLLM includes tests that compare the output with the expected FP8 output, as seen https://github.com/comaniac/vllm/blob/main/tests/models/test_fp8.py. We can add a similar test if required. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, that's cool! I left some general comments, @Narsil may have more about the design
@@ -0,0 +1,56 @@ | |||
# Accelerating Inference with FP8 KV Cache | |||
|
|||
Text Generation Inference (TGI) now supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Text Generation Inference (TGI) now supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks | |
Text Generation Inference (TGI) supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks. |
It would be worth to explain what is FP8 KV Cache
. Readers may not be familiar with it (does it mean attention computation is in fp8? etc)
enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation.
This is kind of vague.
we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks
It would be worth IMO to show numbers / a chart here to get a grasp of what greatly means, in which case there is indeed a speedup, etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a couple of options for numbers and charts
-
Max Batch Total Tokens: The logs display the max batch total tokens, which increase when using the FP8 KV cache. We could create a chart showing the max batch total tokens in both cases (with and without the FP8 KV cache).
-
Throughput: Currently, I have created a custom script using
AsyncClient
to send 500 requests simultaneously withasyncio.gather
. This provides a rough estimate of throughput. @Narsil , do you have any suggestions on calculating throughput more precisely?
## Current Hardware Support | ||
|
||
* Nvidia GPUs: Supports both FP8E4M3 and FP8E5M2. | ||
* AMD GPUs: Supports FP8E4M3. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically, Instinct supports E4M3FNUZ, not E4M3. https://onnx.ai/onnx/technical/float8.html
## FP8 E5M2 KV Cache | ||
Example usage: | ||
``` | ||
text-generation-launcher --model-id <> --kv-cache-dtype fp8_e5m2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's maybe put a full runnable command with docker run
etc? Good inspiration could be https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one?install=NVIDIA#flashattention-2
|
||
Example usage: | ||
``` | ||
text-generation-launcher --model-id <> --kv-cache-dtype fp8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same + maybe have a model_id which indeed has scaling factors here.
The FP8 kv cache scaling factors required in the FP16 checkpoints are specified through the .kv_scale parameter present on the `Attention` module, such as: | ||
|
||
``` | ||
model.layers.0.self_attn.kv_scale < F32 | ||
model.layers.1.self_attn.kv_scale < F32 | ||
... | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we know whether they are for E4M3FNUZ or E4M3FN format?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the current tools Nvidia AMMO and AutoFP8 uses E4M3FN
. Currently there is no flag to determine the format (unless ofcourse the weight is quantized). But I could add a check and add a parameter for this in the checkpoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated reading this from config. See: a7909e6
torch.distributed.barrier(group=self.process_group) | ||
super(FlashLlama, self).__init__( | ||
model=model, | ||
tokenizer=tokenizer, | ||
num_layers=len(model.model.layers), | ||
num_kv_heads=model.model.num_key_value_heads, | ||
head_size=model.model.head_size, | ||
dtype=dtype, | ||
dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Things are a bit harder to read if dtype
attribute is used to mean the KV cache storage pytorch dtype. For some other models (gemma, idefics), self.dtype
is used with an other meaning, being the weights dtype.
# ROCm uses FP8 format with fp8_e4m3fn, whereas Nvidia GPUs use fp8_e4m3. | ||
# The multiplication by 2 compensates for the different numeric representation | ||
# between ROCm and Nvidia GPUs, ensuring consistent effective scaling across platforms. | ||
# After this adjustment, the overall effect is equivalent to the scaling applied without | ||
# it on Nvidia GPUs. | ||
if SYSTEM == "rocm": | ||
kv_scale *= 2.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we know whether serialized scales are in E4M3FN or E4M3FNUZ format? I think depending on that, the logic should be different here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the code to read the corresponding dtype as kv_cache_torch_dtype
from config.
Added the format in the README.md and utility script to add kv_cache_torch_dtype
when quantising model
See: a7909e6
examples/fp8_kvcache/README.md
Outdated
## Example usage | ||
To extract KV cache scaling factors from a quantized FP8 model and save them to an unquantized FP16 model, use the following command: | ||
|
||
``` | ||
python extract_fp8_kv_scales.py --quantized-model neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV --model meta-llama/Meta-Llama-3-8B-Instruct --save-path Meta-Llama-3-8B-Instruct | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we target vllm/tgi intercompatibility, why couldn't we load directly e.g. neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV
in TGI? Is it not loadable in vllm?
Also, by quantized FP8 model
, do you mean a model whose weights are quantized to fp8? How does it relate to FP8 KV cache? To me to obtain the KV cache scales you would simply need to have calibration data passing through the network & collecting stats on the KV cache.
It feels like the the KV cache scales from a quantized model (like neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV
), whose KV cache scales may have been obtained after the weights quantization (?), may have an unnecessary bias due to being inferred from calibration on the quantized model, not the unquantized one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV
has both the weights quantized along with KV scales. However, I don't think we can load the FP8 weights in TGI yet. And also one would run an FP16 model with FP8 KV cache so we need such a checkpoint
You are right; it may have an additional bias due to calibration after weight quantization. The bias might have been low, so I couldn't find a noticeable difference during inference.
I have tested two other models generated using scales from Nvidia AMMO (VLLM uses this also), which may not have this bias. The quantizer can accept the quantized format as full precision and provide the model with KV scales. We can extract scales from that.
Here is the link to the models: Llama-2-70b-chat-hf-FP8-KV
I will add these models and provide an example for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the example with Nvidia AMMO
examples/fp8_kvcache/README.md
Outdated
@@ -0,0 +1,40 @@ | |||
# FP8 (fp8_e4m3) KV Cache Scaling Factor Extraction Utility | |||
|
|||
This utility is designed to extract KV cache scaling factors from a quantized `FP8(fp8_e4m3)` Hugging Face (HF) model. The extracted scaling factors are then saved to the corresponding unquantized HF model, which can be used with Text Generation Inference (TGI). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have an example of such model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated and added a model to readme.
@@ -0,0 +1,56 @@ | |||
# Accelerating Inference with FP8 KV Cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be worth to have some evaluation metrics for an example model
@@ -246,6 +246,12 @@ class ModelType(enum.Enum): | |||
} | |||
|
|||
|
|||
FP8_KVCACHE_SUPPORTED_MODELS = { | |||
"llama", | |||
"baichun", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a typo of "baichuan"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Thanks for pointing it out
f"used for generating kv scales. Expected 'float8_e4m3fn' or 'float8_e4m3fnuz', but got '{kv_cache_torch_dtype}'." | ||
) | ||
|
||
# ROCm uses FP8 format with fp8_e4m3fn, whereas Nvidia GPUs use fp8_e4m3. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it should be "ROCm uses FP8 format with fp8_e4m3fnuz, whereas NVIDIA GPU uses fp8_e4m3fn"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's correct
Closing this as we added support for FP8 kv cache support in #2603. More support is coming (for pre-scaled kv-cache fp8) |
What does this PR do?
This PR introduces support for FP8 KV Cache in Text Generation Inference (TGI), significantly enhancing performance and memory efficiency on both Nvidia and AMD GPUs. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint, leading to faster and more scalable text generation.
Hardware Compatibility:
Example Usage:
KV cache scaling factors should be included in the FP16 checkpoint for E4M3 format to maintain accuracy. Default scaling factor is set to 1.0 if not provided, which may lead to accuracy loss.
Checkpoint Structure for KV Scales:
The FP8 KV cache scaling factors are specified through the
.kv_scale
parameter in the attention moduleThis follows a structure proposed in vllm - https://docs.vllm.ai/en/stable/quantization/fp8.html#fp8-checkpoint-structure-explanation
When providing
.kv_scale
in model, the config should specify properkv_cache_torch_dtype
used to generate scales (float8_e4m3fn
orfloat8_e4m3fnuz
).Currently, users need to extract the KV scales from FP8 checkpoint and add to the FP16 model. A helper script is provided in the PR for the same.
Sample Models with KV scales: Models with FP8 KV Cache
Todos: