Skip to content

Commit

Permalink
add gather + gemv
Browse files Browse the repository at this point in the history
Summary: Based on PT2 test case: pytorch/pytorch#121661

Reviewed By: bertmaher

Differential Revision: D56437647

fbshipit-source-id: 0d735c443c9cea419dce971fd6ea796ca444c0ff
  • Loading branch information
manman-ren authored and facebook-github-bot committed May 9, 2024
1 parent 6f3faa0 commit 4c7ec3a
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchbenchmark/operators/gather_gemv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
73 changes: 73 additions & 0 deletions torchbenchmark/operators/gather_gemv/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Based on PT2 test case: https://github.com/pytorch/pytorch/issues/121661
Motivated by https://www.thonking.ai/p/short-supporting-mixtral-in-gpt-fast,
gather + gemv is the primary kernel driving mixtral perf.
"""

import csv
import os
import statistics
from typing import Any, Callable, Generator, List, Optional

import numpy
import torch
import triton


from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
register_benchmark,
register_metric,
)

from .triton_gather_gemv import triton_gemv_0 as triton_test_0
from torch._dynamo.testing import rand_strided

class Operator(BenchmarkOperator):

@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
arg0_1, arg1_1, arg2_1 = example_inputs
gbps = (
lambda ms: 2
* arg2_1.size(0) * arg2_1.size(0)
* arg0_1.element_size()
/ ms
* 1e-6
)
return list(map(gbps, metrics.latency))

def __init__(self, mode: str, device: str, extra_args: List[str] = []):
super().__init__(mode=mode, device=device, extra_args=extra_args)

@register_benchmark(baseline=True)
def test_0(self, p1, p2, p3) -> Callable:
return lambda: triton_test_0(p1, p2, p3)

@register_benchmark(baseline=True)
def test_eager(self, w, idx, x):
return lambda: w[idx].to(x.dtype) @ x

@register_benchmark()
def test_inductor(self, w, idx, x):
@torch.compile
def gather_gemv(w, idx, x):
return w[idx].to(x.dtype) @ x

gather_gemv(w, idx, x) # warmup
return lambda: gather_gemv(w, idx, x)

def get_x_val(self, example_inputs) -> float:
arg0_1, arg1_1, arg2_1 = example_inputs
s = arg2_1.size(0)
return s

def get_input_iter(self) -> Generator:
for i in range(11, 15):
S = 2 ** i
arg0_1 = rand_strided((8, S, S), (S*S, S, 1), device='cuda:0', dtype=torch.int8)
arg1_1 = rand_strided((2, ), (1, ), device='cuda:0', dtype=torch.int64)
arg2_1 = rand_strided((S, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
yield arg0_1, arg1_1, arg2_1

116 changes: 116 additions & 0 deletions torchbenchmark/operators/gather_gemv/triton_gather_gemv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
Based on https://github.com/pytorch/pytorch/issues/121661
"""

import torch

import triton
import triton.language as tl
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride


@triton.autotune(
configs=[
triton.Config(
{
"XBLOCK": 1,
"RBLOCK": 2048,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"XBLOCK": 64,
"RBLOCK": 8,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"XBLOCK": 64,
"RBLOCK": 4,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"XBLOCK": 8,
"RBLOCK": 512,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"XBLOCK": 8,
"RBLOCK": 256,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"XBLOCK": 64,
"RBLOCK": 64,
},
num_stages=1,
num_warps=8,
),
],
key=["xnumel", "rnumel"],
)
@triton.jit
def triton_red_fused_mv_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xoffset = tl.program_id(0).to(tl.int64) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64)
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :].to(tl.int64)
x0 = xindex
# x0 // rnumel should have the same value of either 0 or 1
tmp0 = tl.load(in_ptr0 + ((x0 // rnumel)), None, eviction_policy='evict_last')
_tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex # size (1, RBLOCK)
tmp7 = tl.load(in_ptr2 + (r1), None, eviction_policy='evict_last').to(tl.float32)
tmp1 = tmp0 + 8
tmp2 = tmp0 < 0
tmp3 = tl.where(tmp2, tmp1, tmp0) # size (XBLOCK, 1)
# in_ptr1 has (B, S, S) shape, tmp3 is the 2nd dimension with stride of S * S.
tmp4 = tl.load(in_ptr1 + (r1 + (rnumel*(x0 % rnumel)) + (rnumel*rnumel*tmp3)), None, eviction_policy='evict_first')
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp5.to(tl.float32)
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp6 * tmp8 # (XBLOCK, RBLOCK) * (1, RBLOCK)
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK])
tmp12 = _tmp11 + tmp10
_tmp11 = tmp12
tmp11 = tl.sum(_tmp11, 1)[:, None]
tmp13 = tmp11.to(tl.float32)
tl.store(out_ptr1 + (x0), tmp13, None)


def triton_gemv_0(arg0_1, arg1_1, arg2_1):
S, = arg2_1.shape
assert_size_stride(arg0_1, (8, S, S), (S*S, S, 1))
assert_size_stride(arg1_1, (2, ), (1, ))
assert_size_stride(arg2_1, (S, ), (1, ))
xnumel = 2*S
rnumel = S
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
# size will be double
buf1 = empty_strided_cuda((2*S, ), (1, ), torch.bfloat16)

grid = lambda META: (
triton.cdiv(2*S, META["XBLOCK"]),
)
triton_red_fused_mv_0[grid](arg1_1, arg0_1, arg2_1, buf1, xnumel, rnumel)
return (reinterpret_tensor(buf1, (2, S), (S, 1), 0), )

0 comments on commit 4c7ec3a

Please sign in to comment.