Skip to content

Commit

Permalink
Trace enter/exit of TorchFunctionModes (#135422)
Browse files Browse the repository at this point in the history
Summary:
This PR implements tracing of with contexts with TorchFunction modes which have the default enter/exit behavior (ie pushing/popping the mode)

Typically the bytecode for a context manager looks like this during a graph break:
1. graph call
2. enter context
3. unsupported code
4. exit context
5. resume call

resume fn structure:
1. enter context
2. jump
...
3. exit context

The issue with torch function modes is that side effects will replay any mutations to the torch function stack performed during tracing. So, we do not need to enter and exit around the unsupported code in the original function (doing so would result in a duplicate torch function mode entry during execution of the unsupported code), and we don't need to enter again in the resume function (the mode that was pushed from the side effects bytecode would still be on the stack).

So for torch function modes the structure of our output code is this:

1. graph call
2. mutate tf mode stack to replay mutations
4. unsupported code
5. on exception restore stack
6. resume function

Then our resume fn looks like this:

1. no-op enter torch function mode
2. jump
3.  exit tf mode

To implement the no-op enter of the torch function mode I added torch function mode in polyfill which no-op enters, but normally exits. This is needed because we still want to trace the with context in the resume function, and exit properly (the exit instructions will still be in the function, so we need to generate instructions to set up the context).

Separately from the bytecode, dynamo also tracks contexts on the block stack, which is how the SETUP_* instructions are implemented. Naturally at a graph break, we exit these block stacks to properly reset the contexts entirely, so that we can re-enter around the unsupported code soundly. However once again, in the torch function mode case, in the event of a graph we do not want to perform any exit side effects because we want to preserve the state of the mode stack as is so that we will properly update the stack with bytecode mentioned in the first section. If we exited here, dynamo would pop the mode off of the symbolic stack, and not update the true python torch function mode stack with the suffix bytecode. All in all, for torch function modes we enter exactly once, update the global torch function mode stack with side effects bytecode, re-read this stack when compiling the resume function, and exit exactly once in the resume function. This matches the semantics of eager exactly.

X-link: pytorch/pytorch#135422
Approved by: https://github.com/williamwen42
ghstack dependencies: #134732, #133137, #135443, #135444

Reviewed By: jeanschmidt

Differential Revision: D62737292

Pulled By: mlazos

fbshipit-source-id: 9d9eb0221c33166f69082a70f69b4b82c1146a46
  • Loading branch information
mlazos authored and facebook-github-bot committed Sep 16, 2024
1 parent 2536a65 commit 99ea924
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions userbenchmark/dynamo/dynamobench/_dynamo/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def insert_nops(instructions, code_options):
local_scope=locals(),
global_scope=globals(),
f_code=frame.f_code,
torch_function_mode_stack=[],
)

return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))
Expand Down

0 comments on commit 99ea924

Please sign in to comment.