Skip to content

Commit

Permalink
Add pending_fresh_unbacked_symbols, populate unbacked_bindings for Dy…
Browse files Browse the repository at this point in the history
…namo (#124290)

Summary:
The important comment:

```
        # Whenever we allocate a fresh unbacked Symbol, we add it to this
        # pending list.  Unbacked symbol allocation can occur at unpredictable
        # points during meta tensor propagation, but at some point, the we
        # have to know what the binding site for an unbacked symbol is, and
        # this is computed when we actually place the node in the graph.  The
        # important thing is that we always actually handle every unaccounted
        # for unbacked symbol, so this list helps us keep track of them and
        # then make sure they are all accounted for.
        #
        # We could potentially give rise to errors earlier by lexically
        # scoping when we do propagation, and only allowing unbacked symbols
        # to be allocated at this point in time.  However this is inconvenient
        # to do in Dynamo, because fake tensor propagation is far from when we
        # analyze binding sites (set_example_value), so we do it in a more
        # mutatey way.
        #
        # NB: fresh unbacked symbols NEVER get substitutions applied to them,
        # they are binding sites!
```

The compute_unbacked_bindings is the other half of the equation: the thing that actually consumes the pending_fresh_unbacked_symbols and does something with them. Important comment:

```
    After having run fake tensor propagation and producing example_value
    result, traverse example_value looking for freshly bound unbacked
    symbols and record their paths for later.  It is an error if
    we have allocated an unbacked SymInt but it cannot be found in
    example_value.  (NB: this means if you have a multi-output
    function, you must call this on the tuple of tensor output, you
    cannot wait!)
```

For example, if I return a tensor with size `[u0, u1]`, and u1 is a fresh unbacked SymInt, then I'll have `{u1: KeyPath(".size(1)")}`, telling me I can get u1 by running `size(1)` on the result of this node. u0 is not fresh (it probably flowed in as an argument), so I don't generate a binding for it.

I eventually intend to propagate this information all the way to Inductor lowering, where extra metadata about unbacked symbol binding will be canonically used for codegen, instead of trying to infer it from defs/uses.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

X-link: pytorch/pytorch#124290
Approved by: https://github.com/lezcano

Reviewed By: jeanschmidt

Differential Revision: D56521430

Pulled By: ezyang

fbshipit-source-id: 5630756ae11e1bf854e837798f702b706d118e43
  • Loading branch information
ezyang authored and facebook-github-bot committed Apr 25, 2024
1 parent c584443 commit c2f32d6
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,11 @@ def set_example_value(node, example_value):
# this to accurately reflect what the state of the value was at the time
# the program was traced).
node.meta["example_value"] = example_value
assert TracingContext.try_get() is not None
shape_env = TracingContext.get().fake_mode.shape_env
if symbol_to_path := torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings(
shape_env, example_value
):
node.meta["unbacked_bindings"] = symbol_to_path


def _get_fake_tensor(vt):
Expand Down

0 comments on commit c2f32d6

Please sign in to comment.