Skip to content

Commit

Permalink
Fix list assumption about permute arguments
Browse files Browse the repository at this point in the history
Differential Revision: D67765363

Pull Request resolved: #7472
  • Loading branch information
dulinriley authored Jan 2, 2025
1 parent a861294 commit 2600cc8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions backends/cadence/aot/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,16 @@ def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:


# Capture the effect of permute op on incoming dimension order
def get_permuted_dims(node: torch.fx.Node, dims: Optional[List[int]]) -> List[int]:
def get_permuted_dims(node: torch.fx.Node, dims: Optional[Sequence[int]]) -> List[int]:
"""
Given a permute node, and the incoming dimension ordering of the input
tensor to the permute node, return the net effect of permute op on the
dimension order.
"""
assert node.target == exir_ops.edge.aten.permute_copy.default
# Permute each index of the dimension ordering (dims)
permute_dims = node.args[1]
assert isinstance(permute_dims, List)
# pyre-fixme[6]: This combined typecheck isn't supported yet.
permute_dims: List[int] = list(node.args[1])
assert all(isinstance(x, int) for x in permute_dims)
# If the dims is empty, we can simply return the permute order
if not dims:
Expand Down
6 changes: 3 additions & 3 deletions backends/cadence/aot/reorder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,9 +438,9 @@ def postpone_dequantize_op(self, graph_module: torch.fx.GraphModule) -> bool:
args=(user, *node.args[1:]),
)
dequant_node.meta = user.meta.copy()
# Remove meta["debug_handle"] on new node. Reassign it at the
# caller level by calling generate_missing_debug_handles
dequant_node.meta.pop("debug_handle")
# Remove meta["debug_handle"] on new node if it exists.
# Reassign it at the caller level by calling generate_missing_debug_handles
dequant_node.meta.pop("debug_handle", None)
user.replace_all_uses_with(dequant_node)
dequant_node.args = (user, *node.args[1:])

Expand Down

0 comments on commit 2600cc8

Please sign in to comment.