Skip to content

Commit

Permalink
[HigherOrderOp] Move map_impl to torch.ops.higher_order (pytorch#111404)
Browse files Browse the repository at this point in the history
The purpose of this pr is as titled. Because of some misusage of ghstack, ghimport, and export to github from internal, the stack of pytorch#111092 is a mess. I'll try to land them one by one. This is a replacement for pytorch#111092 and pytorch#111400.

Pull Request resolved: pytorch#111404
Approved by: https://github.com/tugsbayasgalan, https://github.com/zou3519
  • Loading branch information
ydwu4 authored and pytorchmergebot committed Oct 26, 2023
1 parent f6f81a5 commit 8bc0b38
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion functorch/experimental/_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __call__(self, xs, *args):


map = MapWrapper("map", _deprecated_global_ns=True)
map_impl = HigherOrderOperator("map_impl", _deprecated_global_ns=True)
map_impl = HigherOrderOperator("map_impl")

dummy_aot_config = AOTConfig(
fw_compiler=None,
Expand Down
2 changes: 1 addition & 1 deletion test/export/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def _check_graph_nodes(gm1, gm2, _check_meta=True):
false_graph1 = getattr(gm1, node1.args[2].target)
false_graph2 = getattr(gm2, node2.args[2].target)
_check_graph_nodes(false_graph1, false_graph2)
elif node1.target == torch.ops.map_impl:
elif node1.target == torch.ops.higher_order.map_impl:
map_graph1 = getattr(gm1, node1.args[0].target)
map_graph2 = getattr(gm2, node2.args[0].target)
_check_graph_nodes(map_graph1, map_graph2, False)
Expand Down
7 changes: 4 additions & 3 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ def check_map_count(self, gm, op_count):
i = 0
for m in gm.modules():
for node in m.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.map_impl:
if node.op == "call_function" and node.target == torch.ops.higher_order.map_impl:
i += 1
self.assertEqual(i, op_count)

Expand Down Expand Up @@ -1041,7 +1041,7 @@ def count_mutable(gm):
c = 0
for node in gm.graph.nodes:
if node.op == "call_function":
if node.target == torch.ops.map_impl:
if node.target == torch.ops.higher_order.map_impl:
c += count_mutable(getattr(gm, str(node.args[0])))
elif schema := getattr(node.target, "_schema", None):
c += int(schema.is_mutable)
Expand Down Expand Up @@ -1446,7 +1446,8 @@ def wrapper(*args, **kwargs):
self.assertExpectedInline(gm.code.strip(), """\
def forward(self, pred_1, x_1):
body_graph_0 = self.body_graph_0
map_impl = torch.ops.map_impl(body_graph_0, 1, x_1, pred_1); body_graph_0 = x_1 = pred_1 = None
map_impl = torch.ops.higher_order.map_impl(body_graph_0, 1, x_1, pred_1);\
body_graph_0 = x_1 = pred_1 = None
getitem = map_impl[0]; map_impl = None
return getitem""")
self.assertExpectedInline(gm.body_graph_0.code.strip(), """\
Expand Down

0 comments on commit 8bc0b38

Please sign in to comment.