Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FP8 KVCache support #2028

Closed
wants to merge 22 commits into from
Closed

Add FP8 KVCache support #2028

wants to merge 22 commits into from

Conversation

mht-sharma
Copy link
Collaborator

@mht-sharma mht-sharma commented Jun 6, 2024

What does this PR do?

This PR introduces support for FP8 KV Cache in Text Generation Inference (TGI), significantly enhancing performance and memory efficiency on both Nvidia and AMD GPUs. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint, leading to faster and more scalable text generation.

Hardware Compatibility:

  • Nvidia GPUs: Supports both FP8E4M3 and FP8E5M2 (TODO: Need VLLM update).
  • AMD GPUs: Supports FP8E4M3.

Example Usage:

text-generation-launcher --model-id <model_id> --kv-cache-dtype fp8/fp8_e5m2

KV cache scaling factors should be included in the FP16 checkpoint for E4M3 format to maintain accuracy. Default scaling factor is set to 1.0 if not provided, which may lead to accuracy loss.

Checkpoint Structure for KV Scales:

The FP8 KV cache scaling factors are specified through the .kv_scale parameter in the attention module

model.layers.0.self_attn.kv_scale                < F32
model.layers.1.self_attn.kv_scale                < F32

This follows a structure proposed in vllm - https://docs.vllm.ai/en/stable/quantization/fp8.html#fp8-checkpoint-structure-explanation

When providing .kv_scale in model, the config should specify proper kv_cache_torch_dtype used to generate scales (float8_e4m3fn or float8_e4m3fnuz).

Currently, users need to extract the KV scales from FP8 checkpoint and add to the FP16 model. A helper script is provided in the PR for the same.

Sample Models with KV scales: Models with FP8 KV Cache

Todos:

  • Documentation
  • Tests
  • Update VLLM for CUDA to support E5M2. @Narsil could you help with this!
  • Only supports LLAMA, will update same for other models in this or other PRs

Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR.

I think a lot has to be changed (to simplify it).

Also I don't see any core logic to actually handle the fp8, are the kernels ready?
Is it possible to test/add tests ?

/// version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead
/// supported for common inference criteria.
#[clap(default_value = "auto", long, env)]
kv_cache_dtype: Option<String>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put a real enum.

enum KvDtype{
    Fp8(Path)
}

...
#[clap(long, env, value_enum)]
Option<KvDtype>

This should work much better. None is equivalent to auto (it just means the user hasn't specified anything we can do whateverwe want with it).

KvDtype will automatically be sanitized/error checked (String isn't since all strings are available).

I tried putting Fp8(Path) directly in clap, I'm not sure it actually works in clap internals but this is what we want, if fp8 is chosen we need a path for the scales. and Path should also ensure the string is a valid path.

Maybe clap doesn't support algebraic enunms and we can't have Fp8(Path) and need Fp8 instead.
In that case you need to handle validation early (There are other forms of validation in that layer, before pushing the args to the shard).

All CLI validation should happen here, as early as possible, with the best possible error messages.

/// greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead
/// supported for common inference criteria.
#[clap(long, env)]
quantization_param_path: Option<String>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why an external file is needed ? Can't the user specify a value ?
If this is linked per model, shouldn't the model config/repo contain that information ?

If it's needed with kvcache=fp8 let's try to make sure it's actually. Ideally it's one option for users, if not possible we need manual validation here (and vaildation can be skipped later)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'm trying to avoid adding too many flags, TGI already has too many, and since we don't break, we never remove stuff that was added, that's why if we can read the information from some consistant config in the repo it keeps the interface for the user simpler)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to load the scales from the checkpoint!

Comment on lines 19 to 33
class KVCacheQuantSchema(BaseModel):
dtype: str
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
# layer indices to their per-tensor KV cache scaling factor.
# TODO: Consider pulling this and its validation methods out into its
# own schema class (tricky as its members are variable)
scaling_factor: Dict[int, Dict[int, float]]

@model_validator(mode="after")
def check_is_fp8(self) -> "KVCacheQuantSchema":
assert self.dtype == "float8_e4m3fn", (
"Loaded scaling factors intended for KV cache dtype = "
f"{self.dtype} rather than float8_e4m3fn!"
)
return self
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be done higher in the stack (ideally in launcher directly).

Rust is much more efficient at running these kind of checks but most importantly errors should happen as early as possible (and launcher has all the user flags too).

return self

@model_validator(mode="after")
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you not fusing this with the first validator ?

if context:
model_type = context.get("model_type", None)
if model_type is not None:
assert model_type == self.model_type, (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering what kind of bug this is supposed to prevent.

If the scaling are contained directly within a single repo (so not user supplied) the validity should be obvious (no need for extra keys).

If it is user supplied, well it is unlikely to contain a model_type, no ?

Comment on lines 695 to 696
kv_cache_dtype: str = "auto",
quantization_param_path: Optional[str] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fused?

Comment on lines 727 to 728
if callable(getattr(self.model, "load_kv_cache_scales", None)):
self.model.load_kv_cache_scales(quantization_param_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If done correctly, you don't need any of that because you handle this at load time (and every model will need to be updated to support the new values)

"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!"
)
elif quantization_param_path is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be handled all the way back in the launcher (again ideally directly by clap, if not possible manually in the rust part).

And it should be hard error, not a soft one (paramters sent by the user don't make sense, we never silently ignore.)

Comment on lines 32 to 33
kv_cache_dtype: Optional[str] = "auto",
quantization_param_path: Optional[str] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fused?

Comment on lines 34 to 47
except FileNotFoundError:
logger.error(f"File or directory '{filename}' not found.")
except json.JSONDecodeError:
logger.error(f"Error decoding JSON in file '{filename}'.")
except Exception as e:
logger.error(f"An error occurred while reading '{filename}': {e}")
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning(
"Defaulting to KV cache scaling factors = 1.0 "
f"for all layers in TP rank {tp_rank} "
"as an error occurred during loading."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything here should be hard error (I think the standard ones would do fine).

If a user sends invalid information, we shouldn't silently ignore. They should fix it.

kv_cache[0],
kv_cache[1],
slots,
self.kv_cache_dtype,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one isn't necessary here, it should already be inferrable from kv_cache[0].

@Narsil
Copy link
Collaborator

Narsil commented Jun 6, 2024

Happy to help with the rebase btw.

Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nothing shocking to me for benchmarking!

server/text_generation_server/layers/schema.py Outdated Show resolved Hide resolved
server/text_generation_server/layers/schema.py Outdated Show resolved Hide resolved
Comment on lines 480 to 483
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Narsil I believe it is related to https://onnx.ai/onnx/technical/float8.html and the diff between e4m3fn and e4m3 (different exponent bias). Is that so @mht-sharma?

But shouldn't it be based on this param https://github.com/vllm-project/vllm/blob/319ad7f1d386699e94f629341c9988a926821f24/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json#L4 ? Also, we have to kind of decide whether we want to follow the scheme in vllm to store quantization params in a json and keep compatibility for e.g. models on the Hub, or not.

@mht-sharma
Copy link
Collaborator Author

Thanks for the review @Narsil @fxmarty I will rebase and address the comments.

Regarding the format for loading the FP8 scales:

VLLM offers two methods:

  • quantization-param-path: This uses a JSON file (kv_cache_scales.json) containing per-tensor scaling factors for each layer. Example can be found here. This file is generated using the Nvidia AMMO quantizer available here.

  • Direct loading from checkpoints: This method has been introduced in one of the recent PRs and is located here.

VLLM intends to deprecate the quantization-param-path method soon, favoring the use of checkpoints for loading scales. Therefore, I would update our approach to also load scales using checkpoints.

@mht-sharma
Copy link
Collaborator Author

Thanks for the review @Narsil @fxmarty I will rebase and address the comments.

Regarding the format for loading the FP8 scales:

VLLM offers two methods:

  • quantization-param-path: This uses a JSON file (kv_cache_scales.json) containing per-tensor scaling factors for each layer. Example can be found here. This file is generated using the Nvidia AMMO quantizer available here.
  • Direct loading from checkpoints: This method has been introduced in one of the recent PRs and is located here.

VLLM intends to deprecate the quantization-param-path method soon, favoring the use of checkpoints for loading scales. Therefore, I would update our approach to also load scales using checkpoints.

Removed the quantization-param-path altogether: This method is already deprecated in VLLM, based on discussions here: vllm-project/vllm#4532

@mht-sharma mht-sharma changed the title [WIP] Add kvcache fp8 support Add kvcache fp8 support Jun 24, 2024
@mht-sharma mht-sharma marked this pull request as ready for review June 24, 2024 14:24
@mht-sharma mht-sharma changed the title Add kvcache fp8 support Add FP8 KVCache support Jun 24, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@mht-sharma
Copy link
Collaborator Author

Also I don't see any core logic to actually handle the fp8, are the kernels ready? Is it possible to test/add tests ?

The core logic for handling FP8 is managed by the Paged Attention kernel in VLLM, with the necessary kernel tests. If you have any specific tests in mind, we can discuss them. VLLM includes tests that compare the output with the expected FP8 output, as seen https://github.com/comaniac/vllm/blob/main/tests/models/test_fp8.py. We can add a similar test if required.

Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, that's cool! I left some general comments, @Narsil may have more about the design

@@ -0,0 +1,56 @@
# Accelerating Inference with FP8 KV Cache

Text Generation Inference (TGI) now supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Text Generation Inference (TGI) now supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks
Text Generation Inference (TGI) supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks.

It would be worth to explain what is FP8 KV Cache. Readers may not be familiar with it (does it mean attention computation is in fp8? etc)

enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation.

This is kind of vague.

we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks

It would be worth IMO to show numbers / a chart here to get a grasp of what greatly means, in which case there is indeed a speedup, etc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a couple of options for numbers and charts

  • Max Batch Total Tokens: The logs display the max batch total tokens, which increase when using the FP8 KV cache. We could create a chart showing the max batch total tokens in both cases (with and without the FP8 KV cache).

  • Throughput: Currently, I have created a custom script using AsyncClient to send 500 requests simultaneously with asyncio.gather. This provides a rough estimate of throughput. @Narsil , do you have any suggestions on calculating throughput more precisely?

## Current Hardware Support

* Nvidia GPUs: Supports both FP8E4M3 and FP8E5M2.
* AMD GPUs: Supports FP8E4M3.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, Instinct supports E4M3FNUZ, not E4M3. https://onnx.ai/onnx/technical/float8.html

## FP8 E5M2 KV Cache
Example usage:
```
text-generation-launcher --model-id <> --kv-cache-dtype fp8_e5m2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe put a full runnable command with docker run etc? Good inspiration could be https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one?install=NVIDIA#flashattention-2
image


Example usage:
```
text-generation-launcher --model-id <> --kv-cache-dtype fp8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same + maybe have a model_id which indeed has scaling factors here.

Comment on lines 42 to 48
The FP8 kv cache scaling factors required in the FP16 checkpoints are specified through the .kv_scale parameter present on the `Attention` module, such as:

```
model.layers.0.self_attn.kv_scale < F32
model.layers.1.self_attn.kv_scale < F32
...
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we know whether they are for E4M3FNUZ or E4M3FN format?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the current tools Nvidia AMMO and AutoFP8 uses E4M3FN. Currently there is no flag to determine the format (unless ofcourse the weight is quantized). But I could add a check and add a parameter for this in the checkpoint.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated reading this from config. See: a7909e6

torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Things are a bit harder to read if dtype attribute is used to mean the KV cache storage pytorch dtype. For some other models (gemma, idefics), self.dtype is used with an other meaning, being the weights dtype.

Comment on lines 794 to 800
# ROCm uses FP8 format with fp8_e4m3fn, whereas Nvidia GPUs use fp8_e4m3.
# The multiplication by 2 compensates for the different numeric representation
# between ROCm and Nvidia GPUs, ensuring consistent effective scaling across platforms.
# After this adjustment, the overall effect is equivalent to the scaling applied without
# it on Nvidia GPUs.
if SYSTEM == "rocm":
kv_scale *= 2.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we know whether serialized scales are in E4M3FN or E4M3FNUZ format? I think depending on that, the logic should be different here.

Copy link
Collaborator Author

@mht-sharma mht-sharma Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the code to read the corresponding dtype as kv_cache_torch_dtype from config.

Added the format in the README.md and utility script to add kv_cache_torch_dtype when quantising model

See: a7909e6

Comment on lines 35 to 40
## Example usage
To extract KV cache scaling factors from a quantized FP8 model and save them to an unquantized FP16 model, use the following command:

```
python extract_fp8_kv_scales.py --quantized-model neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV --model meta-llama/Meta-Llama-3-8B-Instruct --save-path Meta-Llama-3-8B-Instruct
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we target vllm/tgi intercompatibility, why couldn't we load directly e.g. neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV in TGI? Is it not loadable in vllm?

Also, by quantized FP8 model, do you mean a model whose weights are quantized to fp8? How does it relate to FP8 KV cache? To me to obtain the KV cache scales you would simply need to have calibration data passing through the network & collecting stats on the KV cache.

It feels like the the KV cache scales from a quantized model (like neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV), whose KV cache scales may have been obtained after the weights quantization (?), may have an unnecessary bias due to being inferred from calibration on the quantized model, not the unquantized one.

Copy link
Collaborator Author

@mht-sharma mht-sharma Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV has both the weights quantized along with KV scales. However, I don't think we can load the FP8 weights in TGI yet. And also one would run an FP16 model with FP8 KV cache so we need such a checkpoint

You are right; it may have an additional bias due to calibration after weight quantization. The bias might have been low, so I couldn't find a noticeable difference during inference.

I have tested two other models generated using scales from Nvidia AMMO (VLLM uses this also), which may not have this bias. The quantizer can accept the quantized format as full precision and provide the model with KV scales. We can extract scales from that.

Here is the link to the models: Llama-2-70b-chat-hf-FP8-KV

I will add these models and provide an example for this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the example with Nvidia AMMO

@@ -0,0 +1,40 @@
# FP8 (fp8_e4m3) KV Cache Scaling Factor Extraction Utility

This utility is designed to extract KV cache scaling factors from a quantized `FP8(fp8_e4m3)` Hugging Face (HF) model. The extracted scaling factors are then saved to the corresponding unquantized HF model, which can be used with Text Generation Inference (TGI).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an example of such model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated and added a model to readme.

@@ -0,0 +1,56 @@
# Accelerating Inference with FP8 KV Cache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be worth to have some evaluation metrics for an example model

@danieldk danieldk self-assigned this Jul 25, 2024
@@ -246,6 +246,12 @@ class ModelType(enum.Enum):
}


FP8_KVCACHE_SUPPORTED_MODELS = {
"llama",
"baichun",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a typo of "baichuan"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Thanks for pointing it out

f"used for generating kv scales. Expected 'float8_e4m3fn' or 'float8_e4m3fnuz', but got '{kv_cache_torch_dtype}'."
)

# ROCm uses FP8 format with fp8_e4m3fn, whereas Nvidia GPUs use fp8_e4m3.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it should be "ROCm uses FP8 format with fp8_e4m3fnuz, whereas NVIDIA GPU uses fp8_e4m3fn"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's correct

@Narsil
Copy link
Collaborator

Narsil commented Oct 8, 2024

Closing this as we added support for FP8 kv cache support in #2603.

More support is coming (for pre-scaled kv-cache fp8)

@Narsil Narsil closed this Oct 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants