From 09f29c50ff375d5cc958732176897a7f76c5e278 Mon Sep 17 00:00:00 2001 From: Minhua Chen Date: Tue, 1 Oct 2024 18:36:13 -0700 Subject: [PATCH] add step_mode for ensemble_rowwise_adagrad (#3203) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3203 X-link: https://github.com/facebookresearch/FBGEMM/pull/300 add step_mode for ensemble_rowwise_adagrad Reviewed By: q10, minddrummer Differential Revision: D63681997 --- .../split_table_batched_embeddings_ops_training.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index f41a49d19..1730dbedc 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1944,10 +1944,17 @@ def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None: states = self.split_optimizer_states() for i in range(len(self.embedding_specs)): if should_ema: + step_start = int(ensemble_mode["step_start"]) + if int(ensemble_mode["step_mode"]) == 1: + should_ema_reset = self.iter.item() % step_start == 0 + elif int(ensemble_mode["step_mode"]) == 2: + should_ema_reset = self.iter.item() <= step_start + else: + should_ema_reset = (self.iter.item() <= step_start) or ( + self.iter.item() % step_start == 0 + ) coef_ema = ( - ensemble_mode["step_ema_coef"] - if self.iter.item() > int(ensemble_mode["step_start"]) - else 0.0 + 0.0 if should_ema_reset else ensemble_mode["step_ema_coef"] ) weights_cpu = weights[i].to( dtype=states[i][1].dtype, device=states[i][1].device