Skip to content

Commit

Permalink
Do to convert UnsupportedFakeTensorException into RuntimeError in run…
Browse files Browse the repository at this point in the history
…Node for proper graph breaking. (#120026)

Summary:
Fix: pytorch/pytorch#119779 by properly graph breaking  a proper fix is to handle quantized tensors for full complete solution.

if when generating  a fake tensor, UnsupportedFakeTensorException is thrown, then its handled and converted into a
Unimplemented in inside wrap_fake_exception which is then translated to a graph break.

However run_node used to convert  UnsupportedFakeTensorException into a runtime error, creating runtime
errors instead of graph breaks whenever generating a fake tensor for a quantized tensor fails.

X-link: pytorch/pytorch#120026
Approved by: https://github.com/jansel

Reviewed By: huydhn

Differential Revision: D53879832

Pulled By: laithsakka

fbshipit-source-id: 1d2b7af0c9882d74fc6456474a79ad0709ba8ff7
  • Loading branch information
laithsakka authored and facebook-github-bot committed Feb 17, 2024
1 parent e194d21 commit 49c4a5f
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,6 +1722,10 @@ def run_node(tracer, node, args, kwargs, nnmodule):
op = node.op

with set_current_node(node):

def make_error_message(e):
return f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n" + str(e)

try:
if op == "call_function":
return node.target(*args, **kwargs)
Expand All @@ -1735,17 +1739,16 @@ def run_node(tracer, node, args, kwargs, nnmodule):
elif op == "placeholder":
assert "example_value" in node.meta
return node.meta["example_value"]
except NotImplementedError as e:

except (NotImplementedError, UnsupportedFakeTensorException) as e:
# NB: mimic how wrap_fake_exception does it
from .exc import unimplemented

raise unimplemented(
f"running {op} {node.target}(*{args}, **{kwargs})"
) from e

raise unimplemented(make_error_message(e)) from e
except Exception as e:
fn_str = f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n"
raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
raise RuntimeError(make_error_message(e)).with_traceback(
e.__traceback__
) from e

raise AssertionError(op)

Expand Down

0 comments on commit 49c4a5f

Please sign in to comment.