diff --git a/torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py b/torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py index 5d7796b7c..51fea9519 100644 --- a/torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py +++ b/torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py @@ -83,8 +83,6 @@ def bf16xbf16(self, x, w): @register_benchmark() def bf16xint16(self, x, w): x = x.reshape(-1, x.size(-1)) - # TODO(davidberard98) fix this to pass in an int16 - w = w.to(torch.bfloat16) return lambda: bf16xint16_matmul(x, w) @register_benchmark() diff --git a/torchbenchmark/operators/bf16xint16_gemm/kernel.py b/torchbenchmark/operators/bf16xint16_gemm/kernel.py index 92dda9099..69cd0cdf7 100644 --- a/torchbenchmark/operators/bf16xint16_gemm/kernel.py +++ b/torchbenchmark/operators/bf16xint16_gemm/kernel.py @@ -347,9 +347,10 @@ def bf16xbf16_matmul_kernel( tl.store(c_ptrs, c, mask=c_mask) -# TODO(davidberard98): right now this is just a copy of the triton tutorial. -# TODO is to implement the int16 part. -# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html +# NOTE(TritonBench): this is a modified version of the triton tutorial matmul: +# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html +# It is modified to take a bf16 and an int16 input; then cast the int16 to bf16; +# then perform the bf16xbf16 matmul. @triton.autotune( configs=get_autotune_config(), key=["M", "N", "K"], @@ -419,9 +420,10 @@ def bf16xint16_matmul_kernel( # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0) + b_bf16 = b.to(tl.bfloat16) # We accumulate along the K dimension. - accumulator = tl.dot(a, b, accumulator) + accumulator = tl.dot(a, b_bf16, accumulator) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk