Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FP8 KVCache support #2028

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
title: Monitoring TGI with Prometheus and Grafana
- local: basic_tutorials/train_medusa
title: Train Medusa
- local: basic_tutorials/fp8_kv_cache
title: FP8 KV Cache
title: Tutorials
- sections:
- local: conceptual/streaming
Expand Down
56 changes: 56 additions & 0 deletions docs/source/basic_tutorials/fp8_kv_cache.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Accelerating Inference with FP8 KV Cache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be worth to have some evaluation metrics for an example model


Text Generation Inference (TGI) now supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Text Generation Inference (TGI) now supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks
Text Generation Inference (TGI) supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks.

It would be worth to explain what is FP8 KV Cache. Readers may not be familiar with it (does it mean attention computation is in fp8? etc)

enhancing inference speed on both Nvidia and AMD GPUs. This feature significantly boosts performance and memory efficiency, enabling faster and more scalable text generation.

This is kind of vague.

we can greatly reduce the memory footprint. This reduction allows for improved throughput in text generation tasks

It would be worth IMO to show numbers / a chart here to get a grasp of what greatly means, in which case there is indeed a speedup, etc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a couple of options for numbers and charts

  • Max Batch Total Tokens: The logs display the max batch total tokens, which increase when using the FP8 KV cache. We could create a chart showing the max batch total tokens in both cases (with and without the FP8 KV cache).

  • Throughput: Currently, I have created a custom script using AsyncClient to send 500 requests simultaneously with asyncio.gather. This provides a rough estimate of throughput. @Narsil , do you have any suggestions on calculating throughput more precisely?


## FP8 Formats: E4M3 and E5M2
The Open Compute Project (OCP) defines two common 8-bit floating point data formats:

E4M3:

* 1 sign bit
* 4 biased exponent bits
* 3 mantissa bits

E5M2:

* 1 sign bit
* 5 biased exponent bits
* 2 mantissa bits

E4M3 offers higher precision for representing floating point numbers. However, due to its limited range, E4M3 typically requires a higher-precision (usually FP32) scaling factor alongside each quantized tensor. Currently, TGI supports only per-tensor (scalar) scaling factors.

## Current Hardware Support

* Nvidia GPUs: Supports both FP8E4M3 and FP8E5M2.
* AMD GPUs: Supports FP8E4M3.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, Instinct supports E4M3FNUZ, not E4M3. https://onnx.ai/onnx/technical/float8.html


## FP8 E5M2 KV Cache
Example usage:
```
text-generation-launcher --model-id <> --kv-cache-dtype fp8_e5m2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe put a full runnable command with docker run etc? Good inspiration could be https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one?install=NVIDIA#flashattention-2
image

```

## FP8 E4M3 KV Cache
While E4M3 offers higher precision, it requires careful handling of scaling factors to maintain accuracy. Therefore, it is recommended to provide KV cache scaling factors as part of the FP16 checkpoint. If scaling factors are not provided, a default factor of 1.0 is used, which may lead to accuracy loss.

Example usage:
```
text-generation-launcher --model-id <> --kv-cache-dtype fp8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same + maybe have a model_id which indeed has scaling factors here.

```

### Checkpoint structure for KV scales
The FP8 kv cache scaling factors required in the FP16 checkpoints are specified through the .kv_scale parameter present on the `Attention` module, such as:

```
model.layers.0.self_attn.kv_scale < F32
model.layers.1.self_attn.kv_scale < F32
...
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we know whether they are for E4M3FNUZ or E4M3FN format?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the current tools Nvidia AMMO and AutoFP8 uses E4M3FN. Currently there is no flag to determine the format (unless ofcourse the weight is quantized). But I could add a check and add a parameter for this in the checkpoint.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated reading this from config. See: a7909e6


### Generating model with KV Cache scales

Use [AutoFP8](https://github.com/neuralmagic/AutoFP8) with calibration data to generate per-tensor scales for FP8 quantized KV Cache. For more details, see the following example: https://github.com/neuralmagic/AutoFP8/blob/main/example_dataset.py

TGI provides a utility to extract the FP8 KV cache scales from an `AutoFP8` quantized model and save them to the FP16 model for use with TGI. For more information: <path to script>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Todo


Alternatively, you can use other quantizer tools, such as Nvidia AMMO, to obtain these scaling factors.
7 changes: 7 additions & 0 deletions docs/source/basic_tutorials/launcher.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ Options:
[env: DTYPE=]
[possible values: float16, bfloat16]

```
## KV_CACHE_DTYPE
```shell
--kv-cache-dtype <KV_CACHE_DTYPE>
[env: KV_CACHE_DTYPE=]
[possible values: fp8, fp8_e5m2]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a description

```
## TRUST_REMOTE_CODE
```shell
Expand Down
40 changes: 40 additions & 0 deletions examples/fp8_kvcache/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# FP8 (fp8_e4m3) KV Cache Scaling Factor Extraction Utility

This utility is designed to extract KV cache scaling factors from a quantized `FP8(fp8_e4m3)` Hugging Face (HF) model. The extracted scaling factors are then saved to the corresponding unquantized HF model, which can be used with Text Generation Inference (TGI).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an example of such model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated and added a model to readme.


Note: This tool specifically works with models quantized using the [AutoFP8](https://github.com/neuralmagic/AutoFP8/tree/main) repository.

The KV scales are integrated into the unquantized HF model in the following format. The FP8 KV cache scaling factors are added to the FP16 checkpoints and specified through the .kv_scale parameter within the Attention module, as shown below:

```
model.layers.0.self_attn.kv_scale < F32
model.layers.1.self_attn.kv_scale < F32
...
```

## Prerequisites

- text-generation-server
- AutoFP8

## CLI options
```
usage: extract_fp8_kv_scales.py [-h] [--quantized-model QUANTIZED_MODEL] [--model MODEL] [--save-path SAVE_PATH]

Extract FP8 KV cache scales and add them to a FP16 model.

options:
-h, --help show this help message and exit
--quantized-model QUANTIZED_MODEL
Path to the FP8 model checkpoint to extract KV cache scales
--model MODEL Model ID of the FP16 model to save the KV cache scales
--save-path SAVE_PATH
Path to save the FP16 model with the kv scales
```

## Example usage
To extract KV cache scaling factors from a quantized FP8 model and save them to an unquantized FP16 model, use the following command:

```
python extract_fp8_kv_scales.py --quantized-model neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV --model meta-llama/Meta-Llama-3-8B-Instruct --save-path Meta-Llama-3-8B-Instruct
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we target vllm/tgi intercompatibility, why couldn't we load directly e.g. neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV in TGI? Is it not loadable in vllm?

Also, by quantized FP8 model, do you mean a model whose weights are quantized to fp8? How does it relate to FP8 KV cache? To me to obtain the KV cache scales you would simply need to have calibration data passing through the network & collecting stats on the KV cache.

It feels like the the KV cache scales from a quantized model (like neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV), whose KV cache scales may have been obtained after the weights quantization (?), may have an unnecessary bias due to being inferred from calibration on the quantized model, not the unquantized one.

Copy link
Collaborator Author

@mht-sharma mht-sharma Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV has both the weights quantized along with KV scales. However, I don't think we can load the FP8 weights in TGI yet. And also one would run an FP16 model with FP8 KV cache so we need such a checkpoint

You are right; it may have an additional bias due to calibration after weight quantization. The bias might have been low, so I couldn't find a noticeable difference during inference.

I have tested two other models generated using scales from Nvidia AMMO (VLLM uses this also), which may not have this bias. The quantizer can accept the quantized format as full precision and provide the model with KV scales. We can extract scales from that.

Here is the link to the models: Llama-2-70b-chat-hf-FP8-KV

I will add these models and provide an example for this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the example with Nvidia AMMO

97 changes: 97 additions & 0 deletions examples/fp8_kvcache/extract_fp8_kv_scales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from pathlib import Path
from text_generation_server.utils.hub import (
weight_files,
download_weights,
weight_hub_files,
)
from safetensors import safe_open
import argparse


def load_model(ckpt_path):
model_args = {"torch_dtype": "auto"}

model = AutoModelForCausalLM.from_pretrained(
ckpt_path, device_map="auto", **model_args, trust_remote_code=True
)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(ckpt_path)

return model, tokenizer


def set_nested_attribute(obj, attribute_path, value):
keys = attribute_path.split(".")
current_obj = obj
for key in keys[:-1]:
current_obj = getattr(current_obj, key)
setattr(current_obj, keys[-1], value)


def apply_kv_scales_to_model(model, layer_scales_map):
for layer_name, scale_value in layer_scales_map.items():
scale_param = torch.nn.Parameter(torch.tensor(scale_value), requires_grad=False)
set_nested_attribute(model, layer_name, scale_param)


def extract_kv_scales(quantized_model):
def fetch_parameters(filename):
with safe_open(filename, framework="pt") as f:
for name in f.keys():
param_tensor = f.get_tensor(name)
yield name, param_tensor

checkpoint_dir = Path(quantized_model)
if not checkpoint_dir.is_dir():
hub_filenames = weight_hub_files(quantized_model)
downloaded_files = download_weights(hub_filenames, quantized_model)
downloaded_files = weight_files(quantized_model, extension=".safetensors")

layer_scales_map = {}
for tensor_file in downloaded_files:
for name, param in fetch_parameters(tensor_file):
if ".kv_scale" in name:
layer_scales_map[name] = param.item()

return layer_scales_map


def main(quantized_model, model_id, save_path):
layer_scales_map = extract_kv_scales(quantized_model)

model, tokenizer = load_model(model_id)

apply_kv_scales_to_model(model, layer_scales_map)

model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

print(f"Model saved to {save_path}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extract FP8 KV cache scales and add them to a FP16 model."
)
parser.add_argument(
"--quantized-model",
type=str,
help="Path to the FP8 model checkpoint to extract KV cache scales",
)
parser.add_argument(
"--model",
type=str,
help="Model ID of the FP16 model to save the KV cache scales",
)
parser.add_argument(
"--save-path",
type=str,
help="Path to save the FP16 model with the kv scales",
)

args = parser.parse_args()

main(args.quantized_model, args.model, args.save_path)
37 changes: 37 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,28 @@ impl std::fmt::Display for Dtype {
}
}

#[derive(Clone, Copy, Debug, ValueEnum)]
enum KvDtype {
#[clap(name = "fp8")]
Fp8,
#[clap(name = "fp8_e5m2")]
Fp8e5m2,
}

impl std::fmt::Display for KvDtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self {
KvDtype::Fp8 => {
write!(f, "fp8")
}
KvDtype::Fp8e5m2 => {
write!(f, "fp8_e5m2")
}
}
}
}

#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
Linear,
Expand Down Expand Up @@ -214,6 +236,13 @@ struct Args {
#[clap(long, env, value_enum)]
dtype: Option<Dtype>,

// Specify the data type for KV cache. By default, it uses the model's data type.
// CUDA 11.8+ supports `fp8(fp8_e4m3)` and 'fp8_e5m2', while ROCm (AMD GPU) supports `fp8(fp8_e4m3)'.
// If 'fp8_e4m3' is chosen, a model checkpoint with scales for the KV cache should be provided.
// If not provided, the KV cache scaling factors default to 1.0, which may impact accuracy."
#[clap(long, env, value_enum)]
kv_cache_dtype: Option<KvDtype>,

/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been
/// contributed in a newer revision.
Expand Down Expand Up @@ -464,6 +493,7 @@ fn shard_manager(
quantize: Option<Quantization>,
speculate: Option<usize>,
dtype: Option<Dtype>,
kv_cache_dtype: Option<KvDtype>,
trust_remote_code: bool,
uds_path: String,
rank: usize,
Expand Down Expand Up @@ -535,6 +565,11 @@ fn shard_manager(
shard_args.push(dtype.to_string())
}

if let Some(kv_cache_dtype) = kv_cache_dtype {
shard_args.push("--kv-cache-dtype".to_string());
shard_args.push(kv_cache_dtype.to_string());
}

// Model optional revision
if let Some(revision) = revision {
shard_args.push("--revision".to_string());
Expand Down Expand Up @@ -1038,6 +1073,7 @@ fn spawn_shards(
let quantize = args.quantize;
let speculate = args.speculate;
let dtype = args.dtype;
let kv_cache_dtype = args.kv_cache_dtype;
let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels;
Expand All @@ -1055,6 +1091,7 @@ fn spawn_shards(
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
uds_path,
rank,
Expand Down
13 changes: 12 additions & 1 deletion server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional
from enum import Enum
from huggingface_hub import hf_hub_download

from text_generation_server.utils.import_utils import SYSTEM

app = typer.Typer()

Expand Down Expand Up @@ -37,6 +37,7 @@ def serve(
quantize: Optional[Quantization] = None,
speculate: Optional[int] = None,
dtype: Optional[Dtype] = None,
kv_cache_dtype: str = "auto",
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
Expand Down Expand Up @@ -90,13 +91,23 @@ def serve(
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)

if kv_cache_dtype in {"fp8", "fp8_e5m2"}:
if SYSTEM not in {"cuda", "rocm"}:
raise RuntimeError(
f"`{kv_cache_dtype}` KV cache is only supported on Nvidia and AMD GPUs."
)
if kv_cache_dtype == "fp8_e5m2" and SYSTEM != "cuda":
raise RuntimeError(f"`fp8_e5m2` KV cache is only supported on Nvidia GPUs.")

server.serve(
model_id,
revision,
sharded,
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
uds_path,
max_input_tokens,
Expand Down
16 changes: 11 additions & 5 deletions server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ def reshape_and_cache(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
kv_cache_dtype: str = "auto",
kv_scale: int = 1.0,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale
)


def paged_attention(
Expand All @@ -34,6 +38,8 @@ def paged_attention(
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
kv_cache_dtype: str = "auto",
kv_scale: int = 1.0,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
Expand Down Expand Up @@ -78,8 +84,8 @@ def paged_attention(
block_size,
max_s,
None,
"auto",
1.0,
kv_cache_dtype,
kv_scale,
)
else:
# Run PagedAttention V2.
Expand Down Expand Up @@ -111,8 +117,8 @@ def paged_attention(
block_size,
max_s,
None,
"auto",
1.0,
kv_cache_dtype,
kv_scale,
)


Expand Down
Loading
Loading