forked from pytorch/torchrec
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
short circuit the flatten/unflatten between EBC and KTRegroupAsDict m…
…odules (pytorch#2393) Summary: X-link: pytorch/pytorch#136045 Pull Request resolved: pytorch#2393 # context * for the root cause and background please refer to this [post](https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/) * basica idea of this diff is to **short circuit the pytree flatten-unflatten function pairs** between two preserved modules, i.e., EBC/fpEBC and KTRegroupAsDict. NOTE: There could be multiple EBCs and one single KTRegroupAsDict as shown in the [pic](https://fburl.com/gslide/lcyt8eh3) {F1864810545} * short-circuiting the EBC-KTRegroupAsDict pairs are very special and a must in most of the cases due to the EBC key-order issue with distributed table lookup. * hide all the operations behind a control flag `short_circuit_pytree_ebc_regroup` to the torchrec main api call `decapsulate_ir_modules`, which should only be visible to the infra layer, not to the users. # details * The `_short_circuit_pytree_ebc_regroup` function finds all the EBCs/fpEBC and KTRegroupAsDict modules in an unflattened module. Retrieve their fqns and sort to in_fqns (regroup_fqns) and out_fqns (ebc_fqns). Because currently the fpEBC is swapped as a whole, so we do some extra fqn logic to filter out the EBC that belongs to an up-level fpEBC. * a util function `prune_pytree_flatten_unflatten` removes the in-coming and out-going pytree flatten/unflatten function calls in the graph module, based on the given fqns. WARNING: The flag `short_circuit_pytree_ebc_regroup` should be turned on if EBCs are used and EBC sharding is needed. Assertions are also added if can't find a `KTRegroupAsDict` module, or `finalize_interpreter_modules` is not `True`. # additional changes * absorb the `finalize_interpreter_modules` process inside the torchrec main api `decapsulate_ir_modules`. * set `graph.owning_module` in export.unflatten as required by the graph modification * add one more layer of `sparse_module` for closely mimicing the APF model structure. Differential Revision: D62606738
- Loading branch information
1 parent
15c912e
commit 0ce7346
Showing
2 changed files
with
128 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters