Skip to content

Commit

Permalink
Add FP8 rowwise and tensorwise unittests with cudagraph (pytorch#2609)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2609

Add cudagraph coverage in FP8 rowwise and tensorwise unittests to guard FP8 GEMM kernel-wise issues with cudagraph

TODO
Merge f8f8bf16 and f8f8bf16_tensorwise into one kernel

Reviewed By: jasonjk-park

Differential Revision: D57578841

fbshipit-source-id: 780481848ff3443f5addc1afdc1f563fa4f1df82
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed May 20, 2024
1 parent 7792908 commit 581fcec
Showing 1 changed file with 56 additions and 14 deletions.
70 changes: 56 additions & 14 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,17 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None:
Mode=st.sampled_from(["tensorwise", "tensorwise_broadcast", "rowwise"]),
QType=st.sampled_from([torch.float8_e4m3fn, torch.float8_e5m2]),
Bias=st.sampled_from([True, False]),
CudaGraph=st.sampled_from([True, False]),
)
def test_quantize_fp8_matmul(
self, B_T: int, D: int, HD_L: int, Mode: str, QType: torch.dtype, Bias: bool
self,
B_T: int,
D: int,
HD_L: int,
Mode: str,
QType: torch.dtype,
Bias: bool,
CudaGraph: bool,
) -> None:
x = torch.randn(size=(B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1
w = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
Expand All @@ -126,23 +134,57 @@ def test_quantize_fp8_matmul(
)

if Mode == "tensorwise":
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale)
if bias is not None:
zq += bias
if CudaGraph:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale)
if bias is not None:
zq += bias
g.replay()
else:
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale)
if bias is not None:
zq += bias
elif Mode == "tensorwise_broadcast":
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
zq = torch.ops.fbgemm.f8f8bf16_tensorwise(
xq, wq, (x_scale * w_scale).item()
)
if bias is not None:
zq += bias
x_scale = x_scale.item()
w_scale = w_scale.item()
if CudaGraph:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
zq = torch.ops.fbgemm.f8f8bf16_tensorwise(xq, wq, x_scale * w_scale)
if bias is not None:
zq += bias
g.replay()
else:
zq = torch.ops.fbgemm.f8f8bf16_tensorwise(xq, wq, x_scale * w_scale)
if bias is not None:
zq += bias
elif Mode == "rowwise":
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, output_dtype=QType)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
zq = torch.ops.fbgemm.f8f8bf16_rowwise(xq, wq, x_scale, w_scale, bias=bias)
if CudaGraph:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x, output_dtype=QType
)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
zq = torch.ops.fbgemm.f8f8bf16_rowwise(
xq, wq, x_scale, w_scale, bias=bias
)
g.replay()
else:
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x, output_dtype=QType
)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
zq = torch.ops.fbgemm.f8f8bf16_rowwise(
xq, wq, x_scale, w_scale, bias=bias
)
else:
raise ValueError(f"Invalid mode {Mode}")

Expand Down

0 comments on commit 581fcec

Please sign in to comment.