Skip to content

Commit

Permalink
Fix fp8 gemm (#2257)
Browse files Browse the repository at this point in the history
Summary:
- triton.ops.matmul signature changed; but we don't really need the wrapper anyways
- report all of {latency, gbps, tflops} by default

Pull Request resolved: #2257

Reviewed By: adamomainz

Differential Revision: D56999596

Pulled By: bertmaher

fbshipit-source-id: 8fd2f4b390b77e12affef9fca26104ab6b15006d
  • Loading branch information
bertmaher authored and facebook-github-bot committed May 6, 2024
1 parent ed52e01 commit bdb0204
Showing 1 changed file with 5 additions and 17 deletions.
22 changes: 5 additions & 17 deletions torchbenchmark/operators/fp8_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,15 @@
)


def triton_mm(
a, b, *, acc_dtype=None, allow_tf32=True, fp8_fast_accum=True, output_dtype=None
):
return triton.ops.matmul(a, b, acc_dtype, allow_tf32, fp8_fast_accum, output_dtype)


def parse_args(args):
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
parser.add_argument("--llama", action="store_true")
return parser.parse_args(args)


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "gbps", "latency"]

def __init__(self, mode, device, extra_args):
super().__init__(mode=mode, device=device, extra_args=extra_args)
self.extra_args = parse_args(extra_args)
Expand Down Expand Up @@ -80,18 +76,11 @@ def torch_fp8_gemm(self, a, b):
def triton_fp8_gemm(self, a, b):
a = reinterpret(a, tl.float8e4nv)
b = reinterpret(b, tl.float8e4nv)
return lambda: triton_mm(
a,
b,
acc_dtype=None,
allow_tf32=True,
fp8_fast_accum=True,
output_dtype=None,
)
return lambda: triton.ops.matmul(a, b)

@register_metric()
def gbps(
self, fn: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
def nbytes(t):
return t.numel() * t.element_size()
Expand All @@ -103,8 +92,7 @@ def nbytes(t):
m, k = a.shape
_, n = b.shape
gb = (nbytes(a) + nbytes(b) + nbytes(c)) / 1e9
gbps = list(map(lambda x: gb / x * 1e3, metrics.latency))
return statistics.median(gbps)
return list(map(lambda x: gb / x * 1e3, metrics.latency))

@register_metric()
def tflops(
Expand Down

0 comments on commit bdb0204

Please sign in to comment.