diff --git a/backends/apple/mps/mps_preprocess.py b/backends/apple/mps/mps_preprocess.py index 8362774fa97..b798f63a6f4 100644 --- a/backends/apple/mps/mps_preprocess.py +++ b/backends/apple/mps/mps_preprocess.py @@ -32,6 +32,8 @@ CompileSpec, PreprocessResult, ) + +from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass from torch.export.exported_program import ExportedProgram FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -83,6 +85,9 @@ def preprocess( # FlatBuffer graph, process the `output` nodes and add their id to # the `output_ids` array in the schema. + # TODO: Remove this once we have a better support for the dim-order ops. + edge_program = DimOrderOpsRevertPass()(edge_program) + mps_graph = MPSGraph( version="0", mps_nodes=[], diff --git a/backends/apple/mps/operators/constant_ops.py b/backends/apple/mps/operators/constant_ops.py index dacb09215cb..9b3e3effb94 100644 --- a/backends/apple/mps/operators/constant_ops.py +++ b/backends/apple/mps/operators/constant_ops.py @@ -78,6 +78,15 @@ def define_node( ) ) +@register_node_visitor +class ToDimOrderEmptyVisitor(NodeVisitor): + target = ["exir_ops.edge.dim_order_ops._to_dim_order_copy.default"] + + def __init__(self, *args) -> None: + # We should never get here, because DimOrderOpsRevertPass replaces this with an aten.empty.memory_format op + # But if we do, we can't handle it ATM, so raise an exception + raise NotImplementedError("exir_ops.edge.dim_order_ops._to_dim_order_copy.default is not supported yet") + @register_node_visitor class FullLikeVisitor(NodeVisitor): diff --git a/backends/apple/mps/operators/op_clone.py b/backends/apple/mps/operators/op_clone.py index 2310ae02da7..3803fdf180b 100644 --- a/backends/apple/mps/operators/op_clone.py +++ b/backends/apple/mps/operators/op_clone.py @@ -33,3 +33,12 @@ def define_node( ) input_id = self.define_tensor(get_input_node(node, 0), mps_graph) self.tensor_to_id[node] = input_id + +@register_node_visitor +class ToDimOrderCopyVisitor(NodeVisitor): + target = ["exir_ops.edge.dim_order_ops._to_dim_order_copy.default"] + + def __init__(self, *args) -> None: + # We should never get here, because DimOrderOpsRevertPass replaces this with an aten._to_copy op + # But if we do, we can't handle it ATM, so raise an exception + raise NotImplementedError("exir_ops.edge.dim_order_ops._to_dim_order_copy.default is not supported yet")