From 7acad50f066a13490d3f92533c287213152502c3 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Thu, 12 Sep 2024 09:26:59 -0700 Subject: [PATCH] Add HSTU kernel repo (#2452) Summary: Enable HSTU kernels in the OSS. Pull Request resolved: https://github.com/pytorch/benchmark/pull/2452 Test Plan: ``` python run_benchmark.py triton --op addmm ``` Reviewed By: sijiac Differential Revision: D62501334 Pulled By: xuzhao9 fbshipit-source-id: ce8258352f6fbbb9025c75942900144b2db581a9 --- .gitmodules | 3 + submodules/generative-recommenders | 1 + torchbenchmark/operators/addmm/hstu.py | 88 ++++++++++++++++++++++ torchbenchmark/operators/addmm/operator.py | 7 +- 4 files changed, 98 insertions(+), 1 deletion(-) create mode 160000 submodules/generative-recommenders create mode 100644 torchbenchmark/operators/addmm/hstu.py diff --git a/.gitmodules b/.gitmodules index 86afb8835..2087b535a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -19,3 +19,6 @@ [submodule "submodules/ThunderKittens"] path = submodules/ThunderKittens url = https://github.com/HazyResearch/ThunderKittens.git +[submodule "submodules/generative-recommenders"] + path = submodules/generative-recommenders + url = https://github.com/facebookresearch/generative-recommenders.git diff --git a/submodules/generative-recommenders b/submodules/generative-recommenders new file mode 160000 index 000000000..75864114d --- /dev/null +++ b/submodules/generative-recommenders @@ -0,0 +1 @@ +Subproject commit 75864114d4bb9bc5c5dc3488b29376b8d0c77b59 diff --git a/torchbenchmark/operators/addmm/hstu.py b/torchbenchmark/operators/addmm/hstu.py new file mode 100644 index 000000000..55fd0ba15 --- /dev/null +++ b/torchbenchmark/operators/addmm/hstu.py @@ -0,0 +1,88 @@ +import importlib + +from typing import Tuple + +import torch +import triton +from torchbenchmark import add_path, SUBMODULE_PATH + +with add_path(str(SUBMODULE_PATH)): + triton_addmm = importlib.import_module( + "generative-recommenders.ops.triton.triton_addmm" + ) + _addmm_fwd = triton_addmm._addmm_fwd + + +class _AddMmFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + x: torch.Tensor, + w: torch.Tensor, + y: torch.Tensor, + ) -> torch.Tensor: + M, K = x.shape + KB, N = w.shape + assert K == KB, f"incompatible dimensions {K}, {KB}" + + is_y_1d = y.dim() == 1 + NY = y.shape[0] if is_y_1d else y.shape[1] + assert N == NY, f"incompatible dimensions {N}, {NY}" + + # Allocate output + z = torch.empty((M, N), device=x.device, dtype=x.dtype) + if M == 0 or N == 0: + ctx.save_for_backward(x, w) + ctx.is_y_1d = False + return z + + def grid(META): + return (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _addmm_fwd[grid]( + x, + w, + y, + z, + M, + N, + K, + x.stride(0), + x.stride(1), + w.stride(0), + w.stride(1), + y.stride(0) if not is_y_1d else 0, + y.stride(1) if not is_y_1d else y.stride(0), + z.stride(0), + z.stride(1), + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + BROADCAST_Y=is_y_1d, + ) + ctx.save_for_backward(x, w) + ctx.is_y_1d = is_y_1d + return z + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dz: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + (x, w) = ctx.saved_tensors + if ctx.is_y_1d: + dy = torch.sum(dz, dim=0) + else: + dy = dz + dw = torch.mm(x.t(), dz) + dx = torch.mm(dz, w.t()) + + return dx, dw, dy + + +@torch.fx.wrap +def triton_addmm( + input: torch.Tensor, + mat1: torch.Tensor, + mat2: torch.Tensor, +) -> torch.Tensor: + return _AddMmFunction.apply(mat1, mat2, input) diff --git a/torchbenchmark/operators/addmm/operator.py b/torchbenchmark/operators/addmm/operator.py index 654d4795f..d4b6d9435 100644 --- a/torchbenchmark/operators/addmm/operator.py +++ b/torchbenchmark/operators/addmm/operator.py @@ -7,7 +7,12 @@ import torch import torch._inductor.config as inductor_config import triton -from hammer.ops.triton.triton_hstu_linear import _addmm_fwd, triton_addmm +from torchbenchmark import add_path, SUBMODULE_PATH + +try: + from hammer.ops.triton.triton_hstu_linear import triton_addmm +except ModuleNotFoundError: + from .hstu import triton_addmm from torchbenchmark.util.triton_op import ( BenchmarkOperator,