diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py index 2f4a008610..5bfff714a6 100644 --- a/torchbenchmark/util/model.py +++ b/torchbenchmark/util/model.py @@ -383,18 +383,20 @@ def enable_amp(self): self.forward_contexts.append(self.amp_context) @property - def pt2_compilation_time(self): + def pt2_compilation_time(self) -> Optional[float]: from torch._dynamo.utils import compile_times - compile_time = dict(zip(*compile_times(repr="csv", aggregate=True)))["_compile..compile_inner"] - return float(compile_time) + compile_time = dict(zip(*compile_times(repr="csv", aggregate=True))) + if "_compile..compile_inner" in compile_time: + return float(compile_time["_compile..compile_inner"]) + return None @property - def pt2_graph_breaks(self): + def pt2_graph_breaks(self) -> int: from torch._dynamo.utils import counters num_graph_breaks = len(counters["graph_break"].keys()) return num_graph_breaks @property - def ttfb(self): + def ttfb(self) -> float: """Return the time taken to the first batch in ms.""" return (self._end_init_time - self._start_init_time) / 1_000_000