Skip to content

Commit

Permalink
Add AMD CK FP8 Kernels to Llama Dispatch
Browse files Browse the repository at this point in the history
Summary:
This diff replaces the slow `torch._scaled_mm` implementation with faster CK kernels added in D56963018. I also extend AMD support to rowwise fp8 quantization.

Previously we didn't have an accelerated kernel to quantize tensors to fp8 on AMD, I updated the Triton functions to support AMD and use them for rowwise quantization and get excellent performance.

Before this change we have:
```
INFO:root:BF16 T: 5966.47us, FLOPS: 483.74TF/s
INFO:root:BF16 (G) T: 6017.79us, FLOPS: 479.61TF/s
INFO:root:FP8D AMD T: 32761.27us, FLOPS: 88.10TF/s
INFO:root:FP8D AMD (G) T: 33014.49us, FLOPS: 87.42TF/s
```

After we have:
```
INFO:root:BF16 T: 6006.43us, FLOPS: 480.52TF/s
INFO:root:BF16 (G) T: 6045.48us, FLOPS: 477.42TF/s
INFO:root:FP8D AMD CK T: 5894.74us, FLOPS: 489.63TF/s
INFO:root:FP8D AMD CK (G) T: 5870.20us, FLOPS: 491.67TF/s
INFO:root:FP8D rowwise AMD CK T: 3877.73us, FLOPS: 744.31TF/s
INFO:root:FP8D rowwise AMD CK (G) T: 3892.07us, FLOPS: 741.56TF/s
```

When using LLama3 shapes the performance is:
```
INFO:root:BF16 T: 90416.92us, FLOPS: 474.26TF/s
INFO:root:BF16 (G) T: 91022.73us, FLOPS: 471.10TF/s
INFO:root:FP8D AMD CK T: 65453.99us, FLOPS: 655.13TF/s
INFO:root:FP8D AMD CK (G) T: 63632.00us, FLOPS: 673.89TF/s
INFO:root:FP8D rowwise AMD CK T: 60791.51us, FLOPS: 705.38TF/s
INFO:root:FP8D rowwise AMD CK (G) T: 60916.44us, FLOPS: 703.93TF/s
```

Reviewed By: jianyuh, jiawenliu64

Differential Revision: D57739339
  • Loading branch information
jwfromm authored and facebook-github-bot committed May 24, 2024
1 parent ab05ca9 commit 262d4b9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
2 changes: 1 addition & 1 deletion fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _test_quantize_fp8_row(

# Undo scaling.
a_torch = a_fp8.base.to(torch.bfloat16)
a_torch /= a_scale[:, None]
a_torch *= a_scale[:, None]

self.assertTrue(
torch.allclose(
Expand Down
43 changes: 26 additions & 17 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@
)
from triton.runtime.jit import reinterpret as tl_reinterpret, TensorWrapper # @manual

MAX_FP8 = 448.0
# NVidia and AMD use different FP8 types, define them here for use throughout.
if torch.version.hip is not None:
PT_FP8_DTYPE = torch.float8_e4m3fnuz
TL_FP8_DTYPE = tl.float8e4b8
else:
PT_FP8_DTYPE = torch.float8_e4m3fn
TL_FP8_DTYPE = tl.float8e4nv

MAX_FP8 = torch.finfo(PT_FP8_DTYPE).max

logger: logging.Logger = logging.getLogger(__name__)

Expand All @@ -36,7 +44,7 @@ def convert_fp8_type(tensor) -> triton.TensorWrapper:
Returns:
triton.TensorWrapper: fp8 tensor.
"""
return tl_reinterpret(tensor, dtype=tl.float8e4nv)
return tl_reinterpret(tensor, dtype=TL_FP8_DTYPE)


def init_to_zero(name):
Expand Down Expand Up @@ -313,10 +321,8 @@ def _kernel_matmul_fp8_row(
a_scale = tl.load(A_scale + rm, mask=rm < M)
b_scale = tl.load(B_scale + rn, mask=rn < N)
# Invert vector, then multiply on matrix for speed.
inv_a_scale = 1.0 / a_scale
inv_b_scale = 1.0 / b_scale
# pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
scale = inv_a_scale[:, None] * inv_b_scale[None, :]
scale = a_scale[:, None] * b_scale[None, :]
acc *= scale

acc = acc.to(C.dtype.element_ty)
Expand Down Expand Up @@ -736,10 +742,16 @@ def prep_matmul(
m_key, n_key, k_key = get_matmul_tune(M, N, K)

# allocates output
assert a.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5] and b.dtype in [
assert a.dtype in [
tl.float8e4nv,
tl.float8e4b15,
tl.float8e5,
tl.float8e4b8,
] and b.dtype in [
tl.float8e4nv,
tl.float8e4b15,
tl.float8e5,
tl.float8e4b8,
]
c_dtype = torch.bfloat16

Expand Down Expand Up @@ -817,14 +829,14 @@ def _kernel_quantize_fp8_row(

# Scale and quantize.
a_scale = MAX_FP8 / cur_max
tl.store(A_scale + pid, a_scale)
tl.store(A_scale + pid, 1.0 / a_scale)
n_offset = tl.arange(0, BLOCK_SIZE)
for _k in range(0, tl.cdiv(N, BLOCK_SIZE)):
a = tl.load(
A + pid * stride_am + n_offset * stride_an, mask=n_offset < N, other=0.0
)
a_fp8 = a * a_scale
a_fp8.to(tl.float8e4nv)
a_fp8.to(TL_FP8_DTYPE)
tl.store(
A_fp8 + pid * stride_am + n_offset * stride_an, a_fp8, mask=n_offset < N
)
Expand All @@ -844,9 +856,7 @@ def triton_quantize_fp8_row(a: Tensor) -> Tuple[TensorWrapper, torch.Tensor]:
"""
num_rows = a.shape[0]
a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
a_fp8 = torch.empty(
(a.shape[0], a.shape[1]), device=a.device, dtype=torch.float8_e4m3fn
)
a_fp8 = torch.empty((a.shape[0], a.shape[1]), device=a.device, dtype=PT_FP8_DTYPE)

a_fp8 = convert_fp8_type(a_fp8)
grid = (num_rows,)
Expand Down Expand Up @@ -885,15 +895,14 @@ def quantize_fp8_row(

row_max: torch.Tensor = torch.max(torch.abs(a), dim=1)[0]
a_scale = torch.empty((a.shape[0]), dtype=torch.float32, device=output_device)
max_fp8 = torch.finfo(torch.float8_e4m3fn).max
a_scale = max_fp8 / row_max.to(torch.float32) # pyre-ignore
a_scale = MAX_FP8 / row_max.to(torch.float32) # pyre-ignore
a_scale[a_scale == float("inf")] = 1.0 # pyre-ignore
a_fp8 = a * a_scale[:, None] # pyre-ignore
# Cast and move data to output device (for cpu weight loading).
a_fp8 = convert_fp8_type(a_fp8.to(device=output_device, dtype=torch.float8_e4m3fn))
a_fp8 = convert_fp8_type(a_fp8.to(device=output_device, dtype=PT_FP8_DTYPE))
a_scale = a_scale.to(output_device) # pyre-ignore
del a
return a_fp8, a_scale
return a_fp8, 1 / a_scale # pyre-ignore


@triton.jit
Expand Down Expand Up @@ -946,7 +955,7 @@ def _kernel_quantize_fp8_block(

tl.store(A_scale + block_m * stride_a_scale_m + block_k * stride_a_scale_k, scale)
a_fp8 = a_block * scale
a_fp8.to(tl.float8e4nv)
a_fp8.to(TL_FP8_DTYPE)
tl.store(A_fp8 + a_offset, a_fp8, mask=a_mask)


Expand Down Expand Up @@ -974,7 +983,7 @@ def quantize_fp8_block(
grid_m = triton.cdiv(M, block_m)
grid_k = triton.cdiv(K, block_k)
x_scale = torch.ones((grid_m, grid_k), device=x.device, dtype=torch.float32)
x_fp8 = torch.empty((M, K), device=x.device, dtype=torch.float8_e4m3fn)
x_fp8 = torch.empty((M, K), device=x.device, dtype=PT_FP8_DTYPE)
x_fp8 = convert_fp8_type(x_fp8)

_kernel_quantize_fp8_block[(grid_m * grid_k,)](
Expand Down

0 comments on commit 262d4b9

Please sign in to comment.