Skip to content

Commit

Permalink
Add HSTU kernel repo (#2452)
Browse files Browse the repository at this point in the history
Summary:
Enable HSTU kernels in the OSS.

Pull Request resolved: #2452

Test Plan:
```
python run_benchmark.py triton --op addmm
```

Reviewed By: sijiac

Differential Revision: D62501334

Pulled By: xuzhao9

fbshipit-source-id: ce8258352f6fbbb9025c75942900144b2db581a9
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Sep 12, 2024
1 parent 1e6003d commit 7acad50
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@
[submodule "submodules/ThunderKittens"]
path = submodules/ThunderKittens
url = https://github.com/HazyResearch/ThunderKittens.git
[submodule "submodules/generative-recommenders"]
path = submodules/generative-recommenders
url = https://github.com/facebookresearch/generative-recommenders.git
1 change: 1 addition & 0 deletions submodules/generative-recommenders
88 changes: 88 additions & 0 deletions torchbenchmark/operators/addmm/hstu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import importlib

from typing import Tuple

import torch
import triton
from torchbenchmark import add_path, SUBMODULE_PATH

with add_path(str(SUBMODULE_PATH)):
triton_addmm = importlib.import_module(
"generative-recommenders.ops.triton.triton_addmm"
)
_addmm_fwd = triton_addmm._addmm_fwd


class _AddMmFunction(torch.autograd.Function):
@staticmethod
# pyre-ignore[14]
def forward(
ctx,
x: torch.Tensor,
w: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
M, K = x.shape
KB, N = w.shape
assert K == KB, f"incompatible dimensions {K}, {KB}"

is_y_1d = y.dim() == 1
NY = y.shape[0] if is_y_1d else y.shape[1]
assert N == NY, f"incompatible dimensions {N}, {NY}"

# Allocate output
z = torch.empty((M, N), device=x.device, dtype=x.dtype)
if M == 0 or N == 0:
ctx.save_for_backward(x, w)
ctx.is_y_1d = False
return z

def grid(META):
return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)

_addmm_fwd[grid](
x,
w,
y,
z,
M,
N,
K,
x.stride(0),
x.stride(1),
w.stride(0),
w.stride(1),
y.stride(0) if not is_y_1d else 0,
y.stride(1) if not is_y_1d else y.stride(0),
z.stride(0),
z.stride(1),
ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32,
BROADCAST_Y=is_y_1d,
)
ctx.save_for_backward(x, w)
ctx.is_y_1d = is_y_1d
return z

@staticmethod
# pyre-ignore[14]
def backward(
ctx, dz: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
(x, w) = ctx.saved_tensors
if ctx.is_y_1d:
dy = torch.sum(dz, dim=0)
else:
dy = dz
dw = torch.mm(x.t(), dz)
dx = torch.mm(dz, w.t())

return dx, dw, dy


@torch.fx.wrap
def triton_addmm(
input: torch.Tensor,
mat1: torch.Tensor,
mat2: torch.Tensor,
) -> torch.Tensor:
return _AddMmFunction.apply(mat1, mat2, input)
7 changes: 6 additions & 1 deletion torchbenchmark/operators/addmm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
import torch
import torch._inductor.config as inductor_config
import triton
from hammer.ops.triton.triton_hstu_linear import _addmm_fwd, triton_addmm
from torchbenchmark import add_path, SUBMODULE_PATH

try:
from hammer.ops.triton.triton_hstu_linear import triton_addmm
except ModuleNotFoundError:
from .hstu import triton_addmm

from torchbenchmark.util.triton_op import (
BenchmarkOperator,
Expand Down

0 comments on commit 7acad50

Please sign in to comment.