Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #2349 Modify the bf16xint16 kernel from the previous PR to actually do as intended: load the int16 input, convert it to bf16 inside the kernel, and then do the matmul. On H100: ``` $ python run_benchmark.py triton -- --op bf16xint16_gemm x_val bf16xbf16-best_config bf16xbf16-gbps bf16xbf16-latency bf16xbf16-tflops bf16xint16-best_config bf16xint16-gbps bf16xint16-latency bf16xint16-tflops bf16xint16_casted-best_config bf16xint16_casted-gbps bf16xint16_casted-latency bf16xint16_casted-tflops ------------------- -------------------------------------------------------------------------------------------------------------------------------- ---------------- ------------------- ------------------ --------------------------------------------------------------------------------------------------------------------------------- ----------------- -------------------- ------------------- ------------------------------- ------------------------ --------------------------- -------------------------- ... (16384, 1280, 8192) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 557.04 0.561899 611.494 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 225.308 1.38921 247.333 524.85 0.596361 576.157 (16384, 8192, 1024) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 527.772 0.576171 477.077 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 263.431 1.15433 238.127 498.965 0.609436 451.037 (16384, 7168, 8192) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 165.303 3.13361 614.034 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 65.4821 7.91051 243.239 156.515 3.30957 581.389 (16384, 8192, 3584) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 239.774 1.63995 586.649 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 101.813 3.86212 249.105 225.613 1.74288 552.003 (65536, 1280, 8192) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 562.389 2.21223 621.268 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 222.597 5.58917 245.902 552.792 2.25064 610.666 (65536, 8192, 1024) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 539.745 2.2419 490.437 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 270.808 4.46832 246.068 532.576 2.27208 483.922 (65536, 7168, 8192) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 166.428 12.1851 631.639 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 66.2297 30.6199 251.359 163.514 12.4023 620.577 (65536, 8192, 3584) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 243.299 6.37422 603.727 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 107.069 14.4846 265.682 239.25 6.48212 593.678 ``` On A100, this crashes due to some bugs/unhandled cases in triton. imported-using-ghimport Test Plan: Imported from OSS Reviewed By: xuzhao9 Differential Revision: D59234866 Pulled By: davidberard98 fbshipit-source-id: 46f0d671ce7bf9315d7ea7551663b03a36da3bc3
- Loading branch information