diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py index 1f4ff6b35f729..bc6fbe5ff4a9c 100644 --- a/functorch/experimental/_map.py +++ b/functorch/experimental/_map.py @@ -35,7 +35,7 @@ def __call__(self, xs, *args): map = MapWrapper("map", _deprecated_global_ns=True) -map_impl = HigherOrderOperator("map_impl", _deprecated_global_ns=True) +map_impl = HigherOrderOperator("map_impl") dummy_aot_config = AOTConfig( fw_compiler=None, diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 5a6018f2a0a78..d3c775ef255ec 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -257,7 +257,7 @@ def _check_graph_nodes(gm1, gm2, _check_meta=True): false_graph1 = getattr(gm1, node1.args[2].target) false_graph2 = getattr(gm2, node2.args[2].target) _check_graph_nodes(false_graph1, false_graph2) - elif node1.target == torch.ops.map_impl: + elif node1.target == torch.ops.higher_order.map_impl: map_graph1 = getattr(gm1, node1.args[0].target) map_graph2 = getattr(gm2, node2.args[0].target) _check_graph_nodes(map_graph1, map_graph2, False) diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 9fd4e0025a9eb..2407404672ec1 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -885,7 +885,7 @@ def check_map_count(self, gm, op_count): i = 0 for m in gm.modules(): for node in m.graph.nodes: - if node.op == "call_function" and node.target == torch.ops.map_impl: + if node.op == "call_function" and node.target == torch.ops.higher_order.map_impl: i += 1 self.assertEqual(i, op_count) @@ -1041,7 +1041,7 @@ def count_mutable(gm): c = 0 for node in gm.graph.nodes: if node.op == "call_function": - if node.target == torch.ops.map_impl: + if node.target == torch.ops.higher_order.map_impl: c += count_mutable(getattr(gm, str(node.args[0]))) elif schema := getattr(node.target, "_schema", None): c += int(schema.is_mutable) @@ -1446,7 +1446,8 @@ def wrapper(*args, **kwargs): self.assertExpectedInline(gm.code.strip(), """\ def forward(self, pred_1, x_1): body_graph_0 = self.body_graph_0 - map_impl = torch.ops.map_impl(body_graph_0, 1, x_1, pred_1); body_graph_0 = x_1 = pred_1 = None + map_impl = torch.ops.higher_order.map_impl(body_graph_0, 1, x_1, pred_1);\ + body_graph_0 = x_1 = pred_1 = None getitem = map_impl[0]; map_impl = None return getitem""") self.assertExpectedInline(gm.body_graph_0.code.strip(), """\