Skip to content

Commit

Permalink
Remove error anti-pattern when dealing with dynamic shape output (#12…
Browse files Browse the repository at this point in the history
…1681)

Summary:
There are cases where capture_dynamic_output_shape_ops=True and we will still see DynamicOutputShapeException. For example, when an op doesn't have a meta kernel implemented to return the correct dynamic shape output. If we blindly give users instructions to set capture_dynamic_output_shape_ops to True, users would try it and see no change. As witnessed in this issue:
pytorch/pytorch#121036 (comment)

X-link: pytorch/pytorch#121681
Approved by: https://github.com/tugsbayasgalan

Reviewed By: osalpekar

Differential Revision: D54919382

Pulled By: gmagogsfm

fbshipit-source-id: 9d022c13a22a4201f26afec844a822467a63d71e
  • Loading branch information
gmagogsfm authored and facebook-github-bot committed Mar 15, 2024
1 parent 8f9a226 commit a762495
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,10 +1684,17 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
elif isinstance(
cause, torch._subclasses.fake_tensor.DynamicOutputShapeException
):
unimplemented(
f"dynamic shape operator: {cause.func}; "
"to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True"
)
if not torch._dynamo.config.capture_dynamic_output_shape_ops:
unimplemented(
f"dynamic shape operator: {cause.func}; "
"to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True"
)
else:
unimplemented(
f"dynamic shape operator: {cause.func}; "
"Operator does not have a meta kernel that supports dynamic output shapes, "
"please report an issue to PyTorch"
)
elif isinstance(
cause, torch._subclasses.fake_tensor.UnsupportedOperatorException
):
Expand Down

0 comments on commit a762495

Please sign in to comment.