From 0d7b6afc7357cbc0dcd8c56336d4cff503ab1893 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 13 Oct 2023 17:44:55 -0700 Subject: [PATCH] patch the pt2_compilation_time property Summary: As the title says Reviewed By: dshi7 Differential Revision: D50277869 fbshipit-source-id: 0998a70512e1862a3e6eae4e5e83dd8a9e635c76 --- torchbenchmark/util/model.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py index 2f4a008610..5bfff714a6 100644 --- a/torchbenchmark/util/model.py +++ b/torchbenchmark/util/model.py @@ -383,18 +383,20 @@ def enable_amp(self): self.forward_contexts.append(self.amp_context) @property - def pt2_compilation_time(self): + def pt2_compilation_time(self) -> Optional[float]: from torch._dynamo.utils import compile_times - compile_time = dict(zip(*compile_times(repr="csv", aggregate=True)))["_compile..compile_inner"] - return float(compile_time) + compile_time = dict(zip(*compile_times(repr="csv", aggregate=True))) + if "_compile..compile_inner" in compile_time: + return float(compile_time["_compile..compile_inner"]) + return None @property - def pt2_graph_breaks(self): + def pt2_graph_breaks(self) -> int: from torch._dynamo.utils import counters num_graph_breaks = len(counters["graph_break"].keys()) return num_graph_breaks @property - def ttfb(self): + def ttfb(self) -> float: """Return the time taken to the first batch in ms.""" return (self._end_init_time - self._start_init_time) / 1_000_000