Skip to content

Commit

Permalink
Ban decorator usage of dynamo_timed (#132328)
Browse files Browse the repository at this point in the history
Summary:
This is a more manual version of pytorch/pytorch#132073 that just manually creates the new function at each call site instead of magicking it with clone. Review with whitespace diffs off.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

X-link: pytorch/pytorch#132328
Approved by: https://github.com/albanD

Reviewed By: ZainRizvi

Differential Revision: D60702888

Pulled By: ezyang

fbshipit-source-id: 811907dea1cf15ac2b59f7e198d2760f60f7a3d9
  • Loading branch information
ezyang authored and facebook-github-bot committed Aug 3, 2024
1 parent ae2c9e8 commit 38880ca
Showing 1 changed file with 77 additions and 102 deletions.
179 changes: 77 additions & 102 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import warnings
import weakref
from contextlib import contextmanager
from functools import lru_cache, wraps
from functools import lru_cache
from types import MethodWrapperType
from typing import (
Any,
Expand All @@ -51,7 +51,7 @@
Union,
ValuesView,
)
from typing_extensions import Literal, ParamSpec, TypeGuard
from typing_extensions import Literal, TypeGuard

import torch
import torch._functorch.config
Expand Down Expand Up @@ -107,7 +107,6 @@


T = TypeVar("T")
_P = ParamSpec("_P")

unpatched_nn_module_getattr = torch.nn.Module.__getattr__

Expand Down Expand Up @@ -211,18 +210,24 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None:
frame_phase_timing[key][phase_name] += time_spent


# dynamo_timed API works as a function decorator
# dynamo_timed is a context manager
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
# where the key is the functions name.
# For example:
#
# @dynamo_timed
# def _foo(...):
# def _foo(...):
# with dynamo_timed("_foo"):
# ...
#
# Would show up as an entry in our timing dict:
# OrderedDict([('bar.<locals>._foo', [0.083690, 0.23949, 3.1425e-05])])
# OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])])
# This is extremely useful for granular debugging.
#
# Although it is tempting to use dynamo_timed as a decorator, please do not.
# In its decorator form it makes cProfile traces less useful as dynamo_timed
# suddenly becomes a bottleneck for lots of function calls (as only one parent
# pointer is recorded).
#
# For a higher-level mode, pass a phase_name into dynamo_timed
# phase_names record an extra record into a separate compilation timing structure,
# one keyed on frame+name rather than function.
Expand All @@ -232,106 +237,76 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None:
# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs.


@overload
def dynamo_timed(
original_function: Callable[_P, T],
phase_name: Optional[str] = None,
fwd_only: bool = True,
) -> Callable[_P, T]:
...


@overload
def dynamo_timed(
original_function: Literal[None] = None,
phase_name: Optional[str] = None,
fwd_only: bool = True,
) -> Callable[[Callable[_P, T]], Callable[_P, T]]:
...


@contextmanager
def dynamo_timed(
original_function: Optional[Callable[_P, T]] = None,
key: str,
phase_name: Optional[str] = None,
fwd_only: bool = True,
):
def dynamo_timed_inner(func: Callable[_P, T]) -> Callable[_P, T]:
@wraps(func)
def time_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> T:
key = func.__qualname__
if key not in compilation_time_metrics:
compilation_time_metrics[key] = []

fail_type: Optional[str] = None
fail_reason: Optional[str] = None
time_spent = float("-inf")
try:
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
t0 = time.time()
r = func(*args, **kwargs)
time_spent = time.time() - t0
compilation_time_metrics[key].append(time_spent)
except Exception as e:
fail_type = str(type(e))
fail_reason = str(e)
raise
finally:
# Only record backward compilation metrics if phase_name is not None!
if phase_name:
frame_key = str(curr_frame)
# fwd only compilation stages: entire_frame_compile, backend_compile.
# use frame_key as time aggregation key.
if fwd_only and fail_type is None:
if key not in compilation_time_metrics:
compilation_time_metrics[key] = []

fail_type: Optional[str] = None
fail_reason: Optional[str] = None
time_spent = float("-inf")
try:
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
t0 = time.time()
yield
time_spent = time.time() - t0
compilation_time_metrics[key].append(time_spent)
except Exception as e:
fail_type = str(type(e))
fail_reason = str(e)
raise
finally:
# Only record backward compilation metrics if phase_name is not None!
if phase_name:
frame_key = str(curr_frame)
# fwd only compilation stages: entire_frame_compile, backend_compile.
# use frame_key as time aggregation key.
if fwd_only and fail_type is None:
_add_time_spent(frame_key, phase_name, time_spent)
else:
# fwd + bwd compilation stages: inductor_compile, code_gen.
# use frame_key as time aggregation key for fwd graphs;
# use compile_id as time aggregation key for bwd graphs.
if torch._guards.TracingContext.try_get() is not None:
aot_graph_name = str(
torch._guards.TracingContext.get().aot_graph_name
)
if (
"forward" in aot_graph_name or "inference" in aot_graph_name
) and fail_type is None:
_add_time_spent(frame_key, phase_name, time_spent)
else:
# fwd + bwd compilation stages: inductor_compile, code_gen.
# use frame_key as time aggregation key for fwd graphs;
# use compile_id as time aggregation key for bwd graphs.
if torch._guards.TracingContext.try_get() is not None:
aot_graph_name = str(
torch._guards.TracingContext.get().aot_graph_name
)
if (
"forward" in aot_graph_name
or "inference" in aot_graph_name
) and fail_type is None:
_add_time_spent(frame_key, phase_name, time_spent)
elif "backward" in aot_graph_name:
compile_id = str(
torch._guards.CompileContext.current_compile_id()
elif "backward" in aot_graph_name:
compile_id = str(
torch._guards.CompileContext.current_compile_id()
)
if fail_type is None:
_add_time_spent(compile_id, phase_name, time_spent)

# log backward compilation metrics at the end of `inductor_compile` of bwd graph,
# one record for one bwd graph.
if phase_name == "inductor_compile":
if fail_type is None:
inductor_compile_time = frame_phase_timing[
compile_id
].get("inductor_compile", None)
code_gen_time = frame_phase_timing[compile_id].get(
"code_gen", None
)
if fail_type is None:
_add_time_spent(compile_id, phase_name, time_spent)

# log backward compilation metrics at the end of `inductor_compile` of bwd graph,
# one record for one bwd graph.
if phase_name == "inductor_compile":
if fail_type is None:
inductor_compile_time = frame_phase_timing[
compile_id
].get("inductor_compile", None)
code_gen_time = frame_phase_timing[
compile_id
].get("code_gen", None)
else:
inductor_compile_time = None
code_gen_time = None
metrics = BwdCompilationMetrics(
compile_id,
inductor_compile_time,
code_gen_time,
fail_type,
fail_reason,
)
record_compilation_metrics(metrics)

return r

return time_wrapper

if original_function:
return dynamo_timed_inner(original_function)
return dynamo_timed_inner
else:
inductor_compile_time = None
code_gen_time = None
metrics = BwdCompilationMetrics(
compile_id,
inductor_compile_time,
code_gen_time,
fail_type,
fail_reason,
)
record_compilation_metrics(metrics)


@overload
Expand All @@ -350,7 +325,7 @@ def compile_times(repr="str", aggregate: bool = False):
"""
Get metrics about torchdynamo frontend/backend compilation times.
Accumulates information from functions tagged with `@dynamo_timed`.
Accumulates information from functions tagged with `dynamo_timed`.
repr='str' returns a printable string for user interaction, and 'csv'
returns headers, rows which can be logged for output
Expand Down

0 comments on commit 38880ca

Please sign in to comment.