From efcf520c4050ca0342070e7777cd6274f9841b9d Mon Sep 17 00:00:00 2001 From: voznesenskym Date: Fri, 24 Nov 2023 22:06:32 -0800 Subject: [PATCH] Add Stateful/Stateless symbolic contexts, use fresh fake mode for dynamo backends (#2058) Summary: X-link: https://github.com/pytorch/pytorch/pull/114526 X-link: https://github.com/pytorch/pytorch/pull/113926 The primary problem we are setting out to solve here is fake tensor freshness. Before this PR, fake tensors after dynamo represented fake tensors *at the end* of trace, so subsequent retraces like aot_autograd would start off with fake tensors in the wrong (end result) state, rather than their expected fresh state. The solution here is to start a fresh fake mode, and re-fakify the tensors. The nuance comes from ensuring that symbols are uniformly created for the symbolic sizes and strides of the tensor. This PR is the result of *a lot* of back and forth with ezyang and eellison. Initially, the first pass at this was not super different from what we have in the PR - the broad strokes were the same: 1) We cache source->symbol in shape_env 2) We pass policy objects around, stored at dynamo fakificaiton time, and reused for later fakification 3) We create a new fake mode for backends (from https://github.com/pytorch/pytorch/pull/113605/files) This is ugly, and has some layering violations. We detoured our decision making through a few other alternatives. Immutable/mutable fake tensor mode was the most interesting alternative, https://github.com/pytorch/pytorch/pull/113653, and was struck down on concerns of complexity in fake mode combined with it not covering all edge cases. We also detoured on what to do about tensor memoization returning back potentially different tensors than requested, and if that was an anti pattern (it is) we want to hack in with the symbol cache (we don't). We went back to the drawing board here, but with a few concessions: 1) the cache for source->symbol must live outside of shape_env, for both lifecycle, and layering reasons 2) A good amount of work needs to be done to pipe policy around fake_mode and meta_utils correctly, to cover all the cases (ezyang did this) cc penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng imported-using-ghimport Test Plan: Imported from OSS Reviewed By: huydhn, Chillee Differential Revision: D51566250 Pulled By: voznesenskym --- userbenchmark/dynamo/dynamobench/_dynamo/utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index aa0719a3ab..ba876a0fbb 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -2295,3 +2295,17 @@ def has_torch_function(vt: "torch._dynamo.variables.base.VariableTracker") -> bo isinstance(vt, UserDefinedObjectVariable) and hasattr(vt.value, "__torch_function__") ) + + +# see note [Tensor Fakification and Symbol Caching] +def to_fake_tensor(t, fake_mode): + symbolic_context = None + source = None + if tracing_context := torch._guards.TracingContext.try_get(): + if t in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[t] + source = symbolic_context.tensor_source + + return fake_mode.from_tensor( + t, static_shapes=False, symbolic_context=symbolic_context, source=source + )