From c2f32d68f35c85eae18a9ba3f79aeed428bc7caa Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang (Meta Employee)" Date: Wed, 24 Apr 2024 21:50:15 -0700 Subject: [PATCH] Add pending_fresh_unbacked_symbols, populate unbacked_bindings for Dynamo (#124290) Summary: The important comment: ``` # Whenever we allocate a fresh unbacked Symbol, we add it to this # pending list. Unbacked symbol allocation can occur at unpredictable # points during meta tensor propagation, but at some point, the we # have to know what the binding site for an unbacked symbol is, and # this is computed when we actually place the node in the graph. The # important thing is that we always actually handle every unaccounted # for unbacked symbol, so this list helps us keep track of them and # then make sure they are all accounted for. # # We could potentially give rise to errors earlier by lexically # scoping when we do propagation, and only allowing unbacked symbols # to be allocated at this point in time. However this is inconvenient # to do in Dynamo, because fake tensor propagation is far from when we # analyze binding sites (set_example_value), so we do it in a more # mutatey way. # # NB: fresh unbacked symbols NEVER get substitutions applied to them, # they are binding sites! ``` The compute_unbacked_bindings is the other half of the equation: the thing that actually consumes the pending_fresh_unbacked_symbols and does something with them. Important comment: ``` After having run fake tensor propagation and producing example_value result, traverse example_value looking for freshly bound unbacked symbols and record their paths for later. It is an error if we have allocated an unbacked SymInt but it cannot be found in example_value. (NB: this means if you have a multi-output function, you must call this on the tuple of tensor output, you cannot wait!) ``` For example, if I return a tensor with size `[u0, u1]`, and u1 is a fresh unbacked SymInt, then I'll have `{u1: KeyPath(".size(1)")}`, telling me I can get u1 by running `size(1)` on the result of this node. u0 is not fresh (it probably flowed in as an argument), so I don't generate a binding for it. I eventually intend to propagate this information all the way to Inductor lowering, where extra metadata about unbacked symbol binding will be canonically used for codegen, instead of trying to infer it from defs/uses. Signed-off-by: Edward Z. Yang X-link: https://github.com/pytorch/pytorch/pull/124290 Approved by: https://github.com/lezcano Reviewed By: jeanschmidt Differential Revision: D56521430 Pulled By: ezyang fbshipit-source-id: 5630756ae11e1bf854e837798f702b706d118e43 --- userbenchmark/dynamo/dynamobench/_dynamo/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 2b9be139ec..f4f9bc93bb 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -1153,7 +1153,11 @@ def set_example_value(node, example_value): # this to accurately reflect what the state of the value was at the time # the program was traced). node.meta["example_value"] = example_value - assert TracingContext.try_get() is not None + shape_env = TracingContext.get().fake_mode.shape_env + if symbol_to_path := torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings( + shape_env, example_value + ): + node.meta["unbacked_bindings"] = symbol_to_path def _get_fake_tensor(vt):