Skip to content

Commit

Permalink
Implement bf16xint16 kernel (#2349)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2349

Modify the bf16xint16 kernel from the previous PR to actually do as intended: load the int16 input, convert it to bf16 inside the kernel, and then do the matmul.

On H100:

```
$ python run_benchmark.py triton -- --op bf16xint16_gemm
              x_val                                                                                                             bf16xbf16-best_config    bf16xbf16-gbps    bf16xbf16-latency    bf16xbf16-tflops                                                                                                             bf16xint16-best_config    bf16xint16-gbps    bf16xint16-latency    bf16xint16-tflops   bf16xint16_casted-best_config    bf16xint16_casted-gbps    bf16xint16_casted-latency    bf16xint16_casted-tflops
-------------------  --------------------------------------------------------------------------------------------------------------------------------  ----------------  -------------------  ------------------  ---------------------------------------------------------------------------------------------------------------------------------  -----------------  --------------------  ------------------- -------------------------------  ------------------------  ---------------------------  --------------------------
...
(16384, 1280, 8192)  BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None          557.04             0.561899            611.494     BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None           225.308              1.38921             247.333                                                      524.85                      0.596361                   576.157
(16384, 8192, 1024)  BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None          527.772            0.576171            477.077     BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None           263.431              1.15433             238.127                                                      498.965                     0.609436                   451.037
(16384, 7168, 8192)  BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None          165.303            3.13361             614.034     BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None            65.4821             7.91051             243.239                                                      156.515                     3.30957                    581.389
(16384, 8192, 3584)  BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None          239.774            1.63995             586.649     BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None           101.813              3.86212             249.105                                                      225.613                     1.74288                    552.003
(65536, 1280, 8192)  BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None          562.389            2.21223             621.268     BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None           222.597              5.58917             245.902                                                      552.792                     2.25064                    610.666
(65536, 8192, 1024)  BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None          539.745            2.2419              490.437     BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None           270.808              4.46832             246.068                                                      532.576                     2.27208                    483.922
(65536, 7168, 8192)  BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None          166.428           12.1851              631.639     BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None            66.2297            30.6199              251.359                                                      163.514                    12.4023                     620.577
(65536, 8192, 3584)  BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None          243.299            6.37422             603.727     BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None           107.069             14.4846              265.682                                                      239.25                      6.48212                    593.678
```

On A100, this crashes due to some bugs/unhandled cases in triton.

imported-using-ghimport

Test Plan: Imported from OSS

Reviewed By: xuzhao9

Differential Revision: D59234866

Pulled By: davidberard98

fbshipit-source-id: 46f0d671ce7bf9315d7ea7551663b03a36da3bc3
  • Loading branch information
davidberard98 authored and facebook-github-bot committed Sep 11, 2024
1 parent ebd00aa commit 1e6003d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 0 additions & 2 deletions torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ def bf16xbf16(self, x, w):
@register_benchmark()
def bf16xint16(self, x, w):
x = x.reshape(-1, x.size(-1))
# TODO(davidberard98) fix this to pass in an int16
w = w.to(torch.bfloat16)
return lambda: bf16xint16_matmul(x, w)

@register_benchmark()
Expand Down
12 changes: 7 additions & 5 deletions torchbenchmark/operators/bf16xint16_gemm/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,10 @@ def bf16xbf16_matmul_kernel(
tl.store(c_ptrs, c, mask=c_mask)


# TODO(davidberard98): right now this is just a copy of the triton tutorial.
# TODO is to implement the int16 part.
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
# NOTE(TritonBench): this is a modified version of the triton tutorial matmul:
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
# It is modified to take a bf16 and an int16 input; then cast the int16 to bf16;
# then perform the bf16xbf16 matmul.
@triton.autotune(
configs=get_autotune_config(),
key=["M", "N", "K"],
Expand Down Expand Up @@ -419,9 +420,10 @@ def bf16xint16_matmul_kernel(
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0)
b_bf16 = b.to(tl.bfloat16)
# We accumulate along the K dimension.
accumulator = tl.dot(a, b, accumulator)
accumulator = tl.dot(a, b_bf16, accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
Expand Down

0 comments on commit 1e6003d

Please sign in to comment.