Skip to content

Commit

Permalink
add step_mode for ensemble_rowwise_adagrad (#3203)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3203

X-link: facebookresearch/FBGEMM#300

add step_mode for ensemble_rowwise_adagrad

Reviewed By: q10, minddrummer

Differential Revision: D63681997
  • Loading branch information
minhua-chen authored and facebook-github-bot committed Oct 2, 2024
1 parent c24a72d commit 09f29c5
Showing 1 changed file with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1944,10 +1944,17 @@ def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None:
states = self.split_optimizer_states()
for i in range(len(self.embedding_specs)):
if should_ema:
step_start = int(ensemble_mode["step_start"])
if int(ensemble_mode["step_mode"]) == 1:
should_ema_reset = self.iter.item() % step_start == 0
elif int(ensemble_mode["step_mode"]) == 2:
should_ema_reset = self.iter.item() <= step_start
else:
should_ema_reset = (self.iter.item() <= step_start) or (
self.iter.item() % step_start == 0
)
coef_ema = (
ensemble_mode["step_ema_coef"]
if self.iter.item() > int(ensemble_mode["step_start"])
else 0.0
0.0 if should_ema_reset else ensemble_mode["step_ema_coef"]
)
weights_cpu = weights[i].to(
dtype=states[i][1].dtype, device=states[i][1].device
Expand Down

0 comments on commit 09f29c5

Please sign in to comment.