Skip to content

Commit

Permalink
Support augmented batch with model input as attr to semi-sync pipeline (
Browse files Browse the repository at this point in the history
pytorch#2374)

Summary:
Pull Request resolved: pytorch#2374

Some pipelines need to inherit from the torchrec pipeline, and need to work with dataloaders that return augmented batches where the actual model input is an attribute of the batch returned by a dataloader (e.g. `augmented_batch.input`).

This adds support for pipelines inheriting the semi-sync pipeline to override `extract_model_input_from_batch` implementation and extract the model input for model rewriting.

Differential Revision: D62058136

fbshipit-source-id: 4f4b4672919f90bdf139c6674a16a8ed2d24573d
  • Loading branch information
sarckk authored and facebook-github-bot committed Sep 11, 2024
1 parent 30b88f1 commit 0bc1baa
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
9 changes: 7 additions & 2 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from torchrec.pt2.checks import is_torchdynamo_compiling
from torchrec.pt2.utils import default_pipeline_input_transformer
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.streamable import Pipelineable

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -930,6 +931,9 @@ def copy_batch_to_gpu(
context.events.append(event)
return batch, context

def extract_model_input_from_batch(self, batch: In) -> Pipelineable:
return batch

def start_sparse_data_dist(
self,
batch: Optional[In],
Expand All @@ -949,7 +953,8 @@ def start_sparse_data_dist(
with record_function(f"## start_sparse_data_dist {context.index} ##"):
with self._stream_context(self._data_dist_stream):
_wait_for_events(batch, context, self._data_dist_stream)
_start_data_dist(self._pipelined_modules, batch, context)
model_input = self.extract_model_input_from_batch(batch)
_start_data_dist(self._pipelined_modules, model_input, context)
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)
Expand Down Expand Up @@ -979,7 +984,7 @@ def start_embedding_lookup(
else self._embedding_odd_streams[i]
)
with self._stream_context(stream):
_start_embedding_lookup(module, batch, context, stream)
_start_embedding_lookup(module, context, stream)
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)
Expand Down
3 changes: 1 addition & 2 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def _wait_for_events(

def _start_data_dist(
pipelined_modules: List[ShardedModule],
batch: In,
batch: Pipelineable,
context: TrainPipelineContext,
) -> None:
if context.version == 0:
Expand Down Expand Up @@ -620,7 +620,6 @@ def _start_data_dist(

def _start_embedding_lookup(
module: ShardedModule,
batch: In, # not used in this function
context: EmbeddingTrainPipelineContext,
stream: Optional[torch.Stream],
) -> None:
Expand Down

0 comments on commit 0bc1baa

Please sign in to comment.