diff --git a/torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py b/torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py index 51fea9519..ed42ae0ad 100644 --- a/torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py +++ b/torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py @@ -41,6 +41,22 @@ def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): super().__init__(tb_args=tb_args, extra_args=extra_args) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--transpose", + action="store_true", + help="Instead of computing A @ B, compute B.T @ A.T.", + ) + + gemm_args = parser.parse_args(extra_args) + + # Normally, we compute x @ w, where x is bf16 and w is int16. + # If transposed, we compute w.T @ x.T, where x.T/w.T are contiguous and w.T is int16. + # Motivation: on H100, only one the lhs of the wgmma instruction can be read from registers, + # so the order (which input is int16) matters. + self.transpose = gemm_args.transpose + # `Group size` and `inner K tiles` are defaults from gpt-fast. self.group_size = 32 self.inner_k_tiles = 8 @@ -55,6 +71,12 @@ def args(B, Dout, Din): device=self.device, dtype=torch.int16, ) + + if self.transpose: + # transpose logic below is only valid for 2D tensors. + assert x.dim() == 2 + assert w.dim() == 2 + return (w.T.contiguous(), x.T.contiguous()) return (x, w) # LLama-2 shapes w/ 8-way tensor parallelism. @@ -64,6 +86,10 @@ def args(B, Dout, Din): "ffn.w13": (8192, 7168), "ffn.w2": (3584, 8192), } + + yield args(2**16, 1280, 8192) + return + for bsz in (1, 4, 16, 64, 256, 1024, 2**12, 2**14, 2**16): for name, (k, n) in name_to_shapes_70b.items(): yield args(bsz, n, k) @@ -76,19 +102,25 @@ def get_x_val(self, example_inputs) -> float: @register_benchmark(baseline=True) def bf16xbf16(self, x, w): - x = x.reshape(-1, x.size(-1)) - w_bf16 = w.to(torch.bfloat16) - return lambda: bf16xbf16_matmul(x, w_bf16) + x_bf16 = x.to(torch.bfloat16) if self.transpose else x + w_bf16 = w if self.transpose else w.to(torch.bfloat16) + return lambda: bf16xbf16_matmul(x_bf16, w_bf16) @register_benchmark() def bf16xint16(self, x, w): x = x.reshape(-1, x.size(-1)) - return lambda: bf16xint16_matmul(x, w) + return lambda: bf16xint16_matmul(x, w, transpose=self.transpose) @register_benchmark() def bf16xint16_casted(self, x, w): x = x.reshape(-1, x.size(-1)) - return lambda: bf16xbf16_matmul(x, w.to(torch.bfloat16)) + + def fn(): + x_bf16 = x.to(torch.bfloat16) if self.transpose else x + w_bf16 = w if self.transpose else w.to(torch.bfloat16) + return bf16xbf16_matmul(x_bf16, w_bf16) + + return fn @register_metric() def best_config(self, fn, inputs, metrics): diff --git a/torchbenchmark/operators/bf16xint16_gemm/kernel.py b/torchbenchmark/operators/bf16xint16_gemm/kernel.py index 69cd0cdf7..c3590ad2f 100644 --- a/torchbenchmark/operators/bf16xint16_gemm/kernel.py +++ b/torchbenchmark/operators/bf16xint16_gemm/kernel.py @@ -379,6 +379,7 @@ def bf16xint16_matmul_kernel( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # + TRANSPOSE: tl.constexpr, # if true, assume a_ptr is int16; otherwise assume b_ptr is int16 ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -421,9 +422,19 @@ def bf16xint16_matmul_kernel( # If it is out of bounds, set it to 0. a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0) - b_bf16 = b.to(tl.bfloat16) + if TRANSPOSE: + tl.static_assert(a.dtype == tl.int16) + tl.static_assert(b.dtype == tl.bfloat16) + a_bf16 = a.to(tl.bfloat16) + b_bf16 = b + else: + tl.static_assert(a.dtype == tl.bfloat16) + tl.static_assert(b.dtype == tl.int16) + a_bf16 = a + b_bf16 = b.to(tl.bfloat16) + # We accumulate along the K dimension. - accumulator = tl.dot(a, b_bf16, accumulator) + accumulator = tl.dot(a_bf16, b_bf16, accumulator) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk @@ -469,7 +480,7 @@ def bf16xbf16_matmul(a, b): return c -def bf16xint16_matmul(a, b): +def bf16xint16_matmul(a, b, transpose=False): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.is_contiguous(), "Matrix A must be contiguous" @@ -494,5 +505,6 @@ def bf16xint16_matmul(a, b): b.stride(1), # c.stride(0), c.stride(1), # + TRANSPOSE=transpose, ) return c