From 1e6003dd7605983fc6c838b24ea90aeb6ccda1ca Mon Sep 17 00:00:00 2001 From: David Berard Date: Wed, 11 Sep 2024 10:18:30 -0700 Subject: [PATCH] Implement bf16xint16 kernel (#2349) Summary: Pull Request resolved: https://github.com/pytorch/benchmark/pull/2349 Modify the bf16xint16 kernel from the previous PR to actually do as intended: load the int16 input, convert it to bf16 inside the kernel, and then do the matmul. On H100: ``` $ python run_benchmark.py triton -- --op bf16xint16_gemm x_val bf16xbf16-best_config bf16xbf16-gbps bf16xbf16-latency bf16xbf16-tflops bf16xint16-best_config bf16xint16-gbps bf16xint16-latency bf16xint16-tflops bf16xint16_casted-best_config bf16xint16_casted-gbps bf16xint16_casted-latency bf16xint16_casted-tflops ------------------- -------------------------------------------------------------------------------------------------------------------------------- ---------------- ------------------- ------------------ --------------------------------------------------------------------------------------------------------------------------------- ----------------- -------------------- ------------------- ------------------------------- ------------------------ --------------------------- -------------------------- ... (16384, 1280, 8192) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 557.04 0.561899 611.494 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 225.308 1.38921 247.333 524.85 0.596361 576.157 (16384, 8192, 1024) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 527.772 0.576171 477.077 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 263.431 1.15433 238.127 498.965 0.609436 451.037 (16384, 7168, 8192) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 165.303 3.13361 614.034 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 65.4821 7.91051 243.239 156.515 3.30957 581.389 (16384, 8192, 3584) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 239.774 1.63995 586.649 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 101.813 3.86212 249.105 225.613 1.74288 552.003 (65536, 1280, 8192) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 562.389 2.21223 621.268 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 222.597 5.58917 245.902 552.792 2.25064 610.666 (65536, 8192, 1024) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 539.745 2.2419 490.437 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 270.808 4.46832 246.068 532.576 2.27208 483.922 (65536, 7168, 8192) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 166.428 12.1851 631.639 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 66.2297 30.6199 251.359 163.514 12.4023 620.577 (65536, 8192, 3584) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 243.299 6.37422 603.727 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 107.069 14.4846 265.682 239.25 6.48212 593.678 ``` On A100, this crashes due to some bugs/unhandled cases in triton. imported-using-ghimport Test Plan: Imported from OSS Reviewed By: xuzhao9 Differential Revision: D59234866 Pulled By: davidberard98 fbshipit-source-id: 46f0d671ce7bf9315d7ea7551663b03a36da3bc3 --- .../operators/bf16xint16_gemm/bf16xint16_gemm.py | 2 -- torchbenchmark/operators/bf16xint16_gemm/kernel.py | 12 +++++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py b/torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py index 5d7796b7c..51fea9519 100644 --- a/torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py +++ b/torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py @@ -83,8 +83,6 @@ def bf16xbf16(self, x, w): @register_benchmark() def bf16xint16(self, x, w): x = x.reshape(-1, x.size(-1)) - # TODO(davidberard98) fix this to pass in an int16 - w = w.to(torch.bfloat16) return lambda: bf16xint16_matmul(x, w) @register_benchmark() diff --git a/torchbenchmark/operators/bf16xint16_gemm/kernel.py b/torchbenchmark/operators/bf16xint16_gemm/kernel.py index 92dda9099..69cd0cdf7 100644 --- a/torchbenchmark/operators/bf16xint16_gemm/kernel.py +++ b/torchbenchmark/operators/bf16xint16_gemm/kernel.py @@ -347,9 +347,10 @@ def bf16xbf16_matmul_kernel( tl.store(c_ptrs, c, mask=c_mask) -# TODO(davidberard98): right now this is just a copy of the triton tutorial. -# TODO is to implement the int16 part. -# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html +# NOTE(TritonBench): this is a modified version of the triton tutorial matmul: +# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html +# It is modified to take a bf16 and an int16 input; then cast the int16 to bf16; +# then perform the bf16xbf16 matmul. @triton.autotune( configs=get_autotune_config(), key=["M", "N", "K"], @@ -419,9 +420,10 @@ def bf16xint16_matmul_kernel( # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0) + b_bf16 = b.to(tl.bfloat16) # We accumulate along the K dimension. - accumulator = tl.dot(a, b, accumulator) + accumulator = tl.dot(a, b_bf16, accumulator) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk