Skip to content

Commit

Permalink
TGIF check untraced ShardedQuantEmbeddingCollection and ShardedQuantE…
Browse files Browse the repository at this point in the history
…mbeddingBagCollection before torchscript (pytorch#2237)

Summary:
Pull Request resolved: pytorch#2237

We saw several cases of user support request in TorchRec Users group where SQEBC or SQEC is not symbolic traced before torchscript. Here we check the module in TGIF before torchscript it. Give a more user friendly error message.

Reviewed By: jingsh

Differential Revision: D59863101

fbshipit-source-id: b434d2f49017689c4fdf65da7d467c3bff5a29e8
  • Loading branch information
gnahzg authored and facebook-github-bot committed Jul 19, 2024
1 parent 4c98f7b commit 7a7790b
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions torchrec/distributed/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,20 @@ def _recursive_get_module(
)

return trec_modules


def get_non_scriptable_trec_module(
model: torch.nn.Module,
) -> Dict[str, torch.nn.Module]:
"""
Get all targeted TorchRec modules in that model that is not torchsciptable before trace.
Args:
model (torch.nn.Module): The input module to search for TREC modules.
"""
return get_all_torchrec_modules(
model,
trec_module_class_types=[
ShardedQuantEmbeddingBagCollection,
ShardedQuantEmbeddingCollection,
],
)

0 comments on commit 7a7790b

Please sign in to comment.