Skip to content

Commit

Permalink
Add fake mode in verifier (#6132)
Browse files Browse the repository at this point in the history
Add fake mode in verifier (#5805)

Summary:
Pull Request resolved: #5805

Hopefully fixes https://fb.workplace.com/groups/pytorch.edge.users/permalink/1605630670307220/

Reviewed By: larryliu0820

Differential Revision: D63734251

fbshipit-source-id: 854750227b64125a3609245c6cfcbff26b71f26a
(cherry picked from commit c10c96a)

Co-authored-by: Angela Yi <angelayi@meta.com>
  • Loading branch information
pytorchbot and angelayi authored Oct 11, 2024
1 parent b077801 commit 6e788c7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
11 changes: 11 additions & 0 deletions exir/tests/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ def f(x: torch.Tensor) -> torch.Tensor:
any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes)
)

def test_ones(self) -> None:
class M(torch.nn.Module):
def forward(self, x):
y = torch.ones(x.shape[0])
return x + y

ep = torch.export.export(
M(), (torch.ones(3),), dynamic_shapes={"x": {0: torch.export.Dim("x")}}
)
exir.to_edge(ep)

def test_possible_input_mutation(self) -> None:
def f(x: torch.Tensor) -> torch.Tensor:
return torch.add(torch.ones(5), torch.ones(5), out=x)
Expand Down
5 changes: 4 additions & 1 deletion exir/verification/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import itertools
import operator
import types
from contextlib import nullcontext
from typing import Any, List, Optional, Tuple, Type

import torch
Expand All @@ -19,6 +20,7 @@
RunHigherOrderOperatorError,
)
from torch._dispatch.python import enable_python_dispatcher
from torch._export.utils import _detect_fake_mode_from_gm

from torch._export.verifier import SpecViolationError, Verifier
from torch._ops import OpOverload
Expand Down Expand Up @@ -161,8 +163,9 @@ def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
validator = EdgeOpArgValidator(gm)
inputs = _get_inputs(gm)
fake_mode = _detect_fake_mode_from_gm(gm) or nullcontext()
try:
with enable_python_dispatcher():
with enable_python_dispatcher(), fake_mode:
validator.run(*inputs)
except RunHigherOrderOperatorError:
# NB: ignore higher order operator in the graph.
Expand Down

0 comments on commit 6e788c7

Please sign in to comment.