From 4c7ec3ac4f84b19d3eb6accf4edcfcddb4a5e160 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Thu, 9 May 2024 09:05:57 -0700 Subject: [PATCH] add gather + gemv Summary: Based on PT2 test case: https://github.com/pytorch/pytorch/issues/121661 Reviewed By: bertmaher Differential Revision: D56437647 fbshipit-source-id: 0d735c443c9cea419dce971fd6ea796ca444c0ff --- .../operators/gather_gemv/__init__.py | 1 + .../operators/gather_gemv/operator.py | 73 +++++++++++ .../gather_gemv/triton_gather_gemv.py | 116 ++++++++++++++++++ 3 files changed, 190 insertions(+) create mode 100644 torchbenchmark/operators/gather_gemv/__init__.py create mode 100644 torchbenchmark/operators/gather_gemv/operator.py create mode 100644 torchbenchmark/operators/gather_gemv/triton_gather_gemv.py diff --git a/torchbenchmark/operators/gather_gemv/__init__.py b/torchbenchmark/operators/gather_gemv/__init__.py new file mode 100644 index 0000000000..a77a295cc4 --- /dev/null +++ b/torchbenchmark/operators/gather_gemv/__init__.py @@ -0,0 +1 @@ +from .operator import Operator diff --git a/torchbenchmark/operators/gather_gemv/operator.py b/torchbenchmark/operators/gather_gemv/operator.py new file mode 100644 index 0000000000..c7cb069a8f --- /dev/null +++ b/torchbenchmark/operators/gather_gemv/operator.py @@ -0,0 +1,73 @@ +""" +Based on PT2 test case: https://github.com/pytorch/pytorch/issues/121661 +Motivated by https://www.thonking.ai/p/short-supporting-mixtral-in-gpt-fast, +gather + gemv is the primary kernel driving mixtral perf. +""" + +import csv +import os +import statistics +from typing import Any, Callable, Generator, List, Optional + +import numpy +import torch +import triton + + +from torchbenchmark.util.triton_op import ( + BenchmarkOperator, + BenchmarkOperatorMetrics, + register_benchmark, + register_metric, +) + +from .triton_gather_gemv import triton_gemv_0 as triton_test_0 +from torch._dynamo.testing import rand_strided + +class Operator(BenchmarkOperator): + + @register_metric() + def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): + arg0_1, arg1_1, arg2_1 = example_inputs + gbps = ( + lambda ms: 2 + * arg2_1.size(0) * arg2_1.size(0) + * arg0_1.element_size() + / ms + * 1e-6 + ) + return list(map(gbps, metrics.latency)) + + def __init__(self, mode: str, device: str, extra_args: List[str] = []): + super().__init__(mode=mode, device=device, extra_args=extra_args) + + @register_benchmark(baseline=True) + def test_0(self, p1, p2, p3) -> Callable: + return lambda: triton_test_0(p1, p2, p3) + + @register_benchmark(baseline=True) + def test_eager(self, w, idx, x): + return lambda: w[idx].to(x.dtype) @ x + + @register_benchmark() + def test_inductor(self, w, idx, x): + @torch.compile + def gather_gemv(w, idx, x): + return w[idx].to(x.dtype) @ x + + gather_gemv(w, idx, x) # warmup + return lambda: gather_gemv(w, idx, x) + + def get_x_val(self, example_inputs) -> float: + arg0_1, arg1_1, arg2_1 = example_inputs + s = arg2_1.size(0) + return s + + def get_input_iter(self) -> Generator: + for i in range(11, 15): + S = 2 ** i + arg0_1 = rand_strided((8, S, S), (S*S, S, 1), device='cuda:0', dtype=torch.int8) + arg1_1 = rand_strided((2, ), (1, ), device='cuda:0', dtype=torch.int64) + arg2_1 = rand_strided((S, ), (1, ), device='cuda:0', dtype=torch.bfloat16) + yield arg0_1, arg1_1, arg2_1 + diff --git a/torchbenchmark/operators/gather_gemv/triton_gather_gemv.py b/torchbenchmark/operators/gather_gemv/triton_gather_gemv.py new file mode 100644 index 0000000000..2b465b2580 --- /dev/null +++ b/torchbenchmark/operators/gather_gemv/triton_gather_gemv.py @@ -0,0 +1,116 @@ +""" +Based on https://github.com/pytorch/pytorch/issues/121661 +""" + +import torch + +import triton +import triton.language as tl +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +reinterpret_tensor = torch.ops.inductor._reinterpret_tensor +assert_size_stride = torch._C._dynamo.guards.assert_size_stride + + +@triton.autotune( + configs=[ + triton.Config( + { + "XBLOCK": 1, + "RBLOCK": 2048, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "XBLOCK": 64, + "RBLOCK": 8, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "XBLOCK": 64, + "RBLOCK": 4, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "XBLOCK": 8, + "RBLOCK": 512, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "XBLOCK": 8, + "RBLOCK": 256, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "XBLOCK": 64, + "RBLOCK": 64, + }, + num_stages=1, + num_warps=8, + ), + ], + key=["xnumel", "rnumel"], +) +@triton.jit +def triton_red_fused_mv_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = xindex < xnumel + rbase = tl.arange(0, RBLOCK)[None, :].to(tl.int64) + x0 = xindex + # x0 // rnumel should have the same value of either 0 or 1 + tmp0 = tl.load(in_ptr0 + ((x0 // rnumel)), None, eviction_policy='evict_last') + _tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32) + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + r1 = rindex # size (1, RBLOCK) + tmp7 = tl.load(in_ptr2 + (r1), None, eviction_policy='evict_last').to(tl.float32) + tmp1 = tmp0 + 8 + tmp2 = tmp0 < 0 + tmp3 = tl.where(tmp2, tmp1, tmp0) # size (XBLOCK, 1) + # in_ptr1 has (B, S, S) shape, tmp3 is the 2nd dimension with stride of S * S. + tmp4 = tl.load(in_ptr1 + (r1 + (rnumel*(x0 % rnumel)) + (rnumel*rnumel*tmp3)), None, eviction_policy='evict_first') + tmp5 = tmp4.to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp8 = tmp7.to(tl.float32) + tmp9 = tmp6 * tmp8 # (XBLOCK, RBLOCK) * (1, RBLOCK) + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK]) + tmp12 = _tmp11 + tmp10 + _tmp11 = tmp12 + tmp11 = tl.sum(_tmp11, 1)[:, None] + tmp13 = tmp11.to(tl.float32) + tl.store(out_ptr1 + (x0), tmp13, None) + + +def triton_gemv_0(arg0_1, arg1_1, arg2_1): + S, = arg2_1.shape + assert_size_stride(arg0_1, (8, S, S), (S*S, S, 1)) + assert_size_stride(arg1_1, (2, ), (1, )) + assert_size_stride(arg2_1, (S, ), (1, )) + xnumel = 2*S + rnumel = S + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + # size will be double + buf1 = empty_strided_cuda((2*S, ), (1, ), torch.bfloat16) + + grid = lambda META: ( + triton.cdiv(2*S, META["XBLOCK"]), + ) + triton_red_fused_mv_0[grid](arg1_1, arg0_1, arg2_1, buf1, xnumel, rnumel) + return (reinterpret_tensor(buf1, (2, S), (S, 1), 0), ) +