Skip to content

Commit

Permalink
Fix a bug in collecting NCU replay file
Browse files Browse the repository at this point in the history
Summary:
davidberard98 discovered a bug where we are using a hack by renaming the `ncu_rep` metric to `ncu_trace` to skip adding a new metric name. This is incorrect when users are collecting ncu replay files for multiple inputs.

In this Diff, we fix this by adding a new field for the NCU replay file in the built-in metrics.

Reviewed By: davidberard98

Differential Revision: D58208171

fbshipit-source-id: 88568bb6e3536d3830474dbdb25453d215aab269
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Jun 5, 2024
1 parent 6a02e0c commit f7b4bcc
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"accuracy",
"compile_time",
"ncu_trace",
"ncu_rep",
"kineto_trace",
"cpu_peak_mem",
"gpu_peak_mem",
Expand Down Expand Up @@ -146,6 +147,8 @@ class BenchmarkOperatorMetrics:
compile_time: Optional[float]
# ncu trace file
ncu_trace: Optional[str]
# ncu replay file
ncu_rep: Optional[str]
# kineto trace file
kineto_trace: Optional[str]
# cpu peak memory
Expand Down Expand Up @@ -703,6 +706,7 @@ def _init_extra_metrics() -> Dict[str, Any]:
walltime=None,
compile_time=None,
ncu_trace=None,
ncu_rep=None,
hw_roofline=self.hw_roofline() if "hw_roofline" in self.required_metrics else None,
kineto_trace=None,
cpu_peak_mem=None,
Expand Down Expand Up @@ -761,8 +765,7 @@ def _init_extra_metrics() -> Dict[str, Any]:
if "ncu_trace" in self.required_metrics:
metrics.ncu_trace = self.ncu_trace(input_id, fn_name)
if "ncu_rep" in self.required_metrics:
metrics.ncu_trace = self.ncu_trace(input_id, fn_name, replay=True)
self.required_metrics = list(map(lambda x: x.replace('ncu_rep', 'ncu_trace'), self.required_metrics))
metrics.ncu_rep = self.ncu_trace(input_id, fn_name, replay=True)
if "kineto_trace" in self.required_metrics:
metrics.kineto_trace = self.kineto_trace(input_id, fn)
# run the hidden metric "_compile_time_in_task"
Expand Down

0 comments on commit f7b4bcc

Please sign in to comment.