diff --git a/torchbenchmark/operators/fp8_gemm/fp8_gemm.py b/torchbenchmark/operators/fp8_gemm/fp8_gemm.py index 3a37cb2167..5348e3b5bb 100644 --- a/torchbenchmark/operators/fp8_gemm/fp8_gemm.py +++ b/torchbenchmark/operators/fp8_gemm/fp8_gemm.py @@ -17,12 +17,6 @@ ) -def triton_mm( - a, b, *, acc_dtype=None, allow_tf32=True, fp8_fast_accum=True, output_dtype=None -): - return triton.ops.matmul(a, b, acc_dtype, allow_tf32, fp8_fast_accum, output_dtype) - - def parse_args(args): parser = argparse.ArgumentParser(description="TritonBench fp8_gemm") parser.add_argument("--llama", action="store_true") @@ -30,6 +24,8 @@ def parse_args(args): class Operator(BenchmarkOperator): + DEFAULT_METRICS = ["tflops", "gbps", "latency"] + def __init__(self, mode, device, extra_args): super().__init__(mode=mode, device=device, extra_args=extra_args) self.extra_args = parse_args(extra_args) @@ -80,18 +76,11 @@ def torch_fp8_gemm(self, a, b): def triton_fp8_gemm(self, a, b): a = reinterpret(a, tl.float8e4nv) b = reinterpret(b, tl.float8e4nv) - return lambda: triton_mm( - a, - b, - acc_dtype=None, - allow_tf32=True, - fp8_fast_accum=True, - output_dtype=None, - ) + return lambda: triton.ops.matmul(a, b) @register_metric() def gbps( - self, fn: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics + self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics ) -> float: def nbytes(t): return t.numel() * t.element_size() @@ -103,8 +92,7 @@ def nbytes(t): m, k = a.shape _, n = b.shape gb = (nbytes(a) + nbytes(b) + nbytes(c)) / 1e9 - gbps = list(map(lambda x: gb / x * 1e3, metrics.latency)) - return statistics.median(gbps) + return list(map(lambda x: gb / x * 1e3, metrics.latency)) @register_metric() def tflops(