Skip to content

Commit

Permalink
Add remote cache time saved to compilation metrics (#2449)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2449

X-link: pytorch/pytorch#135490

Record remote cache time saved via frame_phase_timing

We add to the "phase" when remote cache hits and saves us time, so that we have a 1:1 correspondence between a frame and time saved.

Reviewed By: aorenste

Differential Revision: D62106921

fbshipit-source-id: 57f84c189fea7a40ad836c7f59f6801d22973c4f
  • Loading branch information
jamesjwu authored and facebook-github-bot committed Sep 13, 2024
1 parent 195d119 commit 47413ed
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,21 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None:
frame_phase_timing[key][phase_name] += time_spent


# Use frame_phase_timing to record remote_cache_time_saved
# This follows the same principles of key as the other frame phase timings,
# but is incremented by FxGraphCache (and later AOTAutogradCache) directly
def add_remote_cache_time_saved(time_saved_ns: int, is_backward: bool = False) -> None:
key = None
if is_backward:
# Use compile id as the frame key for backwards compilation
key = str(torch._guards.CompileContext.current_compile_id())
else:
key = str(curr_frame)
# Convert to seconds (as a float)
time_saved = time_saved_ns / 1e9
_add_time_spent(key, "remote_cache_time_saved", time_saved)


def get_cache_stats() -> Dict[str, Any]:
"""Get a bunch of metadata about cache hits and misses to use in chromium events"""
cache_stats = {
Expand Down Expand Up @@ -332,15 +347,20 @@ def dynamo_timed(
code_gen_time = frame_phase_timing[compile_id].get(
"code_gen", None
)
remote_cache_time_saved = frame_phase_timing[
compile_id
].get("remote_cache_time_saved", None)
else:
inductor_compile_time = None
code_gen_time = None
remote_cache_time_saved = None
metrics = BwdCompilationMetrics(
compile_id,
inductor_compile_time,
code_gen_time,
fail_type,
fail_reason,
remote_cache_time_saved,
)
record_compilation_metrics(metrics)

Expand Down Expand Up @@ -779,6 +799,7 @@ class CompilationMetrics:
# a compiled frame
has_guarded_code: bool
possibly_missed_reinplacing_opportunities: Optional[int]
remote_cache_time_saved_s: Optional[float]


@dataclasses.dataclass
Expand All @@ -788,6 +809,7 @@ class BwdCompilationMetrics:
code_gen_time_s: Optional[float]
fail_type: Optional[str]
fail_reason: Optional[str]
remote_cache_time_saved_s: Optional[float]


DEFAULT_COMPILATION_METRICS_LIMIT = 64
Expand Down

0 comments on commit 47413ed

Please sign in to comment.