diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 7599562a55f..cc215caec29 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -43,6 +43,8 @@
title: Monitoring TGI with Prometheus and Grafana
- local: basic_tutorials/train_medusa
title: Train Medusa
+ - local: basic_tutorials/fp8_kv_cache
+ title: FP8 KV Cache
title: Tutorials
- sections:
- local: conceptual/streaming
diff --git a/docs/source/basic_tutorials/fp8_kv_cache.md b/docs/source/basic_tutorials/fp8_kv_cache.md
new file mode 100644
index 00000000000..35dad187520
--- /dev/null
+++ b/docs/source/basic_tutorials/fp8_kv_cache.md
@@ -0,0 +1,102 @@
+# Accelerating Inference with FP8 KV Cache
+
+Text Generation Inference (TGI) supports FP8 KV Cache, enhancing inference speed on both Nvidia and AMD GPUs.
+
+FP8 KV Cache enhances the efficiency of text generation by quantizing the KV cache to FP8 format. Quantizing the KV cache to FP8 reduces its memory footprint, enabling storage of more tokens in cache. This improves overall throughput in text generation tasks.
+
+In FP8 KV Cache, while the KV cache is stored in quantized FP8 format for memory efficiency, computations are performed in FP16 format. This strategy strikes a balance between conserving memory and maintaining computational accuracy.
+
+## FP8 Formats: E4M3 and E5M2
+The Open Compute Project (OCP) defines two common 8-bit floating point data formats:
+
+E4M3:
+
+* 1 sign bit
+* 4 biased exponent bits
+* 3 mantissa bits
+
+E5M2:
+
+* 1 sign bit
+* 5 biased exponent bits
+* 2 mantissa bits
+
+E4M3 offers higher precision for representing floating point numbers. However, due to its limited range, E4M3 typically requires a higher-precision (usually FP32) scaling factor alongside each quantized tensor. Currently, TGI supports only per-tensor (scalar) scaling factors.
+
+## Current Hardware Support
+
+* Nvidia GPUs: Supports both FP8E4M3 (fp8) and FP8E5M2 (fp8_e5m2).
+* AMD GPUs: Supports FP8E4M3FNUZ (fp8).
+
+## FP8 E5M2 KV Cache
+Example usage:
+```bash
+model=meta-llama/Llama-2-70b-chat-hf
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+tag=<...> # TGI docker tag
+
+docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
+ ghcr.io/huggingface/text-generation-inference:$tag \
+ --model-id $model \
+ --kv-cache-dtype fp8_e5m2
+```
+
+## FP8 E4M3 KV Cache
+While E4M3 offers higher precision, it requires careful handling of scaling factors to maintain accuracy. Therefore, it is recommended to provide KV cache scaling factors as part of the FP16 checkpoint. If scaling factors are not provided, a default factor of 1.0 is used, which may lead to accuracy loss.
+
+Example usage:
+
+
+
+
+```bash
+model=mohitsha/Llama-2-70b-chat-hf-FP8-KV-AMMO
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+tag=<...> # TGI docker tag
+
+docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
+ ghcr.io/huggingface/text-generation-inference:$tag \
+ --model-id $model \
+ --kv-cache-dtype fp8
+```
+
+
+
+
+```bash
+model=mohitsha/Llama-2-70b-chat-hf-FP8-KV-AMMO
+volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
+tag=<...> # TGI docker tag
+
+docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
+ --device=/dev/kfd --device=/dev/dri --group-add video \
+ --ipc=host --shm-size 256g --net host -v $volume:/data \
+ ghcr.io/huggingface/text-generation-inference:$tag \
+ --model-id $model \
+ --kv-cache-dtype fp8
+```
+
+
+
+
+
+`mohitsha/Llama-2-70b-chat-hf-FP8-KV-AMMO`: LLama 70B model with FP8 KV scales generated using Nvidia AMMO.
+
+### Checkpoint structure for KV scales
+The FP8 kv cache scaling factors, required in the model, are specified through the `.kv_scale` parameter present in the `Attention` module, such as:
+
+```
+model.layers.0.self_attn.kv_scale < F32
+model.layers.1.self_attn.kv_scale < F32
+...
+```
+
+When providing `.kv_scale` in model, the config should specify proper `kv_cache_torch_dtype` used to generate scales (`float8_e4m3fn` or `float8_e4m3fnuz`).
+
+Example config: [Llama-2-7b-chat-hf-FP8-KV#config.json](https://huggingface.co/mohitsha/Llama-2-7b-chat-hf-FP8-KV/blob/main/config.json#L14)
+
+### Generating model with KV Cache scales
+
+TGI provides a utility to generate model with FP8 KV cache scales using Nvidia AMMO for use with TGI. For more information: [create_fp8_kv_scales_model.py](https://github.com/huggingface/text-generation-inference/examples/fp8_kvcache/create_fp8_kv_scales_model.py)
+
+Alternatively, you can use other quantizer tools to obtain these scaling factors.
diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md
index 9246093e2a8..bf6a6934dbc 100644
--- a/docs/source/basic_tutorials/launcher.md
+++ b/docs/source/basic_tutorials/launcher.md
@@ -87,6 +87,15 @@ Options:
[env: DTYPE=]
[possible values: float16, bfloat16]
+```
+## KV_CACHE_DTYPE
+```shell
+ --kv-cache-dtype
+ Specify the data type for KV cache. By default, it uses the model's data type. CUDA 11.8+ supports `fp8(fp8_e4m3)` and 'fp8_e5m2', while ROCm (AMD GPU) supports `fp8(fp8_e4m3fnuz)'. If 'fp8' is chosen, a model checkpoint with scales for the KV cache should be provided. If not provided, the KV cache scaling factors default to 1.0, which may impact accuracy."
+
+ [env: KV_CACHE_DTYPE=]
+ [possible values: fp8, fp8_e5m2]
+
```
## TRUST_REMOTE_CODE
```shell
diff --git a/examples/fp8_kvcache/README.md b/examples/fp8_kvcache/README.md
new file mode 100644
index 00000000000..23781b754ef
--- /dev/null
+++ b/examples/fp8_kvcache/README.md
@@ -0,0 +1,52 @@
+# FP8 (fp8_e4m3) KV Cache Scaling Factor Utility
+
+This utility is provided to generate model with `FP8(fp8_e4m3)` quantized KV cache scales. The generated scaling factors are then saved to the corresponding HF model, which can be used with Text Generation Inference (TGI).
+
+The KV scales are integrated into the HF model in the following format. The FP8 KV cache scaling factors are specified through the `.kv_scale` parameter within the `Attention` module, as shown below:
+
+
+```
+model.layers.0.self_attn.kv_scale < F32
+model.layers.1.self_attn.kv_scale < F32
+...
+```
+
+Additionally, `kv_cache_torch_dtype` attribute is added to `config.json` which indicates the torch dtype (`float8_e4m3fn` in this utility) used to generate scales.
+
+Example config: [Llama-2-7b-chat-hf-FP8-KV#config.json](https://huggingface.co/mohitsha/Llama-2-7b-chat-hf-FP8-KV/blob/main/config.json#L14)
+
+Note: The utility supports only a selected LLAMA type models. Please adapt the script for other models.
+
+## Prerequisites
+
+- Nvidia AMMO (nvidia-ammo==0.7.1)
+- Hugging Face Transformers
+
+```bash
+pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo==0.7.1
+```
+
+## CLI options
+```
+usage: create_fp8_kv_scales_model.py [-h] --model_dir MODEL_DIR [--device DEVICE] [--dtype DTYPE] [--batch_size BATCH_SIZE] [--calib_size CALIB_SIZE] [--output_dir OUTPUT_DIR]
+
+Adapted from examples/quantization/hf_ptq.py
+
+options:
+ -h, --help show this help message and exit
+ --model_dir MODEL_DIR
+ Specify where the HuggingFace model is
+ --device DEVICE
+ --dtype DTYPE Model data type.
+ --batch_size BATCH_SIZE
+ Batch size for calibration.
+ --calib_size CALIB_SIZE
+ Number of samples for calibration.
+ --output_dir OUTPUT_DIR
+
+```
+
+## Example usage
+```
+python create_fp8_kv_scales_model.py --model_dir meta-llama/Llama-2-70b-chat-hf --output_dir output
+```
diff --git a/examples/fp8_kvcache/create_fp8_kv_scales_model.py b/examples/fp8_kvcache/create_fp8_kv_scales_model.py
new file mode 100644
index 00000000000..88f4601406e
--- /dev/null
+++ b/examples/fp8_kvcache/create_fp8_kv_scales_model.py
@@ -0,0 +1,278 @@
+# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # noqa: E501
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Adapted from examples/quantization/hf_ptq.py
+"""
+
+import argparse
+import copy
+import json
+import random
+import time
+from safetensors.torch import safe_open
+
+import ammo.torch.quantization as atq
+import numpy as np
+import torch
+from ammo.torch.export import export_model_config
+from datasets import load_dataset
+from torch.utils.data import DataLoader
+from transformers import AutoModelForCausalLM, AutoTokenizer
+import tqdm
+import tempfile
+
+RAND_SEED = 1234
+MAX_SEQ_LEN = 2048
+
+QUANT_CONFIG = {
+ "quant_cfg": {
+ "*weight_quantizer": {"enable": False},
+ "*input_quantizer": {"enable": False},
+ "*lm_head*": {"enable": False},
+ "*output_layer*": {"enable": False},
+ "default": {"enable": False},
+ "*.query_key_value.output_quantizer": {
+ "num_bits": (4, 3),
+ "axis": None,
+ "enable": True,
+ },
+ "*.Wqkv.output_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True},
+ "*.W_pack.output_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True},
+ "*.c_attn.output_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True},
+ "*.k_proj.output_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True},
+ "*.v_proj.output_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True},
+ },
+ "algorithm": "max",
+}
+
+
+MODEL_NAME_PATTERN_MAP = {
+ "Llama": "llama",
+ "Mistral": "llama",
+ "baichuan": "baichuan",
+ "QWen": "qwen",
+}
+
+
+def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None):
+ print(f"Initializing tokenizer from {ckpt_path}")
+ tokenizer = AutoTokenizer.from_pretrained(
+ ckpt_path,
+ model_max_length=max_seq_len,
+ padding_side="left",
+ trust_remote_code=True,
+ )
+ if model_type and model_type == "qwen":
+ # qwen use token id 151643 as pad and eos tokens
+ tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643)
+ tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643)
+
+ # can't set attribute 'pad_token' for ""
+ if tokenizer.pad_token != "":
+ tokenizer.pad_token = tokenizer.eos_token
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ assert tokenizer.pad_token is not None, f"Pad token for {model_type} cannot be set!"
+
+ return tokenizer
+
+
+def get_model(ckpt_path, dtype="fp16", device="cuda"):
+ print(f"Initializing model from {ckpt_path}")
+ if dtype == "bf16" or dtype == "bfloat16":
+ dtype = torch.bfloat16
+ elif dtype == "fp16" or dtype == "float16":
+ dtype = torch.float16
+ elif dtype == "fp32" or dtype == "float32":
+ dtype = torch.float32
+ else:
+ raise NotImplementedError(f"Unknown dtype {dtype}")
+
+ model_kwargs = {"torch_dtype": "auto"}
+
+ model = AutoModelForCausalLM.from_pretrained(
+ ckpt_path, device_map="auto", **model_kwargs, trust_remote_code=True
+ )
+ model.eval()
+
+ model_dtype = next(model.parameters()).dtype
+ if dtype != model_dtype:
+ print(
+ "[TensorRT-LLM][WARNING] The manually set model data type is "
+ f"{dtype}, but the data type of the HuggingFace model is "
+ f"{model_dtype}."
+ )
+
+ return model
+
+
+def get_model_type(model):
+ for k, v in MODEL_NAME_PATTERN_MAP.items():
+ if k.lower() in type(model).__name__.lower():
+ return v
+ return None
+
+
+def get_calib_dataloader(
+ data="cnn_dailymail",
+ tokenizer=None,
+ batch_size=1,
+ calib_size=512,
+ block_size=512,
+ device=None,
+):
+ print("Loading calibration dataset")
+ if data == "pileval":
+ dataset = load_dataset(
+ "json",
+ data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
+ split="train",
+ )
+ dataset = dataset["text"][:calib_size]
+ elif data == "cnn_dailymail":
+ dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
+ dataset = dataset["article"][:calib_size]
+ else:
+ raise NotImplementedError
+
+ batch_encoded = tokenizer.batch_encode_plus(
+ dataset,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=block_size,
+ )
+ if device:
+ batch_encoded = batch_encoded.to(device)
+ batch_encoded = batch_encoded["input_ids"]
+
+ calib_dataloader = DataLoader(batch_encoded, batch_size=batch_size, shuffle=False)
+
+ return calib_dataloader
+
+
+def quantize_model(model, quant_cfg, num_calib_samples, calib_dataloader=None):
+
+ def calibrate_loop():
+ if calib_dataloader is None:
+ return
+ """Adjusts weights and scaling factors based on selected algorithms."""
+ for idx, data in tqdm.tqdm(
+ enumerate(calib_dataloader), total=num_calib_samples
+ ):
+ model(data)
+
+ print("Starting quantization...")
+ start_time = time.time()
+ atq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
+ end_time = time.time()
+ print("Quantization done. Total time used: {:.2f} s.".format(end_time - start_time))
+
+ return model
+
+
+def set_kv_scales(model, scales):
+ for i, scale in scales.items():
+ scale_param = torch.nn.Parameter(torch.tensor(scale), requires_grad=False)
+ model.model.layers[int(i)].self_attn.kv_scale = scale_param
+
+ if hasattr(model.model.layers[int(i)].self_attn.k_proj, "output_quantizer"):
+ del model.model.layers[int(i)].self_attn.k_proj.output_quantizer
+ if hasattr(model.model.layers[int(i)].self_attn.v_proj, "output_quantizer"):
+ del model.model.layers[int(i)].self_attn.v_proj.output_quantizer
+
+
+def main(args):
+ if not torch.cuda.is_available():
+ raise EnvironmentError("GPU is required for inference.")
+
+ random.seed(RAND_SEED)
+ np.random.seed(RAND_SEED)
+
+ model = get_model(args.model_dir, args.dtype, args.device)
+ model_type = get_model_type(model)
+ tokenizer = get_tokenizer(args.model_dir, model_type=model_type)
+
+ calib_dataloader = get_calib_dataloader(
+ tokenizer=tokenizer,
+ batch_size=args.batch_size,
+ calib_size=args.calib_size,
+ device=args.device,
+ )
+
+ model = quantize_model(model, QUANT_CONFIG, args.calib_size, calib_dataloader)
+
+ with torch.inference_mode():
+ if model_type is None:
+ print(
+ f"Unknown model type {type(model).__name__}. Continue " "exporting..."
+ )
+ model_type = f"unknown:{type(model).__name__}"
+
+ export_path = args.output_dir
+
+ with tempfile.TemporaryDirectory() as temp_dir:
+ # export safetensors
+ export_model_config(
+ model,
+ model_type,
+ getattr(torch, args.dtype),
+ export_dir=temp_dir,
+ inference_tensor_parallel=1,
+ inference_pipeline_parallel=1,
+ export_tensorrt_llm_config=False,
+ export_npz=False,
+ )
+
+ def load_safetensor(filename: str):
+ with safe_open(filename, framework="pt") as f:
+ for name in f.keys():
+ param = f.get_tensor(name)
+ yield name, param
+
+ layer_scales_map = {}
+ for name, param in load_safetensor(temp_dir + "/rank0.safetensors"):
+ if "kv_cache" in name:
+ nums = [int(s) for s in name.split(".") if s.isdecimal()]
+ if len(nums) != 1:
+ raise ValueError(f"Could not determine layer idx for {name}")
+
+ layer_idx = nums[0]
+ layer_scales_map[layer_idx] = param.item()
+
+ set_kv_scales(model, layer_scales_map)
+ model.config.kv_cache_dtype = "float8_e4m3fn"
+
+ model.save_pretrained(export_path)
+ tokenizer.save_pretrained(export_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "--model_dir", help="Specify where the HuggingFace model is", required=True
+ )
+ parser.add_argument("--device", default="cuda")
+ parser.add_argument("--dtype", help="Model data type.", default="float16")
+ parser.add_argument(
+ "--batch_size", help="Batch size for calibration.", type=int, default=1
+ )
+ parser.add_argument(
+ "--calib_size", help="Number of samples for calibration.", type=int, default=512
+ )
+ parser.add_argument("--output_dir", default="exported_model")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/launcher/src/main.rs b/launcher/src/main.rs
index e4d5bb85107..af5ab066e67 100644
--- a/launcher/src/main.rs
+++ b/launcher/src/main.rs
@@ -144,6 +144,28 @@ impl std::fmt::Display for Dtype {
}
}
+#[derive(Clone, Copy, Debug, ValueEnum)]
+enum KvDtype {
+ #[clap(name = "fp8")]
+ Fp8,
+ #[clap(name = "fp8_e5m2")]
+ Fp8e5m2,
+}
+
+impl std::fmt::Display for KvDtype {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ // To keep in track with `server`.
+ match self {
+ KvDtype::Fp8 => {
+ write!(f, "fp8")
+ }
+ KvDtype::Fp8e5m2 => {
+ write!(f, "fp8_e5m2")
+ }
+ }
+ }
+}
+
#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
Linear,
@@ -214,6 +236,13 @@ struct Args {
#[clap(long, env, value_enum)]
dtype: Option,
+ /// Specify the data type for KV cache. By default, it uses the model's data type.
+ /// CUDA 11.8+ supports `fp8(fp8_e4m3)` and 'fp8_e5m2', while ROCm (AMD GPU) supports `fp8(fp8_e4m3fnuz)'.
+ /// If 'fp8' is chosen, a model checkpoint with scales for the KV cache should be provided.
+ /// If not provided, the KV cache scaling factors default to 1.0, which may impact accuracy."
+ #[clap(long, env, value_enum)]
+ kv_cache_dtype: Option,
+
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been
/// contributed in a newer revision.
@@ -464,6 +493,7 @@ fn shard_manager(
quantize: Option,
speculate: Option,
dtype: Option,
+ kv_cache_dtype: Option,
trust_remote_code: bool,
uds_path: String,
rank: usize,
@@ -535,6 +565,11 @@ fn shard_manager(
shard_args.push(dtype.to_string())
}
+ if let Some(kv_cache_dtype) = kv_cache_dtype {
+ shard_args.push("--kv-cache-dtype".to_string());
+ shard_args.push(kv_cache_dtype.to_string());
+ }
+
// Model optional revision
if let Some(revision) = revision {
shard_args.push("--revision".to_string());
@@ -1038,6 +1073,7 @@ fn spawn_shards(
let quantize = args.quantize;
let speculate = args.speculate;
let dtype = args.dtype;
+ let kv_cache_dtype = args.kv_cache_dtype;
let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels;
@@ -1055,6 +1091,7 @@ fn spawn_shards(
quantize,
speculate,
dtype,
+ kv_cache_dtype,
trust_remote_code,
uds_path,
rank,
diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py
index 430323bcd5b..8dd1e6e8a33 100644
--- a/server/text_generation_server/cli.py
+++ b/server/text_generation_server/cli.py
@@ -7,7 +7,7 @@
from typing import Optional
from enum import Enum
from huggingface_hub import hf_hub_download
-
+from text_generation_server.utils.import_utils import SYSTEM
app = typer.Typer()
@@ -29,6 +29,12 @@ class Dtype(str, Enum):
bloat16 = "bfloat16"
+class KVDtype(str, Enum):
+ auto = "auto"
+ fp8 = "fp8"
+ fp8_e5m2 = "fp8_e5m2"
+
+
@app.command()
def serve(
model_id: str,
@@ -37,6 +43,7 @@ def serve(
quantize: Optional[Quantization] = None,
speculate: Optional[int] = None,
dtype: Optional[Dtype] = None,
+ kv_cache_dtype: KVDtype = "auto",
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
@@ -90,6 +97,15 @@ def serve(
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
+
+ if kv_cache_dtype in {"fp8", "fp8_e5m2"}:
+ if SYSTEM not in {"cuda", "rocm"}:
+ raise RuntimeError(
+ f"`{kv_cache_dtype}` KV cache is only supported on Nvidia and AMD GPUs."
+ )
+ if kv_cache_dtype == "fp8_e5m2" and SYSTEM != "cuda":
+ raise RuntimeError(f"`fp8_e5m2` KV cache is only supported on Nvidia GPUs.")
+
server.serve(
model_id,
revision,
@@ -97,6 +113,7 @@ def serve(
quantize,
speculate,
dtype,
+ kv_cache_dtype,
trust_remote_code,
uds_path,
max_input_tokens,
diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py
index 583337bdb81..e55a0e771b1 100644
--- a/server/text_generation_server/layers/attention/cuda.py
+++ b/server/text_generation_server/layers/attention/cuda.py
@@ -20,8 +20,12 @@ def reshape_and_cache(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
+ kv_cache_dtype: str = "auto",
+ kv_scale: int = 1.0,
):
- cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
+ cache_ops.reshape_and_cache(
+ key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale
+ )
def paged_attention(
@@ -34,6 +38,8 @@ def paged_attention(
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
+ kv_cache_dtype: str = "auto",
+ kv_scale: int = 1.0,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
@@ -78,8 +84,8 @@ def paged_attention(
block_size,
max_s,
None,
- "auto",
- 1.0,
+ kv_cache_dtype,
+ kv_scale,
)
else:
# Run PagedAttention V2.
@@ -111,8 +117,8 @@ def paged_attention(
block_size,
max_s,
None,
- "auto",
- 1.0,
+ kv_cache_dtype,
+ kv_scale,
)
diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py
index 91ed5818eb6..e407e1a8745 100644
--- a/server/text_generation_server/layers/attention/rocm.py
+++ b/server/text_generation_server/layers/attention/rocm.py
@@ -25,8 +25,12 @@ def reshape_and_cache(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
+ kv_cache_dtype: str = "auto",
+ kv_scale: int = 1.0,
):
- cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
+ cache_ops.reshape_and_cache(
+ key, value, key_cache, value_cache, slots, kv_cache_dtype, kv_scale
+ )
def paged_attention(
@@ -39,6 +43,8 @@ def paged_attention(
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
+ kv_cache_dtype: str = "auto",
+ kv_scale: int = 1.0,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
@@ -83,8 +89,8 @@ def paged_attention(
block_size,
max_s,
None,
- "auto",
- 1.0,
+ kv_cache_dtype,
+ kv_scale,
)
else:
# Run PagedAttention V2.
@@ -116,8 +122,8 @@ def paged_attention(
block_size,
max_s,
None,
- "auto",
- 1.0,
+ kv_cache_dtype,
+ kv_scale,
)
diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/xpu.py
index 8b6cb87b5f3..628e578960c 100644
--- a/server/text_generation_server/layers/attention/xpu.py
+++ b/server/text_generation_server/layers/attention/xpu.py
@@ -39,6 +39,8 @@ def reshape_and_cache(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
+ kv_cache_dtype: str = "auto",
+ kv_scale: int = 1.0,
):
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
@@ -55,6 +57,8 @@ def paged_attention(
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
+ kv_cache_dtype: str = "auto",
+ kv_scale: int = 1.0,
):
query = query.contiguous()
block_size = value_cache.shape[3]
diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py
index 76dca3dc81c..a2c240d9820 100644
--- a/server/text_generation_server/models/__init__.py
+++ b/server/text_generation_server/models/__init__.py
@@ -246,6 +246,12 @@ class ModelType(enum.Enum):
}
+FP8_KVCACHE_SUPPORTED_MODELS = {
+ "llama",
+ "baichun",
+ "phi3",
+}
+
__GLOBALS = locals()
for data in ModelType:
__GLOBALS[data.name] = data.value["type"]
@@ -258,6 +264,7 @@ def get_model(
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str],
+ kv_cache_dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
) -> Model:
@@ -287,6 +294,11 @@ def get_model(
)
model_type = config_dict.get("model_type", None)
+ if model_type not in FP8_KVCACHE_SUPPORTED_MODELS and kv_cache_dtype != "auto":
+ raise RuntimeError(
+ f"kv_cache_dtype is only supported for {', '.join(FP8_KVCACHE_SUPPORTED_MODELS)} models. Got model_type: {model_type}, kv_cache_dtype: {kv_cache_dtype}"
+ )
+
speculator = None
if "medusa_num_heads" in config_dict:
medusa_model_id = model_id
@@ -594,6 +606,7 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
+ kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
index 0d06d1048c0..1950ea98104 100644
--- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
+++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
@@ -44,6 +44,8 @@
FastRMSNorm,
)
+from loguru import logger
+
if SYSTEM == "rocm":
try:
from vllm import _custom_C
@@ -134,6 +136,16 @@ def __init__(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
+ self.kv_cache_dtype = config.kv_cache_dtype
+
+ if self.kv_cache_dtype == "fp8":
+ self.kv_scale = weights.get_kv_cache_scaling_factor(
+ prefix, self.kv_cache_dtype, config.kv_cache_torch_dtype
+ )
+ else:
+ self.kv_scale = 1.0
+ logger.info(f"kv_cache_dtype: {self.kv_cache_dtype}, kv_scale: {self.kv_scale}")
+
def forward(
self,
hidden_states,
@@ -159,7 +171,15 @@ def forward(
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
- reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
+ reshape_and_cache(
+ kv[:, 0],
+ kv[:, 1],
+ kv_cache[0],
+ kv_cache[1],
+ slots,
+ self.kv_cache_dtype,
+ self.kv_scale,
+ )
# output tensor
attn_output = torch.empty_like(query)
@@ -188,6 +208,8 @@ def forward(
block_tables,
input_lengths,
max_s,
+ self.kv_cache_dtype,
+ self.kv_scale,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py
index d16d371068e..470084c9824 100644
--- a/server/text_generation_server/models/flash_causal_lm.py
+++ b/server/text_generation_server/models/flash_causal_lm.py
@@ -729,6 +729,7 @@ def __init__(
rank: int = 0,
world_size: int = 1,
sliding_window: Optional[int] = None,
+ kv_cache_dtype: Optional[torch.dtype] = None,
):
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
@@ -737,6 +738,8 @@ def __init__(
self.cuda_graphs = {}
self.kv_cache = []
+ self.kv_cache_dtype = kv_cache_dtype if kv_cache_dtype else dtype
+
super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
@@ -854,7 +857,7 @@ def warmup(self, batch: FlashCausalLMBatch):
self.num_layers,
self.num_kv_heads,
self.head_size,
- self.dtype,
+ self.kv_cache_dtype,
self.device,
)
max_bt = batch.max_blocks
@@ -873,7 +876,7 @@ def warmup(self, batch: FlashCausalLMBatch):
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
- dtype_size = torch.tensor([], dtype=self.dtype).element_size()
+ dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
@@ -893,7 +896,7 @@ def warmup(self, batch: FlashCausalLMBatch):
self.num_layers,
self.num_kv_heads,
self.head_size,
- self.dtype,
+ self.kv_cache_dtype,
self.device,
)
diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py
index e27f0da28a4..37628513616 100644
--- a/server/text_generation_server/models/flash_llama.py
+++ b/server/text_generation_server/models/flash_llama.py
@@ -28,6 +28,7 @@ def __init__(
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
+ kv_cache_dtype: Optional[str] = "auto",
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
@@ -62,6 +63,9 @@ def __init__(
)
config.quantize = quantize
config.speculator = speculator
+ config.kv_cache_dtype = kv_cache_dtype
+ if not hasattr(config, "kv_cache_torch_dtype"):
+ config.kv_cache_torch_dtype = None
torch.distributed.barrier(group=self.process_group)
@@ -72,6 +76,7 @@ def __init__(
prefix = ""
model = FlashLlamaForCausalLM(prefix, config, weights)
+
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model=model,
@@ -83,4 +88,5 @@ def __init__(
device=device,
rank=rank,
world_size=world_size,
+ kv_cache_dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype,
)
diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py
index a0347cd8e73..d3fc464c7dc 100644
--- a/server/text_generation_server/server.py
+++ b/server/text_generation_server/server.py
@@ -197,6 +197,7 @@ def serve(
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str],
+ kv_cache_dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
max_input_tokens: int,
@@ -208,6 +209,7 @@ async def serve_inner(
quantize: Optional[str] = None,
speculate: Optional[int] = None,
dtype: Optional[str] = None,
+ kv_cache_dtype: Optional[str] = "auto",
trust_remote_code: bool = False,
):
unix_socket_template = "unix://{}-{}"
@@ -229,6 +231,7 @@ async def serve_inner(
quantize,
speculate,
dtype,
+ kv_cache_dtype,
trust_remote_code,
max_input_tokens,
)
@@ -266,6 +269,13 @@ async def serve_inner(
set_model_id(model_id)
asyncio.run(
serve_inner(
- model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
+ model_id,
+ revision,
+ sharded,
+ quantize,
+ speculate,
+ dtype,
+ kv_cache_dtype,
+ trust_remote_code,
)
)
diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py
index e61425254a0..68f89f645ca 100644
--- a/server/text_generation_server/utils/weights.py
+++ b/server/text_generation_server/utils/weights.py
@@ -7,6 +7,7 @@
from loguru import logger
from huggingface_hub import hf_hub_download
import json
+from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
@@ -88,7 +89,11 @@ def get_tensor(self, tensor_name: str, to_device=True):
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32. Exl2 uses int16
# as well.
- if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
+ if tensor.dtype not in [
+ torch.int16,
+ torch.int32,
+ torch.int64,
+ ] and not tensor_name.endswith("kv_scale"):
tensor = tensor.to(dtype=self.dtype)
if to_device:
tensor = tensor.to(device=self.device)
@@ -762,6 +767,50 @@ def _set_gptq_params(self, model_id, revision):
except Exception:
pass
+ def get_kv_cache_scaling_factor(
+ self, prefix: str, kv_cache_dtype: str, kv_cache_torch_dtype: str
+ ):
+ try:
+ kv_scale = self.get_tensor(f"{prefix}.kv_scale").cpu().tolist()
+ except RuntimeError:
+ if kv_cache_dtype == "fp8":
+ log_once(
+ logger.warning,
+ "Could not find the `kv_scale` in checkpoint for `fp8_e4m3`. Using scaling factor"
+ "`1.0`. This may result in accuracy issues. Please ensure the checkpoint includes "
+ "the correct KV cache scaling factor.",
+ )
+
+ kv_scale = 1.0
+ else:
+ if kv_cache_dtype == "fp8_e5m2":
+ raise RuntimeError(
+ "Found `kv_scale` in the checkpoint, but `fp8_e5m2` KV dtype do not support `kv_scale` > 1.0"
+ )
+
+ if not isinstance(kv_scale, float):
+ raise RuntimeError(
+ "Only support per-tensor scaling factor for `fp8 (fp8_e4m3)` KV cache"
+ )
+
+ if kv_cache_torch_dtype not in {"float8_e4m3fn", "float8_e4m3fnuz"}:
+ raise RuntimeError(
+ f"Found `kv_scale` in the checkpoint, the config must specify the `kv_cache_torch_dtype` "
+ f"used for generating kv scales. Expected 'float8_e4m3fn' or 'float8_e4m3fnuz', but got '{kv_cache_torch_dtype}'."
+ )
+
+ # ROCm uses FP8 format with fp8_e4m3fn, whereas Nvidia GPUs use fp8_e4m3.
+ # The multiplication by 2 compensates for the different numeric representation
+ # between ROCm and Nvidia GPUs, ensuring consistent effective scaling across platforms.
+ # After this adjustment, the overall effect is equivalent to the scaling applied without
+ # it on Nvidia GPUs.
+ if SYSTEM == "rocm" and kv_cache_torch_dtype == "float8_e4m3fn":
+ kv_scale *= 2.0
+ elif SYSTEM == "cuda" and kv_cache_torch_dtype == "float8_e4m3fnuz":
+ kv_scale /= 2.0
+
+ return kv_scale
+
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
"""