Skip to content

Commit

Permalink
patch the pt2_compilation_time property
Browse files Browse the repository at this point in the history
Summary: As the title says

Reviewed By: dshi7

Differential Revision: D50277869

fbshipit-source-id: 0998a70512e1862a3e6eae4e5e83dd8a9e635c76
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Oct 14, 2023
1 parent 7dc8e81 commit 0d7b6af
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions torchbenchmark/util/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,18 +383,20 @@ def enable_amp(self):
self.forward_contexts.append(self.amp_context)

@property
def pt2_compilation_time(self):
def pt2_compilation_time(self) -> Optional[float]:
from torch._dynamo.utils import compile_times
compile_time = dict(zip(*compile_times(repr="csv", aggregate=True)))["_compile.<locals>.compile_inner"]
return float(compile_time)
compile_time = dict(zip(*compile_times(repr="csv", aggregate=True)))
if "_compile.<locals>.compile_inner" in compile_time:
return float(compile_time["_compile.<locals>.compile_inner"])
return None

@property
def pt2_graph_breaks(self):
def pt2_graph_breaks(self) -> int:
from torch._dynamo.utils import counters
num_graph_breaks = len(counters["graph_break"].keys())
return num_graph_breaks

@property
def ttfb(self):
def ttfb(self) -> float:
"""Return the time taken to the first batch in ms."""
return (self._end_init_time - self._start_init_time) / 1_000_000

0 comments on commit 0d7b6af

Please sign in to comment.