From 488a0e4dbbb17037979d8ea184e5416032ef35e4 Mon Sep 17 00:00:00 2001 From: Minhua Chen Date: Mon, 30 Sep 2024 11:56:58 -0700 Subject: [PATCH] ensemble_mode consolidation (fbgemm) (#3197) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3197 X-link: https://github.com/facebookresearch/FBGEMM/pull/295 ensemble_mode consolidation (fbgemm) Reviewed By: q10, csmiler Differential Revision: D63634421 --- ...plit_table_batched_embeddings_ops_training.py | 16 +++++++--------- .../tbe/training/backward_optimizers_test.py | 15 ++++----------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 71c9325cf..f41a49d19 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -150,6 +150,7 @@ class EnsembleModeDefinition: step_ema: float = 10000 step_swap: float = 10000 step_start: float = 0 + step_ema_coef: float = 0.6 step_mode: StepMode = StepMode.USE_ITER @@ -457,8 +458,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): Adam. Note that default is different from torch.nn.optim.Adagrad default of 1e-10 - momentum (float = 0.9): Momentum used by LARS-SGD and - ENSEMBLE_ROWWISE_ADAGRAD + momentum (float = 0.9): Momentum used by LARS-SGD weight_decay (float = 0.0): Weight decay used by LARS-SGD, LAMB, ADAM, and rowwise-Adagrad. @@ -924,11 +924,8 @@ def __init__( # noqa C901 if ensemble_mode is None: ensemble_mode = EnsembleModeDefinition() - self._ensemble_mode: Dict[str, int] = { - "step_ema": int(ensemble_mode.step_ema), - "step_swap": int(ensemble_mode.step_swap), - "step_start": int(ensemble_mode.step_start), - "step_mode": int(ensemble_mode.step_mode.value), + self._ensemble_mode: Dict[str, float] = { + key: float(fval) for key, fval in ensemble_mode.__dict__.items() } if counter_based_regularization is None: @@ -1002,6 +999,7 @@ def __init__( # noqa C901 if ( optimizer_state_dtypes is None or "momentum1" not in optimizer_state_dtypes + or optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD ) else optimizer_state_dtypes["momentum1"].as_dtype() ) @@ -1938,7 +1936,7 @@ def forward( # noqa: C901 raise ValueError(f"Invalid OptimType: {self.optimizer}") - def ensemble_and_swap(self, ensemble_mode: Dict[str, int]) -> None: + def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None: should_ema = self.iter.item() % int(ensemble_mode["step_ema"]) == 0 should_swap = self.iter.item() % int(ensemble_mode["step_swap"]) == 0 if should_ema or should_swap: @@ -1947,7 +1945,7 @@ def ensemble_and_swap(self, ensemble_mode: Dict[str, int]) -> None: for i in range(len(self.embedding_specs)): if should_ema: coef_ema = ( - self.optimizer_args.momentum + ensemble_mode["step_ema_coef"] if self.iter.item() > int(ensemble_mode["step_start"]) else 0.0 ) diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index b1809dc1b..2db48594d 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -26,6 +26,7 @@ CounterBasedRegularizationDefinition, CounterWeightDecayMode, CowClipDefinition, + EnsembleModeDefinition, GradSumDecay, LearningRateMode, SplitTableBatchedEmbeddingBagsCodegen, @@ -34,13 +35,6 @@ WeightDecayMode, ) -try: - from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( - EnsembleModeDefinition, - ) -except ImportError: - EnsembleModeDefinition = None - from fbgemm_gpu.tbe.utils import ( b_indices, get_table_batched_offsets_from_dense, @@ -315,23 +309,22 @@ def execute_backward_optimizers_( # noqa C901 optimizer_kwargs["eta"] = eta if optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD: - (eps, step_ema, step_swap, step_start, step_mode, momentum) = ( + (eps, step_ema, step_swap, step_start, step_mode) = ( 1e-4, 1.0, 1.0, 0.0, StepMode.USE_ITER, - 0.8, ) optimizer_kwargs["eps"] = eps + optimizer_kwargs["optimizer_state_dtypes"] = optimizer_state_dtypes optimizer_kwargs["ensemble_mode"] = EnsembleModeDefinition( step_ema=step_ema, step_swap=step_swap, step_start=step_start, + step_ema_coef=momentum, step_mode=step_mode, ) - optimizer_kwargs["momentum"] = momentum - optimizer_kwargs["optimizer_state_dtypes"] = optimizer_state_dtypes cc = emb_op( embedding_specs=[