From 4f114bc2652b4166f09ac8a7f39656aad654e1d5 Mon Sep 17 00:00:00 2001 From: Jie You Date: Fri, 19 Jul 2024 11:36:20 -0700 Subject: [PATCH] Improve Composability of ITEP (#2236) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2236 Mirror the change in D56261218. Making OSS version of ITEP-EBC composable. Reviewed By: sarckk Differential Revision: D59617001 fbshipit-source-id: a5063941f8a2cd938167b62232639b8d3de1aab9 --- torchrec/modules/itep_modules.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchrec/modules/itep_modules.py b/torchrec/modules/itep_modules.py index 3a626887f..14a479b2e 100644 --- a/torchrec/modules/itep_modules.py +++ b/torchrec/modules/itep_modules.py @@ -12,6 +12,7 @@ import torch from torch import nn +from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.embedding_types import ShardedEmbeddingTable from torchrec.modules.embedding_modules import reorder_inverse_indices from torchrec.sparse.jagged_tensor import _pin_and_move, _to_offsets, KeyedJaggedTensor @@ -200,6 +201,8 @@ def init_itep_state(self) -> None: # Iterate over all tables # pyre-ignore for lookup in self.lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module for emb in lookup._emb_modules: emb_tables: List[ShardedEmbeddingTable] = emb._config.embedding_tables @@ -283,6 +286,8 @@ def reset_weight_momentum( if self.lookups is not None: # pyre-ignore for lookup in self.lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module for emb in lookup._emb_modules: emb_tables: List[ShardedEmbeddingTable] = ( emb._config.embedding_tables @@ -322,6 +327,8 @@ def flush_uvm_cache(self) -> None: if self.lookups is not None: # pyre-ignore for lookup in self.lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module for emb in lookup._emb_modules: emb.emb_module.flush() emb.emb_module.reset_cache_states()