Skip to content

Commit

Permalink
fixed fn_name parameter for BenchmarkOperator.tflops
Browse files Browse the repository at this point in the history
Summary:
There is a mismatch between the tflops argument fn and the corresponding
parameter. We pass a Callable fn to tflops, which expects a str fn_name.

Reviewed By: sijiac

Differential Revision: D55736424

fbshipit-source-id: 870e515f07e988fd1eadb7c7165e075c56fa21b9
  • Loading branch information
chenyang78 authored and facebook-github-bot committed Apr 4, 2024
1 parent 2acbd13 commit 8d4aa68
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def _do_bench(self,


@register_metric()
def tflops(self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> List[float]:
def tflops(self, fn: Callable, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> List[float]:
def _get_flops(self, func: Callable) -> float:
"""By default, use the torch.__dispatch__ based flops counter."""
from torch.utils.flop_counter import FlopCounterMode
Expand All @@ -410,7 +410,6 @@ def work_func():
work_func()
total_flops = sum([v for _, v in flop_counter.flop_counts["Global"].items()])
return total_flops
fn = self._get_bm_func(fn_name)
if not fn in self._op_flops:
self._op_flops[fn] = _get_flops(self, fn)
op_flops = self._op_flops[fn]
Expand Down

0 comments on commit 8d4aa68

Please sign in to comment.