diff --git a/torchbenchmark/operators/gemm/operator.py b/torchbenchmark/operators/gemm/operator.py index d125a136ae..f161f5e052 100644 --- a/torchbenchmark/operators/gemm/operator.py +++ b/torchbenchmark/operators/gemm/operator.py @@ -6,6 +6,7 @@ import numpy import torch import triton +import triton.ops from torchbenchmark.util.triton_op import ( BenchmarkOperator, @@ -97,12 +98,18 @@ def __init__(self, mode: str, device: str, extra_args: List[str] = []): self.DEFAULT_NUM_BATCH = len(self.shapes) @register_benchmark() - def triton_matmul(self, a, b, bias) -> Callable: + def triton_tutorial_matmul(self, a, b, bias) -> Callable: if not bias == None: return lambda: triton_matmul(a, b) + bias else: return lambda: triton_matmul(a, b) + @register_benchmark(enabled=torch.version.cuda is not None) + def triton_ops_matmul(self, a, b, bias) -> Callable: + if bias is None: + return lambda: triton.ops.matmul(a, b) + return lambda: triton.ops.matmul(a, b, bias) + @register_benchmark(baseline=True) def aten_matmul(self, a, b, bias) -> Callable: if not bias == None: @@ -194,15 +201,17 @@ def plot(self): line_arg="provider", # argument name whose value corresponds to a different line in the plot line_vals=[ "aten_matmul", - "triton_matmul", + "triton_tutorial_matmul", + "triton_ops_matmul", "hstu_triton_matmul", ], # possible values for `line_arg`` line_names=[ "ATen GEMM", - "Triton GEMM", + "Triton Tutorial GEMM", + "triton.ops.matmul", "HSTU Triton GEMM", ], # label name for the lines - styles=[("blue", "-"), ("green", "-"), ("red", "-")], # line styles + styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("yellow", "-")], # line styles ylabel="tflops", # label name for the y-axis plot_name="gemm-performance", # name for the plot. Used also as a file name for saving the plot. args={}, # values for function arguments not in `x_names` and `y_name`