Skip to content

In this work, we developed a novel algorithm named divergence-based adaptive aggregation (DRAG) to deal with the client-drift effect. Additionally, the DRAG algorithm also showcases resilience against byzantine attacks, as demonstrated through experiments.

Notifications You must be signed in to change notification settings

fzhu0628/DRAG---Divergence-Based-Adaptive-Aggregation-in-Federated-Learning-on-Non-IID-Data

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 

Repository files navigation

DRAG-Divergence-based-adaptive-aggregation-in-federated-learning-on-non-iid-data

About this paper

  • This paper revolves around the topic of mitigating the client-drift effect in federated learning (FL), which is a joint consequence of
    • Heterogeneous data distributions across agents/clients/workers.
    • Intermittent communication between the central server and the agents.
  • In light of this, we develop a novel algorithm called divergence-based adaptive aggregation (DRAG) which
    • Hinges on the metric introduced as "degree of divergence" that quantifies the angle between the local gradient direction of each agent and the reference direction (the one we desire to move toward).
    • Dynamically "drags" the received local updates toward the reference direction in each round without extra communication overhead.
  • Rigorous convergence analysis is provided for DRAG, proving a sublinear convergence rate.
  • Experiments demonstrate the superior performance of DRAG compared to advanced algorithms.
  • Additionally, DRAG exhibits remarkable resilience towards byzantine attack thanks to its "dragging" nature. Numerical results are provided to showcase this property.

Motivations

Client-Drift Effect

  • In the context of FL, the classical problem of interest is to minimize the average/sum of the local loss functions across agents/clients. With a predominant probability, different agents have distinct local datasets generated from different distributions, thus having distinct local loss functions, which is where the "heterogeneity" comes from.
  • On a different note, since communication overhead is a bottleneck in FL, agents cannot afford to talk to the server each time they have finished computing a gradient. Ergo, the famous federated averaging (FedAvg) algorithm requires that each agent compute multiple local updates on its local model before communicating with the server to upload its latest model. This is known as the universal "intermittent communication" mechanism in FL.

image

This figure illustrates the client-drift effect (source: Karimireddy et al., 2021).

However, if these two features are combined together, the "client-drift" effect would then arise since each agent inevitably updates its local model towards its local optimum. Furthermore, simply adding these updated models up and averaging them will not necessarily perfectly cancels the heteregeneity out, but will impede the convergence rate, as both theoretically and experimentally demonstrated by previous work.

Byzantine Attacks

As FL operates in a distributed manner, it is susceptible to adversarial attacks launched by malicious clients, commonly known as Byzantine attacks. Classical Byzantine attacks include

  • Reversing the direction of the local gradient.
  • Scaling the local gradient by a factor to negatively influence the training process.

image

This figure illustrates the Byzantine attack.

Exhibition of the DRAG algorithm

Problem Formulation

  • Problem of interest:

image

  • Local update formula of each agent:

image

  • Model difference uploaded by each agent $m$ at round $t$:

image

  • Server aggregation:

image

  • Server aggregation (w/ Byzantine attacks):

image

Here, $\hat{\mathbf{g}}_m^t=p_m^tg_m^t$ is the perturbed gradient where $p_m^t$ can be either positive or negative, and $A^t$ denotes the set of malicious agents.

Key Definitions

  1. Reference direction: The objective of the reference direction is to offer a practical and sensible direction for modifying the local gradients, denoted as $r^t$

image

where

image

Here, $v_m^t$ is the modified model difference for agent $m$ at round $t$ (defined later). The reference direction takes a bootstrapping form and is a weighted sum of all the historical global update directions.

  1. Degree of divergence: This is the fundamental concept of the DRAG algorithm, measuring the extent to which the local update direction diverges from the reference direction:

image

where the constant $c$ helps provide more flexibility.

  1. Vector manipulation: Using the reference direction and degree of divergence. we "drag" each local model difference towards the reference function via vector manipulation:

image

Note that we normalize the reference direction such that the modified difference consistently has a greater component aligned with $r^t$ compared to $g_m^t$.

image

Algorithm Description

  1. Server broadcasts the global parameter to a selected set of agents at round $t$.
  2. Each agent performs $U$ local updates and sends the model difference to the server.
  3. Server calculates the reference direction and the degree of divergence.
  4. Server modifies the model difference and aggregates the modified differences.

Defending Against Byzantine Attacks

To deal with Byzantine attacks, we need to make adaptations to the DRAG algorithm. Since the malicious attacks can undermine the effectiveness of the reference direction, we update the definition of reference direction in this setup.

Reference direction for Byzantine attacks: The server maintains a small root dataset $D_{root}$. At each round $t$, the server also updates a copy of the current global model for $U$ iterations using the root dataset:

image

The reference direction is then defined as

image

We also need to update the vector manipulation procedure in this case, since the malicious agents might scale the module of the local model differences.

Vector manipulation for Byzantine attacks: The modified model difference is defined as

image

Convergence Analysis

Assuming each local loss function is smooth and lower-bounded, and assuming that the local gradient estimator is unbiased and has bounded variance, we show that DRAG converges to the first-order stationary point (FOSP) at a sublinear rate:

image

Numerical Results

Client-Drift Mitigation

We compare our DRAG algorithm with advanced algorithms on the CIFAR-10 dataset:

image

Byzantine Attack

We compare DRAG with other algorithms under Byzantine attacks with high data heterogeneity:

image

Codes

  • The file DRAG_cifar.py conducts comparisons of ACAG with other algorithms on the CIFAR-10 dataset.
  • The file DRAG_byzantine.py compares ACAG with other algorithms on the CIFAR-10 dataset, accounting for the malicious agents.

About

In this work, we developed a novel algorithm named divergence-based adaptive aggregation (DRAG) to deal with the client-drift effect. Additionally, the DRAG algorithm also showcases resilience against byzantine attacks, as demonstrated through experiments.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages