diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 8b6926553..805eea0bc 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -557,7 +557,7 @@ def test_key_order_with_ebc_and_regroup(self) -> None: ebc2.load_state_dict(ebc1.state_dict()) regroup = KTRegroupAsDict([["f1", "f3"], ["f2"]], ["odd", "even"]) - class myModel(nn.Module): + class mySparse(nn.Module): def __init__(self, ebc, regroup): super().__init__() self.ebc = ebc @@ -569,6 +569,17 @@ def forward( ) -> Dict[str, torch.Tensor]: return self.regroup([self.ebc(features)]) + class myModel(nn.Module): + def __init__(self, ebc, regroup): + super().__init__() + self.sparse = mySparse(ebc, regroup) + + def forward( + self, + features: KeyedJaggedTensor, + ) -> Dict[str, torch.Tensor]: + return self.sparse(features) + model = myModel(ebc1, regroup) eager_out = model(id_list_features) @@ -582,11 +593,17 @@ def forward( preserve_module_call_signature=(tuple(sparse_fqns)), ) unflatten_ep = torch.export.unflatten(ep) - deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) + deserialized_model = decapsulate_ir_modules( + unflatten_ep, + JsonSerializer, + short_circuit_pytree_ebc_regroup=True, + finalize_interpreter_modules=True, + ) + # we export the model with ebc1 and unflatten the model, # and then swap with ebc2 (you can think this as the the sharding process # resulting a shardedEBC), so that we can mimic the key-order change - deserialized_model.ebc = ebc2 + deserialized_model.sparse.ebc = ebc2 deserialized_out = deserialized_model(id_list_features) for key in eager_out.keys(): diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index dc4fc8f1f..19f537092 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -10,6 +10,7 @@ #!/usr/bin/env python3 import logging +import operator from collections import defaultdict from typing import Dict, List, Optional, Tuple, Type @@ -18,7 +19,12 @@ from torch import nn from torch.export import Dim, ShapesCollection from torch.export.dynamic_shapes import _Dim as DIM +from torch.export.unflatten import InterpreterModule +from torch.fx import Node from torchrec.ir.types import SerializerInterface +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.modules.regroup import KTRegroupAsDict from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -129,6 +135,8 @@ def decapsulate_ir_modules( module: nn.Module, serializer: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS, device: Optional[torch.device] = None, + finalize_interpreter_modules: bool = False, + short_circuit_pytree_ebc_regroup: bool = False, ) -> nn.Module: """ Takes a module and decapsulate its embedding modules by retrieving the buffer. @@ -147,6 +155,16 @@ def decapsulate_ir_modules( # we use "ir_metadata" as a convention to identify the deserializable module if "ir_metadata" in dict(module.named_buffers()): module = serializer.decapsulate_module(module, device) + + if short_circuit_pytree_ebc_regroup: + module = _short_circuit_pytree_ebc_regroup(module) + assert finalize_interpreter_modules, "need finalize_interpreter_modules=True" + + if finalize_interpreter_modules: + for mod in module.modules(): + if isinstance(mod, InterpreterModule): + mod.finalize() + return module @@ -233,3 +251,93 @@ def move_to_copy_nodes_to_device( nodes.kwargs = new_kwargs return unflattened_module + + +def _short_circuit_pytree_ebc_regroup(module: nn.Module) -> nn.Module: + """ + Bypass pytree flatten and unflatten function between EBC and KTRegroupAsDict to avoid key-order issue. + https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/ + EBC ==> (out-going) pytree.flatten ==> tensors and specs ==> (in-coming) pytree.unflatten ==> KTRegroupAsDict + """ + ebc_fqns: List[str] = [] + regroup_fqns: List[str] = [] + for fqn, m in module.named_modules(): + if isinstance(m, FeatureProcessedEmbeddingBagCollection): + ebc_fqns.append(fqn) + elif isinstance(m, EmbeddingBagCollection): + if len(ebc_fqns) > 0 and fqn.startswith(ebc_fqns[-1]): + continue + ebc_fqns.append(fqn) + elif isinstance(m, KTRegroupAsDict): + regroup_fqns.append(fqn) + if (len(ebc_fqns) == 0) != (len(regroup_fqns) == 0): + logger.warning("Perf impact if EBC and KTRegroupAsDict are not used together.") + return module + else: + return prune_pytree_flatten_unflatten( + module, in_fqns=regroup_fqns, out_fqns=ebc_fqns + ) + + +def prune_pytree_flatten_unflatten( + module: nn.Module, in_fqns: List[str], out_fqns: List[str] +) -> nn.Module: + """ + Remove pytree flatten and unflatten function between the given in_fqns and out_fqns. + "preserved module" ==> (out-going) pytree.flatten ==> [tensors and specs] + [tensors and specs] ==> (in-coming) pytree.unflatten ==> "preserved module" + """ + + def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]: + for node in mod.graph.nodes: + if node.op == "call_module" and node.target == fqn: + return mod, node + assert "." in fqn, f"can't find {fqn} in the graph of {mod}" + curr, fqn = fqn.split(".", maxsplit=1) + mod = getattr(mod, curr) + return _get_graph_node(mod, fqn) + + # remove tree_unflatten from the in_fqns (in-coming nodes) + for fqn in in_fqns: + submodule, node = _get_graph_node(module, fqn) + assert len(node.args) == 1 + getitem_getitem: Node = node.args[0] # pyre-ignore[9] + assert ( + getitem_getitem.op == "call_function" + and getitem_getitem.target == operator.getitem + ) + tree_unflatten_getitem = node.args[0].args[0] # pyre-ignore[16] + assert ( + tree_unflatten_getitem.op == "call_function" + and tree_unflatten_getitem.target == operator.getitem + ) + tree_unflatten = tree_unflatten_getitem.args[0] + assert ( + tree_unflatten.op == "call_function" + and tree_unflatten.target == torch.utils._pytree.tree_unflatten + ) + logger.info(f"Removing tree_unflatten from {fqn}") + input_nodes = tree_unflatten.args[0] + node.args = (input_nodes,) + submodule.graph.eliminate_dead_code() + + # remove tree_flatten_spec from the out_fqns (out-going nodes) + for fqn in out_fqns: + submodule, node = _get_graph_node(module, fqn) + users = list(node.users.keys()) + assert ( + len(users) == 1 + and users[0].op == "call_function" + and users[0].target == torch.fx._pytree.tree_flatten_spec + ) + tree_flatten_users = list(users[0].users.keys()) + assert ( + len(tree_flatten_users) == 1 + and tree_flatten_users[0].op == "call_function" + and tree_flatten_users[0].target == operator.getitem + ) + logger.info(f"Removing tree_flatten_spec from {fqn}") + getitem_node = tree_flatten_users[0] + getitem_node.replace_all_uses_with(node) + submodule.graph.eliminate_dead_code() + return module