From 49c4a5fbbd62a50a003fd125d70ba68de3470040 Mon Sep 17 00:00:00 2001 From: "laith sakka (Meta Employee)" Date: Sat, 17 Feb 2024 04:00:49 -0800 Subject: [PATCH] Do to convert UnsupportedFakeTensorException into RuntimeError in runNode for proper graph breaking. (#120026) Summary: Fix: https://github.com/pytorch/pytorch/issues/119779 by properly graph breaking a proper fix is to handle quantized tensors for full complete solution. if when generating a fake tensor, UnsupportedFakeTensorException is thrown, then its handled and converted into a Unimplemented in inside wrap_fake_exception which is then translated to a graph break. However run_node used to convert UnsupportedFakeTensorException into a runtime error, creating runtime errors instead of graph breaks whenever generating a fake tensor for a quantized tensor fails. X-link: https://github.com/pytorch/pytorch/pull/120026 Approved by: https://github.com/jansel Reviewed By: huydhn Differential Revision: D53879832 Pulled By: laithsakka fbshipit-source-id: 1d2b7af0c9882d74fc6456474a79ad0709ba8ff7 --- .../dynamo/dynamobench/_dynamo/utils.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 3eaba129af..089b25a1bc 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -1722,6 +1722,10 @@ def run_node(tracer, node, args, kwargs, nnmodule): op = node.op with set_current_node(node): + + def make_error_message(e): + return f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n" + str(e) + try: if op == "call_function": return node.target(*args, **kwargs) @@ -1735,17 +1739,16 @@ def run_node(tracer, node, args, kwargs, nnmodule): elif op == "placeholder": assert "example_value" in node.meta return node.meta["example_value"] - except NotImplementedError as e: + + except (NotImplementedError, UnsupportedFakeTensorException) as e: # NB: mimic how wrap_fake_exception does it from .exc import unimplemented - raise unimplemented( - f"running {op} {node.target}(*{args}, **{kwargs})" - ) from e - + raise unimplemented(make_error_message(e)) from e except Exception as e: - fn_str = f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n" - raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e + raise RuntimeError(make_error_message(e)).with_traceback( + e.__traceback__ + ) from e raise AssertionError(op)