From 0bc1baa8ff8e2ce5c35b0cad3363990360a96409 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Wed, 11 Sep 2024 11:16:21 -0700 Subject: [PATCH] Support augmented batch with model input as attr to semi-sync pipeline (#2374) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2374 Some pipelines need to inherit from the torchrec pipeline, and need to work with dataloaders that return augmented batches where the actual model input is an attribute of the batch returned by a dataloader (e.g. `augmented_batch.input`). This adds support for pipelines inheriting the semi-sync pipeline to override `extract_model_input_from_batch` implementation and extract the model input for model rewriting. Differential Revision: D62058136 fbshipit-source-id: 4f4b4672919f90bdf139c6674a16a8ed2d24573d --- torchrec/distributed/train_pipeline/train_pipelines.py | 9 +++++++-- torchrec/distributed/train_pipeline/utils.py | 3 +-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index bab055c94..038443432 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -63,6 +63,7 @@ from torchrec.pt2.checks import is_torchdynamo_compiling from torchrec.pt2.utils import default_pipeline_input_transformer from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.streamable import Pipelineable logger: logging.Logger = logging.getLogger(__name__) @@ -930,6 +931,9 @@ def copy_batch_to_gpu( context.events.append(event) return batch, context + def extract_model_input_from_batch(self, batch: In) -> Pipelineable: + return batch + def start_sparse_data_dist( self, batch: Optional[In], @@ -949,7 +953,8 @@ def start_sparse_data_dist( with record_function(f"## start_sparse_data_dist {context.index} ##"): with self._stream_context(self._data_dist_stream): _wait_for_events(batch, context, self._data_dist_stream) - _start_data_dist(self._pipelined_modules, batch, context) + model_input = self.extract_model_input_from_batch(batch) + _start_data_dist(self._pipelined_modules, model_input, context) event = torch.get_device_module(self._device).Event() event.record() context.events.append(event) @@ -979,7 +984,7 @@ def start_embedding_lookup( else self._embedding_odd_streams[i] ) with self._stream_context(stream): - _start_embedding_lookup(module, batch, context, stream) + _start_embedding_lookup(module, context, stream) event = torch.get_device_module(self._device).Event() event.record() context.events.append(event) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 8aaa8a1bc..ecca28280 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -583,7 +583,7 @@ def _wait_for_events( def _start_data_dist( pipelined_modules: List[ShardedModule], - batch: In, + batch: Pipelineable, context: TrainPipelineContext, ) -> None: if context.version == 0: @@ -620,7 +620,6 @@ def _start_data_dist( def _start_embedding_lookup( module: ShardedModule, - batch: In, # not used in this function context: EmbeddingTrainPipelineContext, stream: Optional[torch.Stream], ) -> None: