diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py index 1b120d60b1..41b60b5a98 100644 --- a/torchbenchmark/util/triton_op.py +++ b/torchbenchmark/util/triton_op.py @@ -394,7 +394,7 @@ def _do_bench(self, @register_metric() - def tflops(self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> List[float]: + def tflops(self, fn: Callable, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> List[float]: def _get_flops(self, func: Callable) -> float: """By default, use the torch.__dispatch__ based flops counter.""" from torch.utils.flop_counter import FlopCounterMode @@ -410,7 +410,6 @@ def work_func(): work_func() total_flops = sum([v for _, v in flop_counter.flop_counts["Global"].items()]) return total_flops - fn = self._get_bm_func(fn_name) if not fn in self._op_flops: self._op_flops[fn] = _get_flops(self, fn) op_flops = self._op_flops[fn]