From 93c2136ea12fe4111fb7ea75d6f61febb3e02660 Mon Sep 17 00:00:00 2001 From: "Jason Ansel (Meta Employee)" Date: Wed, 7 Feb 2024 09:13:05 -0800 Subject: [PATCH] Skip dynamo when inside a functorch context (#118901) Summary: X-link: https://github.com/pytorch/pytorch/pull/118901 Approved by: https://github.com/zou3519 Reviewed By: atalman Differential Revision: D53498324 Pulled By: jansel fbshipit-source-id: 56f6cfe9220be89cfb0197a0f18f0e904d5942e3 --- userbenchmark/dynamo/dynamobench/_dynamo/utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 7d2e822f20..b0554f0b80 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -17,6 +17,7 @@ import operator import os import pstats +import re import subprocess import sys import textwrap @@ -761,10 +762,24 @@ def clone_inputs(example_inputs): return res +def skip_frame_if_in_functorch_mode(val: torch.Tensor): + try: + val.data_ptr() # will throw for functorch tensors + except RuntimeError as e: + from .exc import SkipFrame + + # This will be GradTrackingTensor/BatchedTensor/etc + functorch_subclass_name = re.sub(r"\(.*", "", repr(val)) + raise SkipFrame( + f"torch.compile cannot be run in context: {functorch_subclass_name}" + ) from e + + @contextmanager def preserve_rng_state(): with torch.utils._python_dispatch._disable_current_modes(): rng_state = torch.clone(torch.random.get_rng_state()) + skip_frame_if_in_functorch_mode(rng_state) if torch.cuda.is_available(): cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) try: