diff --git a/torchbenchmark/operators/fused_ffn/__init__.py b/torchbenchmark/operators/fused_ffn/__init__.py new file mode 100644 index 000000000..a77a295cc --- /dev/null +++ b/torchbenchmark/operators/fused_ffn/__init__.py @@ -0,0 +1 @@ +from .operator import Operator diff --git a/torchbenchmark/operators/fused_ffn/kernel.py b/torchbenchmark/operators/fused_ffn/kernel.py new file mode 100644 index 000000000..3cc3ab8ac --- /dev/null +++ b/torchbenchmark/operators/fused_ffn/kernel.py @@ -0,0 +1,328 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import copy +import functools +import linecache +import os +import sys +import tempfile +from typing import Any, Dict, List + +import torch + +import triton +import triton.language as tl + + +def get_cuda_autotune_config(): + return [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=2, num_warps=2 + ), + # triton.Config( + # {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64}, num_stages=4, num_warps=4 + # ), + # triton.Config( + # {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=4 + # ), + # triton.Config( + # {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=4 + # ), + # triton.Config( + # {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64}, num_stages=4, num_warps=4 + # ), + # triton.Config( + # {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=4 + # ), + # triton.Config( + # {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=5, num_warps=2 + # ), + # triton.Config( + # {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=5, num_warps=2 + # ), + ] + + +def get_autotune_config(): + return get_cuda_autotune_config() + + +@triton.autotune( + configs=get_autotune_config(), + key=["M", "D", "H_D"], +) +@triton.jit +def fused_ffn_kernel( + X_ptr, + W13_ptr, + W2_ptr, + Y_ptr, + P_out_ptr, # Output for intermediate results + M, + D, + H_D, # Note: P is not needed as a parameter since P == D + stride_xm, + stride_xd, + stride_w13a, + stride_w13b, + stride_w2n, + stride_w2d, # Changed from stride_w2p to stride_w2d + stride_ym, + stride_yd, # Changed from stride_yp to stride_yd + stride_poutm, + stride_poutn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, # This will be used for both D and P dimensions + BLOCK_K_D: tl.constexpr, # This will be used for D dimension only +): + # Program IDs for M dimension + pid_m = tl.program_id(0) + + # Offsets for M + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = offs_m < M + + # Initialize accumulator with float32 precision + acc = tl.zeros((BLOCK_M, BLOCK_K_D), dtype=tl.float32) + + # Loop over H_D in BLOCK_N chunks + for start_n in range(0, H_D, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offs_n < H_D + + # Initialize partial results + p1_block = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + p2_block = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + # Block pointers for W13 (for p1 and p2) + w1t_bptr = tl.make_block_ptr( + base=W13_ptr, + shape=(D, H_D), + strides=(stride_w13b, stride_w13a), + offsets=(0, start_n), + block_shape=(BLOCK_K, BLOCK_N), + order=(1, 0), + ) + w3t_bptr = tl.make_block_ptr( + base=W13_ptr, + shape=(D, H_D), + strides=(stride_w13b, stride_w13a), + offsets=(0, H_D + start_n), + block_shape=(BLOCK_K, BLOCK_N), + order=(1, 0), + ) + + # Loop over K (which is equal to D) in BLOCK_K chunks + for k in range(0, D, BLOCK_K): + offs_k = k + tl.arange(0, BLOCK_K) + mask_k = offs_k < D + + # Load X block + x_bptr = tl.make_block_ptr( + base=X_ptr, + shape=(M, D), + strides=(stride_xm, stride_xd), + offsets=(pid_m * BLOCK_M, k), + block_shape=(BLOCK_M, BLOCK_K), + order=(1, 0), + ) + X_block = tl.load(x_bptr, boundary_check=(0, 1), padding_option="zero") + # X_block = tl.where(mask_m[:, None] & mask_k[None, :], X_block, 0.0).to( + # tl.float16 + # ) + + # Load W1 and W3 blocks + W1_block = tl.load(w1t_bptr) + W3_block = tl.load(w3t_bptr) + + # Perform GEMM operations + p1_block += tl.dot(X_block, W1_block) + p2_block += tl.dot(X_block, W3_block) + + # Advance the block pointers + w1t_bptr = tl.advance(w1t_bptr, (BLOCK_K, 0)) + w3t_bptr = tl.advance(w3t_bptr, (BLOCK_K, 0)) + + # Apply SiLU activation to p1 and multiply with p2 + p_out_block = p1_block * tl.sigmoid(p1_block) * p2_block + # p_out_block = tl.where(mask_m[:, None] & mask_n[None, :], p_out_block, 0.0) + + # Store P_out + P_out_offs = P_out_ptr + ( + offs_m[:, None] * stride_poutm + offs_n[None, :] * stride_poutn + ) + tl.store( + P_out_offs, + p_out_block.to(tl.float16), + mask=mask_m[:, None] & mask_n[None, :], + ) + + w2_bptr = tl.make_block_ptr( + base=W2_ptr, + shape=(H_D, D), + strides=(stride_w2n, stride_w2d), + offsets=(start_n, 0), + block_shape=(BLOCK_N, BLOCK_K_D), + order=(0, 1), + ) + W2_block = tl.load(w2_bptr, boundary_check=(0, 1), padding_option="zero") + + # Perform the second GEMM + acc += tl.dot(p_out_block.to(tl.float16), W2_block) + + offs_d = tl.arange(0, BLOCK_K_D) + mask_d = offs_d < D + y_offs = Y_ptr + offs_m[:, None] * stride_ym + offs_d[None, :] * stride_yd + tl.store(y_offs, acc.to(tl.float16), mask=mask_m[:, None] & mask_d[None, :]) + + +def fused_ffn( + x: torch.Tensor, w13: torch.Tensor, w2: torch.Tensor, has_p: bool = False +): + # x: [B_T, D] + # w13: [H_D*2, D] + # D = K + # out1: [B_T, H_D] + # w2: [H_D, P] + # P = K + # output: [B_T, P] + B_T, D = x.shape + H_D_2, D = w13.shape + P, H_D = w2.shape + assert D == P, f"D and P must be equal but got {D=} and {P=}" + assert H_D_2 == 2 * H_D, f"H_D_2 must be 2 times of H_D but got {H_D_2=} and {H_D=}" + + def grid(META): + return (triton.cdiv(B_T, META["BLOCK_M"]),) # triton.cdiv(P, META["BLOCK_P"])) + + output = torch.empty((B_T, P), dtype=x.dtype, device=x.device) + if has_p: + p_out = torch.empty((B_T, H_D), dtype=x.dtype, device=x.device) + else: + p_out = torch.empty(1, dtype=x.dtype, device=x.device) # Dummy tensor + + w2_t = w2.t().contiguous() + + BLOCK_K_D = D + + fused_ffn_kernel[grid]( + x, + w13, + w2_t, + output, + p_out, + B_T, + D, + H_D, + x.stride(0), + x.stride(1), + w13.stride(0), + w13.stride(1), + w2_t.stride(0), + w2_t.stride(1), + output.stride(0), + output.stride(1), + p_out.stride(0) if has_p else 0, + p_out.stride(1) if has_p else 0, + BLOCK_K_D=BLOCK_K_D, + ) + + return output, p_out if has_p else None + + +def eager_ffn(x, w13, w2): + p = torch.matmul(x, w13.t()) + H_D_2, D = w13.shape + H_D = H_D_2 // 2 + p1 = p[:, :H_D] # B_T, H_D + p2 = p[:, H_D:] # B_T, H_D + p_out = p1 * torch.sigmoid(p1) * p2 + out = torch.matmul(p_out, w2.t()) + return out, p_out + + +def nunerics_check(shape): + B_T, H_D, D = shape + print(f"Running numeric check for {shape}") + x = torch.randn((B_T, D), dtype=torch.float16, device="cuda") + w13 = torch.randn((H_D * 2, D), dtype=torch.float16, device="cuda") * 0.1 + w2 = torch.randn((D, H_D), dtype=torch.float16, device="cuda") * 0.1 + triton_out, triton_p = fused_ffn(x, w13, w2, has_p=True) + eager_out, eager_p = eager_ffn(x, w13, w2) + + if not torch.allclose(triton_p, eager_p, atol=1e-2, rtol=1e-2): + print("P numeric check failed") + print(f"triton output: {triton_p.flatten()[0:10]}") + print(f"eager output: {eager_p.flatten()[0:10]}") + else: + print("P numeric check passed") + if not torch.allclose(triton_out, eager_out, atol=1e-2, rtol=1e-2): + print("Y numeric check failed") + print(f"triton output: {triton_out.flatten()[0:10]}") + print(f"eager output: {eager_out.flatten()[0:10]}") + else: + print("Y numeric check passed") + + torch.testing.assert_close(triton_out, eager_out, atol=1e-2, rtol=1e-2) + + +def do_benchmark(): + + D = 2048 + H_D = 8192 + + configs = [] + configs.append( + triton.testing.Benchmark( + x_names=[ + "B_T", + "H_D", + "D", + ], # Argument names to use as an x-axis for the plot + x_vals=[ + (i, H_D, D) + for H_D, D in [(5325, 4096)] + for i in [1024, 2048, 4096, 8192, 16384] + ], # Different possible values for `x_name` + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. + line_vals=["eager", "fused"], + line_names=["Eager", "Fused"], + styles=[("green", "-"), ("blue", "-")], + ylabel="Latency(ms)", # Label name for the y-axis + plot_name="fused_ffn-benchmark", + args={}, + ) + ) + + @triton.testing.perf_report(configs) + def benchmark(B_T, H_D, D, provider): + # breakpoint() + x = torch.randn((B_T, D), dtype=torch.float16, device="cuda") + w13 = torch.randn((H_D * 2, D), dtype=torch.float16, device="cuda") + w2 = torch.randn((D, H_D), dtype=torch.float16, device="cuda") + quantiles = [0.5, 0.2, 0.8] + if provider == "eager": + return triton.testing.do_bench( + lambda: eager_ffn(x, w13, w2), quantiles=quantiles + ) + if provider == "fused": + return triton.testing.do_bench( + lambda: fused_ffn(x, w13, w2), quantiles=quantiles + ) + + benchmark.run(show_plots=True, print_data=True) + + +if __name__ == "__main__": + # B_T, H_D, D + torch.manual_seed(0) + nunerics_check((1024, 1024, 128)) + + # do_benchmark() diff --git a/torchbenchmark/operators/fused_ffn/operator.py b/torchbenchmark/operators/fused_ffn/operator.py new file mode 100644 index 000000000..bf4fc538d --- /dev/null +++ b/torchbenchmark/operators/fused_ffn/operator.py @@ -0,0 +1,120 @@ +import argparse +import os +from typing import Any, Callable, Generator, List, Optional, Tuple + +import torch +import triton + +from torchbenchmark.util.triton_op import ( + BenchmarkOperator, + BenchmarkOperatorMetrics, + register_benchmark, + register_metric, + register_x_val, +) + +from .kernel import eager_ffn, fused_ffn + + +def parse_args(args: List[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="TorchBench for fused FFN operator Benchmark" + ) + parser.add_argument("--b-t", type=int) + parser.add_argument("--h-d", type=int) + parser.add_argument("--d", type=int) + args = parser.parse_args(args) + return args + + +BUILDIN_SHAPES = [ + (b_t, h_d, d) + for h_d, d in [(128, 256), (1024, 512), (8192, 2048)] + for b_t in [1024, 2048, 4096, 8192, 16384] +] + + +class Operator(BenchmarkOperator): + DEFAULT_METRICS = ["latency"] + DEFAULT_PRECISION = "bf16" + + def __init__( + self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None + ): + super().__init__(tb_args, extra_args) + addmm_args = parse_args(self.extra_args) + if addmm_args.m and addmm_args.n and addmm_args.k: + self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)] + else: + self.shapes = BUILDIN_SHAPES + + @register_benchmark() + def fused_ffn_op(self, x, w13, w2) -> Callable: + return lambda: fused_ffn(x, w13, w2) + + @register_benchmark() + def eager_ffn_op(self, x, w13, w2) -> Callable: + return lambda: eager_ffn(x, w13, w2) + + @register_metric() + def tflops( + self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics + ) -> List[float]: + x, w13, w2 = example_inputs + B_T, _ = x.size() + H_D_2, _ = w13.size() + H_D, D = w2.size() + # gemm #1 + flops = 2 * B_T * H_D_2 * D + # gemm #2 + flops += 2 * B_T * H_D * D + return flops / metrics.latency / 1e12 * 1e3 + + @register_x_val(label="(B_T, Hidden_D, D)") + def get_x_val(self, example_inputs) -> Tuple[int, int, int]: + x, w13, w2 = example_inputs + B_T, D = x.size() + H_D, D = w2.size() + return (B_T, H_D, D) + + def get_input_iter(self) -> Generator: + for shape in self.shapes: + b_t, h_d, d = shape + x = torch.randn((b_t, d), device=self.device, dtype=self.dtype) + w13 = torch.randn((2 * h_d, d), device=self.device, dtype=self.dtype) + w2 = torch.randn((h_d, d), device=self.device, dtype=self.dtype) + + yield x, w13, w2 + + def plot(self): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["shape"], # argument names to use as an x-axis for the plot + x_vals=self.output.x_vals, # different possible values for `x_name` + line_arg="provider", # argument name whose value corresponds to a different line in the plot + line_vals=[ + "fused_ffn", + "eager_ffn", + ], # possible values for `line_arg`` + line_names=[ + "Fused FFN", + "Eager FFN", + ], # label name for the lines + styles=[ + ("blue", "-"), + ("green", "-"), + ], # line styles + ylabel="tflops", # label name for the y-axis + plot_name="gemm-performance", # name for the plot. Used also as a file name for saving the plot. + args={}, # values for function arguments not in `x_names` and `y_name` + ) + ) + def _plot(density, provider): + tflops = self.output.get_y_vals(density, provider, "tflops") + return tflops + + save_path = self.get_temp_path() + + os.mkdirs(save_path, exist_ok=True) + + _plot.run(show_plots=True, print_data=True, save_path=save_path)