Skip to content

Commit

Permalink
Add persistent+TMA matmul to fp8 gemm benchmark (#2377)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2377

Reviewed By: xuzhao9, sijiac

Differential Revision: D59812172

Pulled By: bertmaher

fbshipit-source-id: 450229888e09c22b9cd11a37015e6d601ec919ce
  • Loading branch information
bertmaher authored and facebook-github-bot committed Jul 16, 2024
1 parent 8a3bf01 commit 4f245a8
Show file tree
Hide file tree
Showing 3 changed files with 734 additions and 5 deletions.
23 changes: 18 additions & 5 deletions torchbenchmark/operators/fp8_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import statistics
import torch
import triton.ops
import triton.language as tl

from triton.runtime.jit import reinterpret
Expand All @@ -17,10 +16,15 @@
register_metric,
)

from .tutorial import matmul as tutorial_matmul
from .persistent import matmul_persistent, matmul_tma_persistent, allocate_matmul_tma

def parse_args(args):
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
parser.add_argument("--llama", action="store_true")
parser.add_argument("--m", type=int)
parser.add_argument("--k", type=int)
parser.add_argument("--n", type=int)
return parser.parse_args(args)


Expand All @@ -45,7 +49,8 @@ def args(m, n, k):
if self.extra_args.llama:
for m, n, k, _bias in llama_shapes():
yield args(m, n, k)

elif self.extra_args.m:
yield args(self.extra_args.m, self.extra_args.n, self.extra_args.k)
else:
for i in range(10, 15):
for j in range(0, 4):
Expand All @@ -70,9 +75,17 @@ def torch_fp8_gemm(self, a, b):

@register_benchmark()
def triton_fp8_gemm(self, a, b):
a = reinterpret(a, tl.float8e4nv)
b = reinterpret(b, tl.float8e4nv)
return lambda: triton.ops.matmul(a, b)
return lambda: tutorial_matmul(a, b)

@register_benchmark()
def triton_persistent_fp8_gemm(self, a, b):
return lambda: matmul_persistent(a, b)

@register_benchmark()
def triton_tma_persistent_fp8_gemm(self, a, b):
b = b.T.contiguous()
c, desc_a, desc_b, desc_c = allocate_matmul_tma(a, b)
return lambda: matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c)

@register_metric()
def gbps(
Expand Down
Loading

0 comments on commit 4f245a8

Please sign in to comment.