From 7a7790be63935f3b3fbddf21dcd8e13e3d09a053 Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Fri, 19 Jul 2024 10:53:50 -0700 Subject: [PATCH] TGIF check untraced ShardedQuantEmbeddingCollection and ShardedQuantEmbeddingBagCollection before torchscript (#2237) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2237 We saw several cases of user support request in TorchRec Users group where SQEBC or SQEC is not symbolic traced before torchscript. Here we check the module in TGIF before torchscript it. Give a more user friendly error message. Reviewed By: jingsh Differential Revision: D59863101 fbshipit-source-id: b434d2f49017689c4fdf65da7d467c3bff5a29e8 --- torchrec/distributed/infer_utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/torchrec/distributed/infer_utils.py b/torchrec/distributed/infer_utils.py index 5cf6a7890..8171570fb 100644 --- a/torchrec/distributed/infer_utils.py +++ b/torchrec/distributed/infer_utils.py @@ -175,3 +175,20 @@ def _recursive_get_module( ) return trec_modules + + +def get_non_scriptable_trec_module( + model: torch.nn.Module, +) -> Dict[str, torch.nn.Module]: + """ + Get all targeted TorchRec modules in that model that is not torchsciptable before trace. + Args: + model (torch.nn.Module): The input module to search for TREC modules. + """ + return get_all_torchrec_modules( + model, + trec_module_class_types=[ + ShardedQuantEmbeddingBagCollection, + ShardedQuantEmbeddingCollection, + ], + )