Skip to content

Commit

Permalink
Skip dynamo when inside a functorch context (#118901)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#118901
Approved by: https://github.com/zou3519

Reviewed By: atalman

Differential Revision: D53498324

Pulled By: jansel

fbshipit-source-id: 56f6cfe9220be89cfb0197a0f18f0e904d5942e3
  • Loading branch information
jansel authored and facebook-github-bot committed Feb 7, 2024
1 parent 46188bf commit 93c2136
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import operator
import os
import pstats
import re
import subprocess
import sys
import textwrap
Expand Down Expand Up @@ -761,10 +762,24 @@ def clone_inputs(example_inputs):
return res


def skip_frame_if_in_functorch_mode(val: torch.Tensor):
try:
val.data_ptr() # will throw for functorch tensors
except RuntimeError as e:
from .exc import SkipFrame

# This will be GradTrackingTensor/BatchedTensor/etc
functorch_subclass_name = re.sub(r"\(.*", "", repr(val))
raise SkipFrame(
f"torch.compile cannot be run in context: {functorch_subclass_name}"
) from e


@contextmanager
def preserve_rng_state():
with torch.utils._python_dispatch._disable_current_modes():
rng_state = torch.clone(torch.random.get_rng_state())
skip_frame_if_in_functorch_mode(rng_state)
if torch.cuda.is_available():
cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
try:
Expand Down

0 comments on commit 93c2136

Please sign in to comment.