diff --git a/fedscale/utils/optimizer/yogi.py b/fedscale/utils/optimizer/yogi.py index 0302b974..2d73d0ad 100755 --- a/fedscale/utils/optimizer/yogi.py +++ b/fedscale/utils/optimizer/yogi.py @@ -15,7 +15,7 @@ def __init__(self, eta=1e-2, tau=1e-3, beta=0.9, beta2=0.99): def update(self, gradients): update_gradients = [] if not self.v_t: - self.v_t = [g**2 for g in gradients] + self.v_t = [torch.full_like(g, self.tau) for g in gradients] self.m_t = [torch.full_like(g, 0.0) for g in gradients] for idx, gradient in enumerate(gradients):