Skip to content

Commit

Permalink
short circuit the flatten/unflatten between EBC and KTRegroupAsDict m…
Browse files Browse the repository at this point in the history
…odules (pytorch#2393)

Summary:
X-link: pytorch/pytorch#136045

Pull Request resolved: pytorch#2393

# context
* for the root cause and background please refer to this [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/)
* basica idea of this diff is to **short circuit the pytree flatten-unflatten function pairs** between two preserved modules, i.e., EBC/fpEBC and KTRegroupAsDict.
NOTE: There could be multiple EBCs and one single KTRegroupAsDict as shown in the [pic](https://fburl.com/gslide/lcyt8eh3) {F1864810545}
* short-circuiting the EBC-KTRegroupAsDict pairs are very special and a must in most of the cases due to the EBC key-order issue with distributed table lookup.
* hide all the operations behind a control flag `short_circuit_pytree_ebc_regroup` to the torchrec main api call `decapsulate_ir_modules`, which should only be visible to the infra layer, not to the users.

# details
* The `_short_circuit_pytree_ebc_regroup` function finds all the EBCs/fpEBC and KTRegroupAsDict modules in an unflattened module.  Retrieve their fqns and sort to in_fqns (regroup_fqns) and out_fqns (ebc_fqns). Because currently the fpEBC is swapped as a whole, so we do some extra fqn logic to filter out the EBC that belongs to an up-level fpEBC.
* a util function `prune_pytree_flatten_unflatten` removes the in-coming and out-going pytree flatten/unflatten function calls in the graph module, based on the given fqns.

WARNING: The flag `short_circuit_pytree_ebc_regroup` should be turned on if EBCs are used and EBC sharding is needed. Assertions are also added if can't find a `KTRegroupAsDict` module, or `finalize_interpreter_modules` is not `True`.

# additional changes
* absorb the `finalize_interpreter_modules` process inside the torchrec main api `decapsulate_ir_modules`.
* set `graph.owning_module` in export.unflatten as required by the graph modification
* add one more layer of `sparse_module` for closely mimicing the APF model structure.

Differential Revision: D62606738
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Sep 13, 2024
1 parent 15c912e commit 0ce7346
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 3 deletions.
23 changes: 20 additions & 3 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def test_key_order_with_ebc_and_regroup(self) -> None:
ebc2.load_state_dict(ebc1.state_dict())
regroup = KTRegroupAsDict([["f1", "f3"], ["f2"]], ["odd", "even"])

class myModel(nn.Module):
class mySparse(nn.Module):
def __init__(self, ebc, regroup):
super().__init__()
self.ebc = ebc
Expand All @@ -569,6 +569,17 @@ def forward(
) -> Dict[str, torch.Tensor]:
return self.regroup([self.ebc(features)])

class myModel(nn.Module):
def __init__(self, ebc, regroup):
super().__init__()
self.sparse = mySparse(ebc, regroup)

def forward(
self,
features: KeyedJaggedTensor,
) -> Dict[str, torch.Tensor]:
return self.sparse(features)

model = myModel(ebc1, regroup)
eager_out = model(id_list_features)

Expand All @@ -582,11 +593,17 @@ def forward(
preserve_module_call_signature=(tuple(sparse_fqns)),
)
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
deserialized_model = decapsulate_ir_modules(
unflatten_ep,
JsonSerializer,
short_circuit_pytree_ebc_regroup=True,
finalize_interpreter_modules=True,
)

# we export the model with ebc1 and unflatten the model,
# and then swap with ebc2 (you can think this as the the sharding process
# resulting a shardedEBC), so that we can mimic the key-order change
deserialized_model.ebc = ebc2
deserialized_model.sparse.ebc = ebc2

deserialized_out = deserialized_model(id_list_features)
for key in eager_out.keys():
Expand Down
108 changes: 108 additions & 0 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#!/usr/bin/env python3

import logging
import operator
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Type

Expand All @@ -18,7 +19,12 @@
from torch import nn
from torch.export import Dim, ShapesCollection
from torch.export.dynamic_shapes import _Dim as DIM
from torch.export.unflatten import InterpreterModule
from torch.fx import Node
from torchrec.ir.types import SerializerInterface
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


Expand Down Expand Up @@ -129,6 +135,8 @@ def decapsulate_ir_modules(
module: nn.Module,
serializer: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
device: Optional[torch.device] = None,
finalize_interpreter_modules: bool = False,
short_circuit_pytree_ebc_regroup: bool = False,
) -> nn.Module:
"""
Takes a module and decapsulate its embedding modules by retrieving the buffer.
Expand All @@ -147,6 +155,16 @@ def decapsulate_ir_modules(
# we use "ir_metadata" as a convention to identify the deserializable module
if "ir_metadata" in dict(module.named_buffers()):
module = serializer.decapsulate_module(module, device)

if short_circuit_pytree_ebc_regroup:
module = _short_circuit_pytree_ebc_regroup(module)
assert finalize_interpreter_modules, "need finalize_interpreter_modules=True"

if finalize_interpreter_modules:
for mod in module.modules():
if isinstance(mod, InterpreterModule):
mod.finalize()

return module


Expand Down Expand Up @@ -233,3 +251,93 @@ def move_to_copy_nodes_to_device(
nodes.kwargs = new_kwargs

return unflattened_module


def _short_circuit_pytree_ebc_regroup(module: nn.Module) -> nn.Module:
"""
Bypass pytree flatten and unflatten function between EBC and KTRegroupAsDict to avoid key-order issue.
https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/
EBC ==> (out-going) pytree.flatten ==> tensors and specs ==> (in-coming) pytree.unflatten ==> KTRegroupAsDict
"""
ebc_fqns: List[str] = []
regroup_fqns: List[str] = []
for fqn, m in module.named_modules():
if isinstance(m, FeatureProcessedEmbeddingBagCollection):
ebc_fqns.append(fqn)
elif isinstance(m, EmbeddingBagCollection):
if len(ebc_fqns) > 0 and fqn.startswith(ebc_fqns[-1]):
continue
ebc_fqns.append(fqn)
elif isinstance(m, KTRegroupAsDict):
regroup_fqns.append(fqn)
if (len(ebc_fqns) == 0) != (len(regroup_fqns) == 0):
logger.warning("Perf impact if EBC and KTRegroupAsDict are not used together.")
return module
else:
return prune_pytree_flatten_unflatten(
module, in_fqns=regroup_fqns, out_fqns=ebc_fqns
)


def prune_pytree_flatten_unflatten(
module: nn.Module, in_fqns: List[str], out_fqns: List[str]
) -> nn.Module:
"""
Remove pytree flatten and unflatten function between the given in_fqns and out_fqns.
"preserved module" ==> (out-going) pytree.flatten ==> [tensors and specs]
[tensors and specs] ==> (in-coming) pytree.unflatten ==> "preserved module"
"""

def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]:
for node in mod.graph.nodes:
if node.op == "call_module" and node.target == fqn:
return mod, node
assert "." in fqn, f"can't find {fqn} in the graph of {mod}"
curr, fqn = fqn.split(".", maxsplit=1)
mod = getattr(mod, curr)
return _get_graph_node(mod, fqn)

# remove tree_unflatten from the in_fqns (in-coming nodes)
for fqn in in_fqns:
submodule, node = _get_graph_node(module, fqn)
assert len(node.args) == 1
getitem_getitem: Node = node.args[0] # pyre-ignore[9]
assert (
getitem_getitem.op == "call_function"
and getitem_getitem.target == operator.getitem
)
tree_unflatten_getitem = node.args[0].args[0] # pyre-ignore[16]
assert (
tree_unflatten_getitem.op == "call_function"
and tree_unflatten_getitem.target == operator.getitem
)
tree_unflatten = tree_unflatten_getitem.args[0]
assert (
tree_unflatten.op == "call_function"
and tree_unflatten.target == torch.utils._pytree.tree_unflatten
)
logger.info(f"Removing tree_unflatten from {fqn}")
input_nodes = tree_unflatten.args[0]
node.args = (input_nodes,)
submodule.graph.eliminate_dead_code()

# remove tree_flatten_spec from the out_fqns (out-going nodes)
for fqn in out_fqns:
submodule, node = _get_graph_node(module, fqn)
users = list(node.users.keys())
assert (
len(users) == 1
and users[0].op == "call_function"
and users[0].target == torch.fx._pytree.tree_flatten_spec
)
tree_flatten_users = list(users[0].users.keys())
assert (
len(tree_flatten_users) == 1
and tree_flatten_users[0].op == "call_function"
and tree_flatten_users[0].target == operator.getitem
)
logger.info(f"Removing tree_flatten_spec from {fqn}")
getitem_node = tree_flatten_users[0]
getitem_node.replace_all_uses_with(node)
submodule.graph.eliminate_dead_code()
return module

0 comments on commit 0ce7346

Please sign in to comment.