Skip to content

Commit

Permalink
Allow pipelining of constant values for preproc + handle nested prepr…
Browse files Browse the repository at this point in the history
…ocs (pytorch#2342)

Summary:
Pull Request resolved: pytorch#2342

Ran into 2 issues while enabling pipeline for a model:
1) Current pipeline logic for finding and swapping a preproc module only works if the preproc module exists at model level. If the preproc is within a model's child modules, this logic would break down e.g. `model._sparse_arch._preproc_module`. Finding a module would not work as this used `getattr` on the model and swapping the module would fail as this used `setattr` on the model. Solution:
   - Replaced `getattr` and `setattr` with `_find_preproc_module_recursive` and `_swap_preproc_module_recursive` respectively.

2) Logic doesn't support if an arg to a preproc module is a constant (e.g. `self.model.constant_value`) as we skip args that aren't `torch.fx.Node` values. However, we should be able to pipeline these cases. Solution:
    - Add a new field to `ArgInfo` called `objects` of type `List[Optional[object]]`. After fx tracing, you will have fx immutable collections, such as `torch.fx.immutable_dict` for immutable `Dict`. Creating a copy converts it back to mutable original value. So we capture this variable in `ArgInfo`. Potential downside is the extra memory overhead, but for this model in particular, this was just a small string value.

Reviewed By: xing-liu

Differential Revision: D61891459

fbshipit-source-id: fe5d074f8b8937a6154596221a57e2d5213ffc36
  • Loading branch information
sarckk authored and facebook-github-bot committed Aug 28, 2024
1 parent 01cb500 commit 6c8b397
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import unittest
from dataclasses import dataclass
from functools import partial
from typing import cast, List, Optional, Tuple, Type
from typing import cast, List, Optional, Tuple, Type, Union
from unittest.mock import MagicMock

import torch
Expand Down Expand Up @@ -1366,6 +1366,54 @@ def test_pipeline_preproc_fwd_values_cached(self) -> None:
["_preproc_module", "preproc_nonweighted", "preproc_weighted"],
)

# pyre-ignore
@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
def test_nested_preproc(self) -> None:
"""
If preproc module is nested, we should still be able to pipeline it
"""
extra_input = ModelInput.generate(
tables=self.tables,
weighted_tables=self.weighted_tables,
batch_size=self.batch_size,
world_size=1,
num_float_features=10,
randomize_indices=False,
)[0].to(self.device)

preproc_module = TestNegSamplingModule(
extra_input=extra_input,
)
model = self._setup_model(preproc_module=preproc_module)

class ParentModule(nn.Module):
def __init__(
self,
nested_model: nn.Module,
) -> None:
super().__init__()
self.nested_model = nested_model

def forward(
self,
input: ModelInput,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
return self.nested_model(input)

model = ParentModule(model)

pipelined_model, pipeline = self._check_output_equal(
model,
self.sharding_type,
)

# Check that both EC and EBC pipelined
self.assertEqual(len(pipeline._pipelined_modules), 2)
self.assertEqual(len(pipeline._pipelined_preprocs), 1)


class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase):
@unittest.skipIf(
Expand Down
87 changes: 81 additions & 6 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@

from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.fx.immutable_collections import (
immutable_dict as fx_immutable_dict,
immutable_list as fx_immutable_list,
)
from torch.fx.node import Node
from torch.profiler import record_function
from torchrec.distributed.dist_data import KJTAllToAll, KJTAllToAllTensorsAwaitable
Expand Down Expand Up @@ -156,6 +160,9 @@ class ArgInfo:
input_attrs (List[str]): attributes of input batch,
e.g. `batch.attr1.attr2` will produce ["attr1", "attr2"].
is_getitems (List[bool]): `batch[attr1].attr2` will produce [True, False].
preproc_modules (List[Optional[PipelinedPreproc]]): list of torch.nn.Modules that
transform the input batch.
constants: constant arguments that are passed to preproc modules.
name (Optional[str]): name for kwarg of pipelined forward() call or None for a
positional arg.
"""
Expand All @@ -164,6 +171,7 @@ class ArgInfo:
is_getitems: List[bool]
# recursive dataclass as preproc_modules.args -> arginfo.preproc_modules -> so on
preproc_modules: List[Optional["PipelinedPreproc"]]
constants: List[Optional[object]]
name: Optional[str]


Expand All @@ -178,10 +186,16 @@ def _build_args_kwargs(
for arg_info in fwd_args:
if arg_info.input_attrs:
arg = initial_input
for attr, is_getitem, preproc_mod in zip(
arg_info.input_attrs, arg_info.is_getitems, arg_info.preproc_modules
for attr, is_getitem, preproc_mod, obj in zip(
arg_info.input_attrs,
arg_info.is_getitems,
arg_info.preproc_modules,
arg_info.constants,
):
if preproc_mod is not None:
if obj is not None:
arg = obj
break
elif preproc_mod is not None:
# preproc will internally run the same logic recursively
# if its args are derived from other preproc modules
# we can get all inputs to preproc mod based on its recorded args_info + arg passed to it
Expand Down Expand Up @@ -682,6 +696,46 @@ def _check_preproc_pipelineable(
return True


def _find_preproc_module_recursive(
module: torch.nn.Module,
preproc_module_fqn: str,
) -> Optional[torch.nn.Module]:
"""
Finds the preproc module in the model.
"""
for name, child in module.named_modules():
if name == preproc_module_fqn:
return child
return None


def _swap_preproc_module_recursive(
module: torch.nn.Module,
to_swap_module: torch.nn.Module,
preproc_module_fqn: str,
path: str = "",
) -> torch.nn.Module:
"""
Swaps the preproc module in the model.
"""
if isinstance(module, PipelinedPreproc):
return module

if path == preproc_module_fqn:
return to_swap_module

for name, child in module.named_children():
child = _swap_preproc_module_recursive(
child,
to_swap_module,
preproc_module_fqn,
path + "." + name if path else name,
)
setattr(module, name, child)

return module


def _get_node_args_helper(
model: torch.nn.Module,
# pyre-ignore
Expand All @@ -695,13 +749,23 @@ def _get_node_args_helper(
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
It also counts the number of (args + kwargs) found.
"""
arg_info_list = [ArgInfo([], [], [], None) for _ in range(len(arguments))]
arg_info_list = [ArgInfo([], [], [], [], None) for _ in range(len(arguments))]
for arg, arg_info in zip(arguments, arg_info_list):
if arg is None:
num_found += 1
continue
while True:
if not isinstance(arg, torch.fx.Node):
if pipeline_preproc:
arg_info.input_attrs.insert(0, "")
arg_info.is_getitems.insert(0, False)
arg_info.preproc_modules.insert(0, None)
if isinstance(arg, (fx_immutable_dict, fx_immutable_list)):
# Make them mutable again, in case in-place updates are made
arg_info.constants.insert(0, arg.copy())
else:
arg_info.constants.insert(0, arg)
num_found += 1
break
child_node = arg

Expand All @@ -719,11 +783,13 @@ def _get_node_args_helper(
arg_info.input_attrs.append(key)
arg_info.is_getitems.append(False)
arg_info.preproc_modules.append(None)
arg_info.constants.append(None)
else:
# no-op
arg_info.input_attrs.insert(0, "")
arg_info.is_getitems.insert(0, False)
arg_info.preproc_modules.insert(0, None)
arg_info.constants.insert(0, None)

num_found += 1
break
Expand All @@ -740,6 +806,7 @@ def _get_node_args_helper(
arg_info.input_attrs.insert(0, child_node.args[1])
arg_info.is_getitems.insert(0, False)
arg_info.preproc_modules.insert(0, None)
arg_info.constants.insert(0, None)
arg = child_node.args[0]
elif (
child_node.op == "call_function"
Expand All @@ -754,6 +821,7 @@ def _get_node_args_helper(
arg_info.input_attrs.insert(0, child_node.args[1])
arg_info.is_getitems.insert(0, True)
arg_info.preproc_modules.insert(0, None)
arg_info.constants.insert(0, None)
arg = child_node.args[0]
elif (
child_node.op == "call_function"
Expand Down Expand Up @@ -795,10 +863,13 @@ def _get_node_args_helper(
arg_info.input_attrs.insert(0, child_node.args[1])
arg_info.is_getitems.insert(0, True)
arg_info.preproc_modules.insert(0, None)
arg_info.constants.insert(0, None)
arg = child_node.args[0]
elif child_node.op == "call_module":
preproc_module_fqn = str(child_node.target)
preproc_module = getattr(model, preproc_module_fqn, None)
preproc_module = _find_preproc_module_recursive(
model, preproc_module_fqn
)

if not pipeline_preproc:
logger.warning(
Expand All @@ -816,6 +887,7 @@ def _get_node_args_helper(
arg_info.is_getitems.insert(0, False)
pipelined_preprocs.add(preproc_module)
arg_info.preproc_modules.insert(0, preproc_module)
arg_info.constants.insert(0, None)
num_found += 1
break

Expand Down Expand Up @@ -858,12 +930,15 @@ def _get_node_args_helper(
)

# module swap
setattr(model, preproc_module_fqn, pipelined_preproc_module)
_swap_preproc_module_recursive(
model, pipelined_preproc_module, preproc_module_fqn
)

arg_info.input_attrs.insert(0, "") # dummy value
arg_info.is_getitems.insert(0, False)
pipelined_preprocs.add(pipelined_preproc_module)
arg_info.preproc_modules.insert(0, pipelined_preproc_module)
arg_info.constants.insert(0, None)

num_found += 1

Expand Down

0 comments on commit 6c8b397

Please sign in to comment.