Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Work in progress] Add FP8 support in fwd_prefill #115

Draft
wants to merge 62 commits into
base: main_perf
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
94b2da3
Alex fp8 work
alexkranias-amd Dec 4, 2024
7434112
fix mismatches
micmelesse Dec 9, 2024
2ea54b1
no navi for now
micmelesse Dec 9, 2024
89d3d7d
fix: ref uses scaling + added ENV VAR to enable/disable quantization …
alexkranias-amd Dec 9, 2024
c65af82
fix: fp8 ref matches kernel
alexkranias-amd Dec 9, 2024
9297d78
misc: added note about p_scale
alexkranias-amd Dec 9, 2024
f92ca5b
feat: added precision error test for various triton ops
alexkranias-amd Dec 9, 2024
9ed1d00
save
alexkranias-amd Dec 4, 2024
1c3f756
feat: added benchmark for fp8 flash attention
alexkranias-amd Dec 10, 2024
c4ca789
fix: quantization scaling in fp8 benchmark
alexkranias-amd Dec 10, 2024
937e814
checkpoint
alexkranias-amd Dec 4, 2024
fd342f7
feat: added fp8 to precision test
alexkranias-amd Dec 11, 2024
543736b
fix: refactor fp32 for torch, moved scaling of fp8 to out of kernel
alexkranias-amd Dec 13, 2024
210e2df
Fix test_op_fwd_prefill
brunomazzottiamd Dec 16, 2024
d8dd966
Document two tests that are failing with FP8
brunomazzottiamd Dec 16, 2024
1835390
Remove cast to fp16, output is already being cast to fp32
brunomazzottiamd Dec 17, 2024
a2624a9
Increase error tolerance for fp8
brunomazzottiamd Dec 17, 2024
8eab5e5
Enable more test cases
brunomazzottiamd Dec 17, 2024
0cf49ce
Fix bug for "bshd" layout
brunomazzottiamd Dec 17, 2024
f413f33
Take max fp8 value into account while computing scales
brunomazzottiamd Dec 18, 2024
4f3e633
Compute 1st FA GEMM without casting to fp16
brunomazzottiamd Dec 18, 2024
6773e3a
Remove redundant `v.to(v.type.element_ty)` cast
brunomazzottiamd Dec 18, 2024
b31cd5d
Fix global scaling for "bhsd" and "bshd" layouts
brunomazzottiamd Dec 18, 2024
3044d7b
[WIP] First attempt to support "thd" layout
brunomazzottiamd Dec 18, 2024
5856c6b
Refactor fp8 scale computation
brunomazzottiamd Dec 23, 2024
a170a08
Compute p scale factor and pass it to the kernel
brunomazzottiamd Dec 23, 2024
85c62ae
Fix minor coding mistakes
brunomazzottiamd Dec 23, 2024
13b07df
Use p scale factor in the kernel
brunomazzottiamd Dec 23, 2024
02a4d8f
Improve scale factor generation
brunomazzottiamd Dec 23, 2024
4796838
Compute per batch / head fp8 scale for varlen layout
brunomazzottiamd Dec 26, 2024
d767bdc
Scale "thd" varlen input tensors and enable related tests
brunomazzottiamd Dec 26, 2024
920ad12
Split test_op_prefill_fwd_impl in two (one for fp16 and other for fp8)
brunomazzottiamd Dec 27, 2024
478fa9c
Specify desired output dtype of ref. implementation
brunomazzottiamd Dec 27, 2024
6d6a0e4
Restore requires_grad of q, k, v after fp8 scaling
brunomazzottiamd Dec 27, 2024
a3db94d
Test global scaling + per batch / head scaling
brunomazzottiamd Dec 27, 2024
565fdde
Change output type to fp16 and document test failures
brunomazzottiamd Dec 27, 2024
eaa2dc6
Remove fp32 casts from reference implementation
brunomazzottiamd Dec 27, 2024
90eae97
[CLEANUP] Remove prototype files
brunomazzottiamd Dec 30, 2024
97319de
[CLEANUP] Revert exploratory changes in some tests
brunomazzottiamd Dec 30, 2024
19633cf
[CLEANUP] Minimize diff hunks in the fwd_prefill kernel
brunomazzottiamd Dec 30, 2024
40caf86
[ORG] Create new module to group fp8 related features
brunomazzottiamd Dec 30, 2024
4fe957d
[ORG] Add type annotations to fp8 module
brunomazzottiamd Dec 30, 2024
73382d9
[ORG] Make check_is_fp8 function more general
brunomazzottiamd Dec 30, 2024
5750d29
[ORG] Scale q k v outside Triton implementation
brunomazzottiamd Dec 30, 2024
59f76f0
[WIP] Implement fp8 scaling in ref. implementation
brunomazzottiamd Dec 30, 2024
68dbef9
[WIP] Implement fp8 scaling in ref. implementation
brunomazzottiamd Jan 2, 2025
1847bbd
Ref. impl. supporting MQA and GQA improved the situation
brunomazzottiamd Jan 2, 2025
7d457c4
[WIP] Implement fp8 scaling in ref. implementation
brunomazzottiamd Jan 2, 2025
4ba9d87
Ref. impl. supporting varlen improved the situation
brunomazzottiamd Jan 2, 2025
9b3a095
Decrease absolute error tolerance
brunomazzottiamd Jan 2, 2025
9e49dc8
Document and cleanup fp8.py module
brunomazzottiamd Jan 3, 2025
0d2f0c6
Reduce whitespace changes in fwd_prefill.py
brunomazzottiamd Jan 3, 2025
5b651b9
Code cleanup in fwd_ref.py
brunomazzottiamd Jan 3, 2025
e6a515e
Revert changes in test_op_prefill_fwd_impl test
brunomazzottiamd Jan 3, 2025
fbe8a37
Reduce minor whitespace changes
brunomazzottiamd Jan 3, 2025
6b0e777
Remove needless casts tl.float32 in 2nd GEMM
brunomazzottiamd Jan 3, 2025
9de6785
[TEMP] Collect fp8 error data
brunomazzottiamd Jan 6, 2025
de7194f
Compare fp8 Triton with fp16 Triton in unit test
brunomazzottiamd Jan 6, 2025
c0dd573
[CLEANUP] Remove fp8 benchmark
brunomazzottiamd Jan 7, 2025
0494786
Revert "[TEMP] Collect fp8 error data"
brunomazzottiamd Jan 7, 2025
9cef5c2
Fix typos in comments
brunomazzottiamd Jan 7, 2025
8a1630b
Invert scale factors and perform scaling with multiplication
brunomazzottiamd Jan 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx1100"]]'
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"]]'
else
echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]'
fi
Expand Down
263 changes: 263 additions & 0 deletions flash_attn/flash_attn_triton_amd/fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
import os
from typing import Optional

import torch
from torch import Tensor

from .utils import get_shape_from_layout, MetaData


REMOVE_QUANTIZATION_SCALING: bool = os.environ.get("FLASH_ATTENTION_TRITON_AMD_REMOVE_QUANT_SCALE", "0").lower() in ("1", "true", "yes")

FP8_TYPES: set[torch.dtype] = {
torch.float8_e4m3fnuz,
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
}

FP8_MAX: dict[torch.dtype, float] = {
dtype: torch.finfo(dtype).max
for dtype in FP8_TYPES
}

DEFAULT_SCALE_PER_HEAD: bool = True


def check_is_fp8(x: Tensor, *xs: Tensor) -> bool:
"""
Checks whether the given tensors are of FP8 data types.

This function determines if any of the input tensors have a data type
matching one of the FP8 types defined in `FP8_TYPES`. If the environment
variable `FLASH_ATTENTION_TRITON_AMD_REMOVE_QUANT_SCALE` is set to a
truthy value, the function always returns `False`, effectively disabling FP8
type detection.

Args:
x (Tensor): The primary tensor to check.
*xs (Tensor): Additional tensors to check.

Returns:
bool:
- `True` if any of the input tensors have an FP8 data type and
quantization scaling is not disabled.
- `False` otherwise.
"""
if REMOVE_QUANTIZATION_SCALING:
return False # makes all methods believe they aren't working with fp8s, so no scaling is applied
return any(y.dtype in FP8_TYPES for y in (x,) + xs)


def create_fp8_scale_tensors(
q: Tensor, k: Tensor, v: Tensor, layout: str,
cu_seqlens_q: Optional[Tensor] = None, cu_seqlens_k: Optional[Tensor] = None,
max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None,
scale_per_head: bool = DEFAULT_SCALE_PER_HEAD, eps: float = 1e-9,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Create scale tensors for q, k and v based on the scaling configuration.

Args:
q (torch.Tensor): Query tensor.
k (torch.Tensor): Key tensor.
v (torch.Tensor): Value tensor.
layout (str): Tensor layout, can be `"bhsd"`, `"bshd"` or `"thd"`.
cu_seqlens_q (Optional[torch.Tensor]): Cumulative Q sequence length.
Used with `"thd"` varlen layout.
cu_seqlens_k (Optional[torch.Tensor]): Cumulative KV sequence length.
Used with `"thd"` varlen layout.
max_seqlen_q (Optional[int]): Max. Q sequence length.
Used with `"thd"` varlen layout.
max_seqlen_k (Optional[int]): Max. KV sequence length.
Used with `"thd"` varlen layout.
scale_per_head (bool): Whether to compute scale per head or globally.
Defaults to `DEFAULT_SCALE_PER_HEAD.`
eps (float): If the maximum absolute value of a tensor is zero, this
constant avoids division by zero while scaling. Defaults to 1e-9.

Returns:
tuple of 2D torch.Tensor: `(q_scale, k_scale, v_scale, p_scale)`.
To perform fp8 quantization, you should multiply by scale factor `(x_quant = x * x_scale)`.
To perform fp8 dequantization, you should multiply by reciprocal of scale factor
`(x = x_quant * (1 / x_scale))`.
`p_scale` is related to intermediate FA computation `p = softmax(matmul(q, transpose(k)))`.
All scale tensors are `float32` ones.
`q_scale` and `p_scale` have `(BATCH, HEADS_Q)` shape.
`k_scale` and `v_scale` have `(BATCH, HEADS_K)` shape.
"""
assert layout in ["bhsd", "bshd", "thd"], "Unknow layout."
is_varlen = layout == "thd"
if is_varlen:
assert cu_seqlens_q is not None, "cu_seqlens_q is required for varlen layout."
assert cu_seqlens_k is not None, "cu_seqlens_k is required for varlen layout."
assert max_seqlen_q is not None, "max_seqlen_q is required for varlen layout."
assert max_seqlen_k is not None, "max_seqlen_k is required for varlen layout."

is_fp8 = check_is_fp8(q, k, v)
batch, head_q, head_k, _, _, _ = get_shape_from_layout(
q, k, layout,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k,
)

if not is_fp8:
# For non-float8 dtypes, use a default scale of 1.
q_scale = torch.ones((batch, head_q), dtype=torch.float32, device=q.device)
k_scale = torch.ones((batch, head_k), dtype=torch.float32, device=k.device)
v_scale = torch.ones((batch, head_k), dtype=torch.float32, device=v.device)
p_scale = torch.ones((batch, head_q), dtype=torch.float32, device="cuda")

else:
# Handle float8 dtype special case.

# Convert to float32 for scale computation.
q_float32 = q.detach().to(torch.float32)
k_float32 = k.detach().to(torch.float32)
v_float32 = v.detach().to(torch.float32)

if not scale_per_head:
# Handle global scaling.

# Compute global max and create a tensor of that value.
q_global_max = max(q_float32.abs().max().item(), eps)
k_global_max = max(k_float32.abs().max().item(), eps)
v_global_max = max(v_float32.abs().max().item(), eps)

q_scale = torch.full((batch, head_q), q_global_max, dtype=torch.float32, device=q.device)
k_scale = torch.full((batch, head_k), k_global_max, dtype=torch.float32, device=k.device)
v_scale = torch.full((batch, head_k), v_global_max, dtype=torch.float32, device=v.device)

else:
# Handle per batch / head scaling.
teps = torch.tensor(eps)

if is_varlen:
q_scale = torch.stack([torch.maximum(q_float32[s:e].abs().amax(dim=(0, 2)), teps) for s, e in zip(cu_seqlens_q[:-1], cu_seqlens_q[1:])])
k_scale = torch.stack([torch.maximum(k_float32[s:e].abs().amax(dim=(0, 2)), teps) for s, e in zip(cu_seqlens_k[:-1], cu_seqlens_k[1:])])
v_scale = torch.stack([torch.maximum(v_float32[s:e].abs().amax(dim=(0, 2)), teps) for s, e in zip(cu_seqlens_k[:-1], cu_seqlens_k[1:])])

else:
if layout == "bhsd":
seqlen_loc = 2
dim_loc = 3
elif layout == "bshd":
seqlen_loc = 1
dim_loc = 3

# Compute max for each batch-head pair.
# Compute max across seqlen and dim.
q_scale = torch.maximum(q_float32.abs().amax(dim=(seqlen_loc, dim_loc)), teps)
k_scale = torch.maximum(k_float32.abs().amax(dim=(seqlen_loc, dim_loc)), teps)
v_scale = torch.maximum(v_float32.abs().amax(dim=(seqlen_loc, dim_loc)), teps)

# Divide max tensors by respective data type max.
q_scale = FP8_MAX[q.dtype] / q_scale
k_scale = FP8_MAX[k.dtype] / k_scale
v_scale = FP8_MAX[v.dtype] / v_scale

# Compute p_scale.
p_scale = torch.full((batch, head_q), FP8_MAX[q.dtype], dtype=torch.float32, device="cuda")

return q_scale, k_scale, v_scale, p_scale


def scale_fp8(
x: Tensor, x_scale: Tensor, layout: str,
cu_seqlens: Optional[Tensor] = None,
) -> Tensor:
"""
Scales an FP8 tensor using a specified scaling factor.

This function scales the input tensor `x` by dividing it by the scaling factor `x_scale`,
while considering the specified tensor layout.

If the input tensor `x` is not an FP8 tensor, the function returns it unchanged.
For FP8 tensors, the result is clamped to the representable range of the FP8 data type.

Args:
x (Tensor): The input tensor to be scaled.
x_scale (Tensor): The scaling factor tensor, broadcasted as needed based on `layout`.
layout (str): The data layout of the tensor. Must be one of `"bhsd"`, `"bshd"`, or `"thd"`.
cu_seqlens (Optional[Tensor], optional): Cumulative sequence lengths for variable-length
sequences. Required when `layout` is `"thd"`.

Returns:
Tensor: The scaled tensor.
"""
assert layout in ["bhsd", "bshd", "thd"], "Unknow layout."
assert (layout == "thd" and cu_seqlens is not None) or layout != "thd", "cu_seqlens is required for varlen layout."
if not check_is_fp8(x):
return x
# Fraction numerator is float32 version of x.
n = x.detach().to(torch.float32)
# Fraction denominator is the broadcasted scaled factor.
x_scale = x_scale.detach()
if layout == "bhsd":
x_scaled = n * x_scale[:, :, None, None]
elif layout == "bshd":
x_scaled = n * x_scale[:, None, :, None]
elif layout == "thd":
x_scaled = torch.cat([
n[s:e] * x_scale[z, :].unsqueeze(0).unsqueeze(-1)
for z, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:]))
], dim=0)
# Clamp and convert back to float8.
return torch.clamp(x_scaled, min=torch.finfo(x.dtype).min, max=torch.finfo(x.dtype).max).to(
x.dtype).requires_grad_(x.requires_grad)


class Fp8MetaData:
"""
Manages FP8 scaling metadata and scaled tensors for query, key, and value.

Attributes:
scale_per_head (bool): Indicates if scaling is applied per batch / head.
q_scale (Tensor): Scaling factor for the query tensor.
k_scale (Tensor): Scaling factor for the key tensor.
v_scale (Tensor): Scaling factor for the value tensor.
p_scale (Tensor): Scaling factor for intermediate FA computation
`p = softmax(matmul(q, transpose(k)))`.
q_inv_scale (Tensor): Inverse of `q_scale`.
k_inv_scale (Tensor): Inverse of `k_scale`.
v_inv_scale (Tensor): Inverse of `v_scale`.
p_inv_scale (Tensor): Inverse of `p_scale`.
q_scaled (Tensor): Scaled query tensor.
k_scaled (Tensor): Scaled key tensor.
v_scaled (Tensor): Scaled value tensor.

Methods:
__init__: Initializes scaling metadata and applies scaling to tensors.
"""

scale_per_head: bool
q_scale: Tensor
k_scale: Tensor
v_scale: Tensor
p_scale: Tensor
q_inv_scale: Tensor
k_inv_scale: Tensor
v_inv_scale: Tensor
p_inv_scale: Tensor
q_scaled: Tensor
k_scaled: Tensor
v_scaled: Tensor

def __init__(
self,
q: Tensor, k: Tensor, v: Tensor, layout: str, metadata: MetaData,
scale_per_head: bool = DEFAULT_SCALE_PER_HEAD,
) -> None:
self.scale_per_head = scale_per_head
self.q_scale, self.k_scale, self.v_scale, self.p_scale = create_fp8_scale_tensors(
q, k, v, layout,
cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_k=metadata.cu_seqlens_k,
max_seqlen_q=metadata.max_seqlens_q, max_seqlen_k=metadata.max_seqlens_k,
scale_per_head=scale_per_head,
)
self.q_inv_scale = 1 / self.q_scale
self.k_inv_scale = 1 / self.k_scale
self.v_inv_scale = 1 / self.v_scale
self.p_inv_scale = 1 / self.p_scale
self.q_scaled = scale_fp8(q, self.q_scale, layout, cu_seqlens=metadata.cu_seqlens_q)
self.k_scaled = scale_fp8(k, self.k_scale, layout, cu_seqlens=metadata.cu_seqlens_k)
self.v_scaled = scale_fp8(v, self.v_scale, layout, cu_seqlens=metadata.cu_seqlens_k)
Loading
Loading