diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 44bb672d6..ef0f22cce 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -27,7 +27,7 @@ import warnings import weakref from contextlib import contextmanager -from functools import lru_cache, wraps +from functools import lru_cache from types import MethodWrapperType from typing import ( Any, @@ -51,7 +51,7 @@ Union, ValuesView, ) -from typing_extensions import Literal, ParamSpec, TypeGuard +from typing_extensions import Literal, TypeGuard import torch import torch._functorch.config @@ -107,7 +107,6 @@ T = TypeVar("T") -_P = ParamSpec("_P") unpatched_nn_module_getattr = torch.nn.Module.__getattr__ @@ -211,18 +210,24 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None: frame_phase_timing[key][phase_name] += time_spent -# dynamo_timed API works as a function decorator +# dynamo_timed is a context manager # By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics # where the key is the functions name. # For example: # -# @dynamo_timed -# def _foo(...): +# def _foo(...): +# with dynamo_timed("_foo"): +# ... # # Would show up as an entry in our timing dict: -# OrderedDict([('bar.._foo', [0.083690, 0.23949, 3.1425e-05])]) +# OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])]) # This is extremely useful for granular debugging. # +# Although it is tempting to use dynamo_timed as a decorator, please do not. +# In its decorator form it makes cProfile traces less useful as dynamo_timed +# suddenly becomes a bottleneck for lots of function calls (as only one parent +# pointer is recorded). +# # For a higher-level mode, pass a phase_name into dynamo_timed # phase_names record an extra record into a separate compilation timing structure, # one keyed on frame+name rather than function. @@ -232,106 +237,76 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None: # The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs. -@overload -def dynamo_timed( - original_function: Callable[_P, T], - phase_name: Optional[str] = None, - fwd_only: bool = True, -) -> Callable[_P, T]: - ... - - -@overload -def dynamo_timed( - original_function: Literal[None] = None, - phase_name: Optional[str] = None, - fwd_only: bool = True, -) -> Callable[[Callable[_P, T]], Callable[_P, T]]: - ... - - +@contextmanager def dynamo_timed( - original_function: Optional[Callable[_P, T]] = None, + key: str, phase_name: Optional[str] = None, fwd_only: bool = True, ): - def dynamo_timed_inner(func: Callable[_P, T]) -> Callable[_P, T]: - @wraps(func) - def time_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> T: - key = func.__qualname__ - if key not in compilation_time_metrics: - compilation_time_metrics[key] = [] - - fail_type: Optional[str] = None - fail_reason: Optional[str] = None - time_spent = float("-inf") - try: - with torch.profiler.record_function(f"{key} (dynamo_timed)"): - t0 = time.time() - r = func(*args, **kwargs) - time_spent = time.time() - t0 - compilation_time_metrics[key].append(time_spent) - except Exception as e: - fail_type = str(type(e)) - fail_reason = str(e) - raise - finally: - # Only record backward compilation metrics if phase_name is not None! - if phase_name: - frame_key = str(curr_frame) - # fwd only compilation stages: entire_frame_compile, backend_compile. - # use frame_key as time aggregation key. - if fwd_only and fail_type is None: + if key not in compilation_time_metrics: + compilation_time_metrics[key] = [] + + fail_type: Optional[str] = None + fail_reason: Optional[str] = None + time_spent = float("-inf") + try: + with torch.profiler.record_function(f"{key} (dynamo_timed)"): + t0 = time.time() + yield + time_spent = time.time() - t0 + compilation_time_metrics[key].append(time_spent) + except Exception as e: + fail_type = str(type(e)) + fail_reason = str(e) + raise + finally: + # Only record backward compilation metrics if phase_name is not None! + if phase_name: + frame_key = str(curr_frame) + # fwd only compilation stages: entire_frame_compile, backend_compile. + # use frame_key as time aggregation key. + if fwd_only and fail_type is None: + _add_time_spent(frame_key, phase_name, time_spent) + else: + # fwd + bwd compilation stages: inductor_compile, code_gen. + # use frame_key as time aggregation key for fwd graphs; + # use compile_id as time aggregation key for bwd graphs. + if torch._guards.TracingContext.try_get() is not None: + aot_graph_name = str( + torch._guards.TracingContext.get().aot_graph_name + ) + if ( + "forward" in aot_graph_name or "inference" in aot_graph_name + ) and fail_type is None: _add_time_spent(frame_key, phase_name, time_spent) - else: - # fwd + bwd compilation stages: inductor_compile, code_gen. - # use frame_key as time aggregation key for fwd graphs; - # use compile_id as time aggregation key for bwd graphs. - if torch._guards.TracingContext.try_get() is not None: - aot_graph_name = str( - torch._guards.TracingContext.get().aot_graph_name - ) - if ( - "forward" in aot_graph_name - or "inference" in aot_graph_name - ) and fail_type is None: - _add_time_spent(frame_key, phase_name, time_spent) - elif "backward" in aot_graph_name: - compile_id = str( - torch._guards.CompileContext.current_compile_id() + elif "backward" in aot_graph_name: + compile_id = str( + torch._guards.CompileContext.current_compile_id() + ) + if fail_type is None: + _add_time_spent(compile_id, phase_name, time_spent) + + # log backward compilation metrics at the end of `inductor_compile` of bwd graph, + # one record for one bwd graph. + if phase_name == "inductor_compile": + if fail_type is None: + inductor_compile_time = frame_phase_timing[ + compile_id + ].get("inductor_compile", None) + code_gen_time = frame_phase_timing[compile_id].get( + "code_gen", None ) - if fail_type is None: - _add_time_spent(compile_id, phase_name, time_spent) - - # log backward compilation metrics at the end of `inductor_compile` of bwd graph, - # one record for one bwd graph. - if phase_name == "inductor_compile": - if fail_type is None: - inductor_compile_time = frame_phase_timing[ - compile_id - ].get("inductor_compile", None) - code_gen_time = frame_phase_timing[ - compile_id - ].get("code_gen", None) - else: - inductor_compile_time = None - code_gen_time = None - metrics = BwdCompilationMetrics( - compile_id, - inductor_compile_time, - code_gen_time, - fail_type, - fail_reason, - ) - record_compilation_metrics(metrics) - - return r - - return time_wrapper - - if original_function: - return dynamo_timed_inner(original_function) - return dynamo_timed_inner + else: + inductor_compile_time = None + code_gen_time = None + metrics = BwdCompilationMetrics( + compile_id, + inductor_compile_time, + code_gen_time, + fail_type, + fail_reason, + ) + record_compilation_metrics(metrics) @overload @@ -350,7 +325,7 @@ def compile_times(repr="str", aggregate: bool = False): """ Get metrics about torchdynamo frontend/backend compilation times. - Accumulates information from functions tagged with `@dynamo_timed`. + Accumulates information from functions tagged with `dynamo_timed`. repr='str' returns a printable string for user interaction, and 'csv' returns headers, rows which can be logged for output