diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index aa0719a3ab..ba876a0fbb 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -2295,3 +2295,17 @@ def has_torch_function(vt: "torch._dynamo.variables.base.VariableTracker") -> bo isinstance(vt, UserDefinedObjectVariable) and hasattr(vt.value, "__torch_function__") ) + + +# see note [Tensor Fakification and Symbol Caching] +def to_fake_tensor(t, fake_mode): + symbolic_context = None + source = None + if tracing_context := torch._guards.TracingContext.try_get(): + if t in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[t] + source = symbolic_context.tensor_source + + return fake_mode.from_tensor( + t, static_shapes=False, symbolic_context=symbolic_context, source=source + )