From b34bb95fb7e60a8d994ef041aa21c59a665baa4b Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 13 Jan 2025 11:40:38 -0800 Subject: [PATCH] update dynamic shape detection (#7605) Summary: Updating Dynamic Shape Detection re: https://github.com/pytorch/executorch/issues/5794 Reviewed By: digantdesai Differential Revision: D68036835 --- backends/xnnpack/operators/op_squeeze.py | 6 ++++-- exir/backend/utils.py | 6 ++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/backends/xnnpack/operators/op_squeeze.py b/backends/xnnpack/operators/op_squeeze.py index 8ed5aa36ae6..7a21fe9e551 100644 --- a/backends/xnnpack/operators/op_squeeze.py +++ b/backends/xnnpack/operators/op_squeeze.py @@ -16,7 +16,9 @@ XNNStaticReshape, XNode, ) + from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node +from torch.fx.experimental.symbolic_shapes import free_symbols @register_node_visitor @@ -57,7 +59,7 @@ def define_node( num_dynamic_dims = 0 for dim in dynamic_shape: - if isinstance(dim, torch.SymInt): + if free_symbols(dim): num_dynamic_dims += 1 new_shape.append(0) else: @@ -119,7 +121,7 @@ def define_node( num_dynamic_dims = 0 for dim in dynamic_shape: - if isinstance(dim, torch.SymInt): + if free_symbols(dim): num_dynamic_dims += 1 new_shape.append(0) else: diff --git a/exir/backend/utils.py b/exir/backend/utils.py index 50d1e73fd7b..9487c59a848 100644 --- a/exir/backend/utils.py +++ b/exir/backend/utils.py @@ -23,6 +23,7 @@ from executorch.exir.lowered_backend_module import create_submodule_from_nodes from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param +from torch.fx.experimental.symbolic_shapes import has_free_symbols from torch.fx.node import Node from torch.fx.passes.utils.source_matcher_utils import SourcePartition @@ -424,10 +425,7 @@ def is_shape_dynamic(node: torch.fx.Node) -> bool: Check if the node shape is dynamic. """ - # Shape is dynamic if any of the dimensions don't evaluate to a static value - return "val" in node.meta and any( - isinstance(d, torch.SymInt) for d in node.meta["val"].shape - ) + return has_free_symbols(node.meta["val"].shape) # TODO - style: use templated types