Skip to content

Commit

Permalink
Remove unnecessary tree_map_only (#121052)
Browse files Browse the repository at this point in the history
Summary:
Reduces the torch.compile(backend="eager") for this code by 1-2 seconds.

~~~
def fn(x):
    for _ in range(10000):
        # x = torch.sin(x)
        x = torch.ops.aten.sin(x)
        # x = sin(x)

    return x
~~~

X-link: pytorch/pytorch#121052
Approved by: https://github.com/jansel
ghstack dependencies: #121053

Reviewed By: izaitsevfb

Differential Revision: D54472072

Pulled By: anijain2305

fbshipit-source-id: fe0a1b018c8dc23160f77dd1d54be7914a3f9df5
  • Loading branch information
anijain2305 authored and facebook-github-bot committed Mar 4, 2024
1 parent 4920ad1 commit bbac79b
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,14 +1593,16 @@ def ensure_graph_fake(e, tx):
def get_fake_values_from_nodes(tx, nodes, allow_non_graph_fake):
def visit(n: torch.fx.Node):
if n.op == "call_function" and "example_value" not in n.meta:
# fake tensor validity is checked inside get_fake_value using
# ensure_graph_fake
return get_fake_value(n, tx, allow_non_graph_fake)

return n.meta["example_value"]
out = n.meta["example_value"]
if not allow_non_graph_fake and isinstance(out, torch.Tensor):
return ensure_graph_fake(out, tx)
return out

args_kwargs = torch.fx.node.map_arg(nodes, visit)
return tree_map_only(
torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), args_kwargs
)
return torch.fx.node.map_arg(nodes, visit)


def get_fake_value(node, tx, allow_non_graph_fake=False):
Expand Down

0 comments on commit bbac79b

Please sign in to comment.