Skip to content

Commit

Permalink
Trace torch function modes entered outside of torch.compile (#133137)
Browse files Browse the repository at this point in the history
Summary:
This PR adds initial tracing for torch function modes.

Details:
In essence, this adds tracing into the torch function of modes entered outside of the torch.compile call.
This does not yet support tracing enter/exit of a torch function mode/ tracing set_default_device properly using the new mode infra (this will be a very good stress test for modes). I am adding more PRs to this stack to support these. The overall plan is to support tracing enter/exit and handling graph breaks like we do other torch.* context managers.

Previously landed:
pytorch/pytorch#133135
pytorch/pytorch#133136
pytorch/pytorch#133134
pytorch/pytorch#133133
pytorch/pytorch#133132
pytorch/pytorch#133131
pytorch/pytorch#133729
pytorch/pytorch#133130

X-link: pytorch/pytorch#133137
Approved by: https://github.com/jansel, https://github.com/zou3519
ghstack dependencies: #134732

Reviewed By: jeanschmidt

Differential Revision: D62737267

Pulled By: mlazos

fbshipit-source-id: a913a5f89b409e38bc7f940f75f4510fe09fdde3
  • Loading branch information
mlazos authored and facebook-github-bot committed Sep 16, 2024
1 parent c709128 commit 2536a65
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
import torch.utils._pytree as pytree
from torch import fx
from torch._C import (
_get_function_stack_at,
_instruction_counter,
_len_torch_function_stack,
_pop_torch_function_stack,
Expand Down Expand Up @@ -3087,7 +3086,9 @@ def is_parameter_freezing():
def get_torch_function_mode_stack(filter_ignored=True):
from .variables.torch_function import IGNORED_MODES

stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())]
stack = [
get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack())
]
if filter_ignored:
stack = [mode for mode in stack if type(mode) not in IGNORED_MODES]

Expand All @@ -3107,6 +3108,11 @@ def set_torch_function_mode_stack(stack):
_push_on_torch_function_stack(mode)


def clear_torch_function_mode_stack():
for i in range(_len_torch_function_stack()):
_pop_torch_function_stack()


def verify_guard_fn_signature(value):
fn = value.__metadata_guard__
sig = inspect.signature(fn)
Expand Down

0 comments on commit 2536a65

Please sign in to comment.