diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index ad46d9d69..8ff170992 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -12,7 +12,7 @@ import unittest from dataclasses import dataclass from functools import partial -from typing import cast, List, Optional, Tuple, Type +from typing import cast, List, Optional, Tuple, Type, Union from unittest.mock import MagicMock import torch @@ -1366,6 +1366,54 @@ def test_pipeline_preproc_fwd_values_cached(self) -> None: ["_preproc_module", "preproc_nonweighted", "preproc_weighted"], ) + # pyre-ignore + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_nested_preproc(self) -> None: + """ + If preproc module is nested, we should still be able to pipeline it + """ + extra_input = ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=self.batch_size, + world_size=1, + num_float_features=10, + randomize_indices=False, + )[0].to(self.device) + + preproc_module = TestNegSamplingModule( + extra_input=extra_input, + ) + model = self._setup_model(preproc_module=preproc_module) + + class ParentModule(nn.Module): + def __init__( + self, + nested_model: nn.Module, + ) -> None: + super().__init__() + self.nested_model = nested_model + + def forward( + self, + input: ModelInput, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + return self.nested_model(input) + + model = ParentModule(model) + + pipelined_model, pipeline = self._check_output_equal( + model, + self.sharding_type, + ) + + # Check that both EC and EBC pipelined + self.assertEqual(len(pipeline._pipelined_modules), 2) + self.assertEqual(len(pipeline._pipelined_preprocs), 1) + class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase): @unittest.skipIf( diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 50c08eedb..4e28a12ba 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -33,6 +33,10 @@ from torch import distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.fx.immutable_collections import ( + immutable_dict as fx_immutable_dict, + immutable_list as fx_immutable_list, +) from torch.fx.node import Node from torch.profiler import record_function from torchrec.distributed.dist_data import KJTAllToAll, KJTAllToAllTensorsAwaitable @@ -156,6 +160,9 @@ class ArgInfo: input_attrs (List[str]): attributes of input batch, e.g. `batch.attr1.attr2` will produce ["attr1", "attr2"]. is_getitems (List[bool]): `batch[attr1].attr2` will produce [True, False]. + preproc_modules (List[Optional[PipelinedPreproc]]): list of torch.nn.Modules that + transform the input batch. + constants: constant arguments that are passed to preproc modules. name (Optional[str]): name for kwarg of pipelined forward() call or None for a positional arg. """ @@ -164,6 +171,7 @@ class ArgInfo: is_getitems: List[bool] # recursive dataclass as preproc_modules.args -> arginfo.preproc_modules -> so on preproc_modules: List[Optional["PipelinedPreproc"]] + constants: List[Optional[object]] name: Optional[str] @@ -178,10 +186,16 @@ def _build_args_kwargs( for arg_info in fwd_args: if arg_info.input_attrs: arg = initial_input - for attr, is_getitem, preproc_mod in zip( - arg_info.input_attrs, arg_info.is_getitems, arg_info.preproc_modules + for attr, is_getitem, preproc_mod, obj in zip( + arg_info.input_attrs, + arg_info.is_getitems, + arg_info.preproc_modules, + arg_info.constants, ): - if preproc_mod is not None: + if obj is not None: + arg = obj + break + elif preproc_mod is not None: # preproc will internally run the same logic recursively # if its args are derived from other preproc modules # we can get all inputs to preproc mod based on its recorded args_info + arg passed to it @@ -682,6 +696,46 @@ def _check_preproc_pipelineable( return True +def _find_preproc_module_recursive( + module: torch.nn.Module, + preproc_module_fqn: str, +) -> Optional[torch.nn.Module]: + """ + Finds the preproc module in the model. + """ + for name, child in module.named_modules(): + if name == preproc_module_fqn: + return child + return None + + +def _swap_preproc_module_recursive( + module: torch.nn.Module, + to_swap_module: torch.nn.Module, + preproc_module_fqn: str, + path: str = "", +) -> torch.nn.Module: + """ + Swaps the preproc module in the model. + """ + if isinstance(module, PipelinedPreproc): + return module + + if path == preproc_module_fqn: + return to_swap_module + + for name, child in module.named_children(): + child = _swap_preproc_module_recursive( + child, + to_swap_module, + preproc_module_fqn, + path + "." + name if path else name, + ) + setattr(module, name, child) + + return module + + def _get_node_args_helper( model: torch.nn.Module, # pyre-ignore @@ -695,13 +749,23 @@ def _get_node_args_helper( Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. It also counts the number of (args + kwargs) found. """ - arg_info_list = [ArgInfo([], [], [], None) for _ in range(len(arguments))] + arg_info_list = [ArgInfo([], [], [], [], None) for _ in range(len(arguments))] for arg, arg_info in zip(arguments, arg_info_list): if arg is None: num_found += 1 continue while True: if not isinstance(arg, torch.fx.Node): + if pipeline_preproc: + arg_info.input_attrs.insert(0, "") + arg_info.is_getitems.insert(0, False) + arg_info.preproc_modules.insert(0, None) + if isinstance(arg, (fx_immutable_dict, fx_immutable_list)): + # Make them mutable again, in case in-place updates are made + arg_info.constants.insert(0, arg.copy()) + else: + arg_info.constants.insert(0, arg) + num_found += 1 break child_node = arg @@ -719,11 +783,13 @@ def _get_node_args_helper( arg_info.input_attrs.append(key) arg_info.is_getitems.append(False) arg_info.preproc_modules.append(None) + arg_info.constants.append(None) else: # no-op arg_info.input_attrs.insert(0, "") arg_info.is_getitems.insert(0, False) arg_info.preproc_modules.insert(0, None) + arg_info.constants.insert(0, None) num_found += 1 break @@ -740,6 +806,7 @@ def _get_node_args_helper( arg_info.input_attrs.insert(0, child_node.args[1]) arg_info.is_getitems.insert(0, False) arg_info.preproc_modules.insert(0, None) + arg_info.constants.insert(0, None) arg = child_node.args[0] elif ( child_node.op == "call_function" @@ -754,6 +821,7 @@ def _get_node_args_helper( arg_info.input_attrs.insert(0, child_node.args[1]) arg_info.is_getitems.insert(0, True) arg_info.preproc_modules.insert(0, None) + arg_info.constants.insert(0, None) arg = child_node.args[0] elif ( child_node.op == "call_function" @@ -795,10 +863,13 @@ def _get_node_args_helper( arg_info.input_attrs.insert(0, child_node.args[1]) arg_info.is_getitems.insert(0, True) arg_info.preproc_modules.insert(0, None) + arg_info.constants.insert(0, None) arg = child_node.args[0] elif child_node.op == "call_module": preproc_module_fqn = str(child_node.target) - preproc_module = getattr(model, preproc_module_fqn, None) + preproc_module = _find_preproc_module_recursive( + model, preproc_module_fqn + ) if not pipeline_preproc: logger.warning( @@ -816,6 +887,7 @@ def _get_node_args_helper( arg_info.is_getitems.insert(0, False) pipelined_preprocs.add(preproc_module) arg_info.preproc_modules.insert(0, preproc_module) + arg_info.constants.insert(0, None) num_found += 1 break @@ -858,12 +930,15 @@ def _get_node_args_helper( ) # module swap - setattr(model, preproc_module_fqn, pipelined_preproc_module) + _swap_preproc_module_recursive( + model, pipelined_preproc_module, preproc_module_fqn + ) arg_info.input_attrs.insert(0, "") # dummy value arg_info.is_getitems.insert(0, False) pipelined_preprocs.add(pipelined_preproc_module) arg_info.preproc_modules.insert(0, pipelined_preproc_module) + arg_info.constants.insert(0, None) num_found += 1