Skip to content

Commit

Permalink
Add AMD Rowwise FP8 Matmul (#2611)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2611

This diff extends the `fp8fp8bf16_rowwise` gemm operation to AMD through a new CK kernel. The new kernel requires new stride support that is only available in developer branches of CK, so we must rely on the `ai_codesign/gen_ai` CK repo. I also extend the fp8 benchmarking suite to include rowwise measurements. I'll soon add detailed benchmarking results but the quick summary is that performance looks quite good, typically inline with tensorwise quantization and sometimes faster, presumably due to using the latest and greatest CK pipelines.

Full performance results can be found on the rowwise sheet linked below. Overall, the rowwise kernel is similarly performant to tensorwise scaling and actually does quite a bit better in some cases. https://docs.google.com/spreadsheets/d/1hK62EJIt9mZQJdZBxf5rGZ3u0hjMng7AkpwKnIjyn3k/edit#gid=881561326

Reviewed By: jianyuh, jiawenliu64

Differential Revision: D57600068

fbshipit-source-id: c5c533f1d332de293f788792816e535ab1131cec
  • Loading branch information
jwfromm authored and facebook-github-bot committed May 22, 2024
1 parent f739e0a commit 8e335d3
Show file tree
Hide file tree
Showing 5 changed files with 334 additions and 49 deletions.
107 changes: 86 additions & 21 deletions fbgemm_gpu/experimental/gen_ai/bench/ck_fp8_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ def set_amd_env_vars() -> None:
os.environ["PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS"] = "30"


@torch.no_grad()
def fp8_row_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# Quantize an input tensor and return the fp8 tensor and its inverse scale.
x_row_max = torch.max(torch.abs(x), dim=1).values
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
scale = E4M3_MAX_POS / torch.clamp(x_row_max, EPS)
# pyre-fixme[16]: Item `float` of `typing.Union[float, torch._tensor.Tensor]` has no attribute `__getitem__`.
xq = torch.clamp(x * scale[:, None], min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS).to(
torch.float8_e4m3fnuz
)
# pyre-fixme[16]: Item `float` of `typing.Union[float, torch._tensor.Tensor]` has no attribute `__getitem__`.
return xq, scale.to(torch.float32).reciprocal()


def fp8_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
f8_max = torch.tensor(E4M3_MAX_POS, device=x.device)
x_amax = torch.max(torch.abs(x))
Expand Down Expand Up @@ -65,12 +79,22 @@ def forward(
return output


class CKMatmul(torch.nn.Module):
class CKTensorMatmul(torch.nn.Module):
def forward(
self, a: torch.Tensor, b: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
out = torch.ops.fbgemm.f8f8bf16_tensorwise(a, b, scale)
return out
return torch.ops.fbgemm.f8f8bf16_tensorwise(a, b, scale)


class CKRowMatmul(torch.nn.Module):
def forward(
self,
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
) -> torch.Tensor:
return torch.ops.fbgemm.f8f8bf16_rowwise(a, b, a_scale, b_scale)


@torch.no_grad()
Expand All @@ -82,13 +106,18 @@ def evaluate_impl(
baseline_func: Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor
],
ck_func: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
) -> Tuple[float, float, float, float, float]:
ck_tensor_func: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
ck_row_func: Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor
],
) -> Tuple[float, float, float, float, float, float, float]:
print(f"Evaluating {M=}, {N=}, {K=}")
A = torch.randn(M, K).to(dtype=torch.bfloat16, device="cuda")
QA, a_scale = fp8_quantize(A)
B = torch.randn(N, K).to(dtype=torch.bfloat16, device="cuda")
QB, b_scale = fp8_quantize(B)
QA_row, a_scale_row = fp8_row_quantize(A)
QB_row, b_scale_row = fp8_row_quantize(B)

# Check accuracy.
out_ref = fp_func(A.to(torch.float32), B.t().to(torch.float32))
Expand All @@ -97,9 +126,13 @@ def evaluate_impl(
baseline_sim = torch.mean(torch.pow(torch.abs(baseline_out - out_ref), 2))
print(f"Baseline accuracy: {baseline_sim}")

ck_out = ck_func(QA, QB, a_scale * b_scale)
ck_sim = torch.mean(torch.pow(torch.abs(ck_out - out_ref), 2))
print(f"CK accuracy: {ck_sim}")
ck_tensor_out = ck_tensor_func(QA, QB, a_scale * b_scale)
ck_tensor_sim = torch.mean(torch.pow(torch.abs(ck_tensor_out - out_ref), 2))
print(f"CK tensorwise accuracy: {ck_tensor_sim}")

ck_row_out = ck_row_func(QA_row, QB_row, a_scale_row, b_scale_row)
ck_row_sim = torch.mean(torch.pow(torch.abs(ck_row_out - out_ref), 2))
print(f"CK rowwise accuracy: {ck_row_sim}")

# Benchmark runtimes.
ms_ref: float = triton.testing.do_bench(lambda: fp_func(A, B.t()))
Expand All @@ -110,23 +143,45 @@ def evaluate_impl(
)
print(f"Baseline runtime: {ms_baseline} ms")

ms_ck: float = triton.testing.do_bench(lambda: ck_func(QA, QB, a_scale * b_scale))
print(f"CK runtime: {ms_ck} ms")
ms_tensor_ck: float = triton.testing.do_bench(
lambda: ck_tensor_func(QA, QB, a_scale * b_scale)
)
print(f"CK tensorwise runtime: {ms_tensor_ck} ms")

return float(baseline_sim.item()), float(ck_sim.item()), ms_baseline, ms_ck, ms_ref
ms_row_ck: float = triton.testing.do_bench(
lambda: ck_row_func(QA_row, QB_row, a_scale_row, b_scale_row)
)
print(f"CK rowwise runtime: {ms_row_ck} ms")

return (
float(baseline_sim.item()),
float(ck_tensor_sim.item()),
float(ck_row_sim.item()),
ms_baseline,
ms_tensor_ck,
ms_row_ck,
ms_ref,
)


def main(args: Any) -> None:
if args.enable_amd_env_vars:
set_amd_env_vars()

with torch.no_grad():
ck_mod = CKMatmul()
ck_tensor_mod = CKTensorMatmul()
ck_row_mod = CKRowMatmul()
baseline_mod = BaselineMatmul()
bf16_mod = FPMatMul()
if args.torch_compile_mode:
ck_mod = torch.compile(
ck_mod,
ck_tensor_mod = torch.compile(
ck_tensor_mod,
dynamic=False,
backend="inductor",
mode=args.torch_compile_mode,
)
ck_row_mod = torch.compile(
ck_row_mod,
dynamic=False,
backend="inductor",
mode=args.torch_compile_mode,
Expand All @@ -147,25 +202,35 @@ def main(args: Any) -> None:
benchmark_results = []

# Test over a bunch of shapes.
M = [13312, 16384, 16032, 2304, 2048]
N = [4096, 2304, 13312, 8192]
K = [16384, 6656, 2304, 2048, 13312]
M = [128, 2048, 2304, 13312, 16032, 16384]
N = [128, 2304, 4096, 8192, 13312]
K = [128, 2048, 2304, 6656, 13312, 16384]

for m in M:
for n in N:
for k in K:
baseline_sim, ck_sim, ms_baseline, ms_ck, ms_bf16 = evaluate_impl(
m, n, k, bf16_mod, baseline_mod, ck_mod
(
baseline_sim,
ck_tensor_sim,
ck_row_sim,
ms_baseline,
ms_tensor_ck,
ms_row_ck,
ms_bf16,
) = evaluate_impl(
m, n, k, bf16_mod, baseline_mod, ck_tensor_mod, ck_row_mod
)
benchmark_results.append(
{
"M": m,
"N": n,
"K": k,
"baseline_sim": baseline_sim,
"ck_sim": ck_sim,
"ck_tensor_sim": ck_tensor_sim,
"ck_row_sim": ck_row_sim,
"ms_baseline": ms_baseline,
"ms_ck": ms_ck,
"ms_tensor_ck": ms_tensor_ck,
"ms_row_ck": ms_row_ck,
"ms_bf16": ms_bf16,
}
)
Expand Down
Loading

0 comments on commit 8e335d3

Please sign in to comment.