Skip to content

Commit

Permalink
ensemble_mode consolidation (fbgemm) (#3197)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3197

X-link: facebookresearch/FBGEMM#295

ensemble_mode consolidation (fbgemm)

Reviewed By: q10, csmiler

Differential Revision: D63634421
  • Loading branch information
minhua-chen authored and facebook-github-bot committed Sep 30, 2024
1 parent f2a0156 commit 488a0e4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class EnsembleModeDefinition:
step_ema: float = 10000
step_swap: float = 10000
step_start: float = 0
step_ema_coef: float = 0.6
step_mode: StepMode = StepMode.USE_ITER


Expand Down Expand Up @@ -457,8 +458,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
Adam. Note that default is different from torch.nn.optim.Adagrad
default of 1e-10
momentum (float = 0.9): Momentum used by LARS-SGD and
ENSEMBLE_ROWWISE_ADAGRAD
momentum (float = 0.9): Momentum used by LARS-SGD
weight_decay (float = 0.0): Weight decay used by LARS-SGD, LAMB, ADAM,
and rowwise-Adagrad.
Expand Down Expand Up @@ -924,11 +924,8 @@ def __init__( # noqa C901

if ensemble_mode is None:
ensemble_mode = EnsembleModeDefinition()
self._ensemble_mode: Dict[str, int] = {
"step_ema": int(ensemble_mode.step_ema),
"step_swap": int(ensemble_mode.step_swap),
"step_start": int(ensemble_mode.step_start),
"step_mode": int(ensemble_mode.step_mode.value),
self._ensemble_mode: Dict[str, float] = {
key: float(fval) for key, fval in ensemble_mode.__dict__.items()
}

if counter_based_regularization is None:
Expand Down Expand Up @@ -1002,6 +999,7 @@ def __init__( # noqa C901
if (
optimizer_state_dtypes is None
or "momentum1" not in optimizer_state_dtypes
or optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD
)
else optimizer_state_dtypes["momentum1"].as_dtype()
)
Expand Down Expand Up @@ -1938,7 +1936,7 @@ def forward( # noqa: C901

raise ValueError(f"Invalid OptimType: {self.optimizer}")

def ensemble_and_swap(self, ensemble_mode: Dict[str, int]) -> None:
def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None:
should_ema = self.iter.item() % int(ensemble_mode["step_ema"]) == 0
should_swap = self.iter.item() % int(ensemble_mode["step_swap"]) == 0
if should_ema or should_swap:
Expand All @@ -1947,7 +1945,7 @@ def ensemble_and_swap(self, ensemble_mode: Dict[str, int]) -> None:
for i in range(len(self.embedding_specs)):
if should_ema:
coef_ema = (
self.optimizer_args.momentum
ensemble_mode["step_ema_coef"]
if self.iter.item() > int(ensemble_mode["step_start"])
else 0.0
)
Expand Down
15 changes: 4 additions & 11 deletions fbgemm_gpu/test/tbe/training/backward_optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
CounterBasedRegularizationDefinition,
CounterWeightDecayMode,
CowClipDefinition,
EnsembleModeDefinition,
GradSumDecay,
LearningRateMode,
SplitTableBatchedEmbeddingBagsCodegen,
Expand All @@ -34,13 +35,6 @@
WeightDecayMode,
)

try:
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
EnsembleModeDefinition,
)
except ImportError:
EnsembleModeDefinition = None

from fbgemm_gpu.tbe.utils import (
b_indices,
get_table_batched_offsets_from_dense,
Expand Down Expand Up @@ -315,23 +309,22 @@ def execute_backward_optimizers_( # noqa C901
optimizer_kwargs["eta"] = eta

if optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD:
(eps, step_ema, step_swap, step_start, step_mode, momentum) = (
(eps, step_ema, step_swap, step_start, step_mode) = (
1e-4,
1.0,
1.0,
0.0,
StepMode.USE_ITER,
0.8,
)
optimizer_kwargs["eps"] = eps
optimizer_kwargs["optimizer_state_dtypes"] = optimizer_state_dtypes
optimizer_kwargs["ensemble_mode"] = EnsembleModeDefinition(
step_ema=step_ema,
step_swap=step_swap,
step_start=step_start,
step_ema_coef=momentum,
step_mode=step_mode,
)
optimizer_kwargs["momentum"] = momentum
optimizer_kwargs["optimizer_state_dtypes"] = optimizer_state_dtypes

cc = emb_op(
embedding_specs=[
Expand Down

0 comments on commit 488a0e4

Please sign in to comment.