From ebd00aac63d2c762eceb2aa148981277937c72bb Mon Sep 17 00:00:00 2001 From: David Berard Date: Wed, 11 Sep 2024 10:18:30 -0700 Subject: [PATCH] tritonbench bf16xint16 matmul template (#2348) Summary: Pull Request resolved: https://github.com/pytorch/benchmark/pull/2348 Overall context: Before looking further into the bf16xint4 matmul, I'm planning to look into a bf16xint16 matmul first. The idea of this matmul is that it will just be the same as a bf16xbf16 matmul, except the second operand needs to be casted from int16 to bf16 in the triton kernel before executing. This PR: is NOT fully functional yet. It's just implemented this way to make review easier. There's 3 kernels that will be benchmarked here: 1. bf16xbf16 triton kernel - I've selected this kernel as the "baseline" because, ideally, we'd like the bf16xint16 kernel to be as close as possible to this kernel. 2. bf16xint16 triton kernel - this is NOT implemented yet, will be implemented in the follow-up PR. 3. bf16x(convert(int16 -> bf16)) triton kernel - i.e. convert the int16->bf16, write to global memory, and then run the bf16xbf16 kernel. Differential Revision: D59234085 imported-using-ghimport D59234085 Test Plan: Imported from OSS Reviewed By: xuzhao9 Pulled By: davidberard98 fbshipit-source-id: 75a493dbd78ee1aa1f63926f6dd61a2e7388816c --- .../operators/bf16xint16_gemm/__init__.py | 1 + .../bf16xint16_gemm/bf16xint16_gemm.py | 158 ++++++ .../operators/bf16xint16_gemm/kernel.py | 496 ++++++++++++++++++ 3 files changed, 655 insertions(+) create mode 100644 torchbenchmark/operators/bf16xint16_gemm/__init__.py create mode 100644 torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py create mode 100644 torchbenchmark/operators/bf16xint16_gemm/kernel.py diff --git a/torchbenchmark/operators/bf16xint16_gemm/__init__.py b/torchbenchmark/operators/bf16xint16_gemm/__init__.py new file mode 100644 index 000000000..3a66ad270 --- /dev/null +++ b/torchbenchmark/operators/bf16xint16_gemm/__init__.py @@ -0,0 +1 @@ +from .bf16xint16_gemm import Operator diff --git a/torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py b/torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py new file mode 100644 index 000000000..5d7796b7c --- /dev/null +++ b/torchbenchmark/operators/bf16xint16_gemm/bf16xint16_gemm.py @@ -0,0 +1,158 @@ +""" +Compute a bf16 (activation) x int16 (weight) gemm. +A stepping stone to a fast int4_gemm (another TritonBench kernel) +bf16xbf16 baseline implementation taken from the triton tutorial + https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html +and the bf16xint16 implementation is a modified version of the same + tutorial kernel. +The benchmarking file (i.e. this file) is mostly copied from the + int4_gemm benchmarking file. +""" + +import argparse +import os +import statistics + +from typing import Any, List, Optional + +import torch +import triton +import triton.language as tl + +from torchbenchmark.util.triton_op import ( + BenchmarkOperator, + BenchmarkOperatorMetrics, + register_benchmark, + register_metric, +) + +from .kernel import ( + bf16xbf16_matmul, + bf16xbf16_matmul_kernel, + bf16xint16_matmul, + bf16xint16_matmul_kernel, +) + + +class Operator(BenchmarkOperator): + DEFAULT_METRICS = ["tflops", "gbps", "latency"] + + def __init__( + self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None + ): + super().__init__(tb_args=tb_args, extra_args=extra_args) + # `Group size` and `inner K tiles` are defaults from gpt-fast. + self.group_size = 32 + self.inner_k_tiles = 8 + + def get_input_iter(self): + def args(B, Dout, Din): + x = torch.randn(B, Din, device=self.device, dtype=torch.bfloat16) + w = torch.randint( + -(2**15), + 2**15 - 1, + (Din, Dout), + device=self.device, + dtype=torch.int16, + ) + return (x, w) + + # LLama-2 shapes w/ 8-way tensor parallelism. + name_to_shapes_70b = { + "attn.wqkv": (8192, 1280), + "attn.w0": (1024, 8192), + "ffn.w13": (8192, 7168), + "ffn.w2": (3584, 8192), + } + for bsz in (1, 4, 16, 64, 256, 1024, 2**12, 2**14, 2**16): + for name, (k, n) in name_to_shapes_70b.items(): + yield args(bsz, n, k) + + def get_x_val(self, example_inputs) -> float: + x, w = example_inputs + m, k = x.size() + _, n = w.size() + return (m, n, k) + + @register_benchmark(baseline=True) + def bf16xbf16(self, x, w): + x = x.reshape(-1, x.size(-1)) + w_bf16 = w.to(torch.bfloat16) + return lambda: bf16xbf16_matmul(x, w_bf16) + + @register_benchmark() + def bf16xint16(self, x, w): + x = x.reshape(-1, x.size(-1)) + # TODO(davidberard98) fix this to pass in an int16 + w = w.to(torch.bfloat16) + return lambda: bf16xint16_matmul(x, w) + + @register_benchmark() + def bf16xint16_casted(self, x, w): + x = x.reshape(-1, x.size(-1)) + return lambda: bf16xbf16_matmul(x, w.to(torch.bfloat16)) + + @register_metric() + def best_config(self, fn, inputs, metrics): + if "bf16xbf16" in str(fn): + return str(bf16xbf16_matmul_kernel.best_config) + if "bf16xint16" in str(fn) and "casted" not in str(fn): + return str(bf16xint16_matmul_kernel.best_config) + return "" + + @register_metric() + def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> float: + def nbytes(t): + return t.numel() * t.element_size() + + x, w = example_inputs + c = fn() + + gb = (sum(nbytes(t) for t in (x, c)) + nbytes(w) // 8) / 1e9 + return gb / metrics.latency * 1e3 + + @register_metric() + def tflops( + self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics + ) -> float: + a, b = example_inputs + m, k = a.size() + _, n = b.size() + flops = 2 * m * n * k + return flops / metrics.latency / 1e12 * 1e3 + + def plot(self): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=[ + "B", + "m", + "n", + "k", + ], # argument names to use as an x-axis for the plot + x_vals=self.output.x_vals, # different possible values for `x_name` + line_arg="provider", # argument name whose value corresponds to a different line in the plot + line_vals=[ + "torch", + "triton", + ], # possible values for `line_arg`` + line_names=[ + "torch", + "triton", + ], # label name for the lines + styles=[("blue", "-"), ("green", "-")], + ylabel="tflops", # label name for the y-axis + plot_name="int4-gemm-performance", # name for the plot. Used also as a file name for saving the plot. + args={}, # values for function arguments not in `x_names` and `y_name` + ) + ) + def _plot(B, m, n, k, provider): + tflops = self.output.get_y_vals((B, m, n, k), provider, "tflops") + return tflops + + save_path = "/tmp/bf16xint16_gemm" + + if not os.path.exists(save_path): + os.mkdir(save_path) + + _plot.run(show_plots=True, print_data=True, save_path=save_path) diff --git a/torchbenchmark/operators/bf16xint16_gemm/kernel.py b/torchbenchmark/operators/bf16xint16_gemm/kernel.py new file mode 100644 index 000000000..92dda9099 --- /dev/null +++ b/torchbenchmark/operators/bf16xint16_gemm/kernel.py @@ -0,0 +1,496 @@ +""" +Triton implementation by @jlebar: https://gist.github.com/jlebar/3435b2c00deea53258887ce37231e5e2 +""" + +import torch + +import triton +import triton.language as tl + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip_mi200(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == "hip" and target.arch == "gfx90a" + + +def get_cuda_autotune_config(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + # Good config for fp8 inputs. + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + ] + + +def get_hip_autotune_config(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 16, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + }, + num_warps=4, + num_stages=0, + ), + triton.Config( + { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 16, + "GROUP_SIZE_M": 4, + "waves_per_eu": 2, + }, + num_warps=8, + num_stages=0, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + }, + num_warps=8, + num_stages=0, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "waves_per_eu": 3, + }, + num_warps=4, + num_stages=0, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "waves_per_eu": 8, + }, + num_warps=4, + num_stages=0, + ), + ] + + +def get_autotune_config(): + if is_cuda(): + return get_cuda_autotune_config() + else: + return get_hip_autotune_config() + + +# NOTE(TritonBench): this is copied from the triton tutorial as the baseline +# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs=get_autotune_config(), + key=["M", "N", "K"], +) +@triton.jit +def bf16xbf16_matmul_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + c = accumulator.to(tl.bfloat16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# TODO(davidberard98): right now this is just a copy of the triton tutorial. +# TODO is to implement the int16 part. +# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html +@triton.autotune( + configs=get_autotune_config(), + key=["M", "N", "K"], +) +@triton.jit +def bf16xint16_matmul_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + c = accumulator.to(tl.bfloat16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def bf16xbf16_matmul(a, b): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) + # 1D launch kernel where each block gets its own program. + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + bf16xbf16_matmul_kernel[grid]( + a, + b, + c, # + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + ) + return c + + +def bf16xint16_matmul(a, b): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) + # 1D launch kernel where each block gets its own program. + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + bf16xint16_matmul_kernel[grid]( + a, + b, + c, # + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + ) + return c