- This paper revolves around the topic of mitigating the client-drift effect in federated learning (FL), which is a joint consequence of
- Heterogeneous data distributions across agents/clients/workers.
- Intermittent communication between the central server and the agents.
- In light of this, we develop a novel algorithm called divergence-based adaptive aggregation (DRAG) which
- Hinges on the metric introduced as "degree of divergence" that quantifies the angle between the local gradient direction of each agent and the reference direction (the one we desire to move toward).
- Dynamically "drags" the received local updates toward the reference direction in each round without extra communication overhead.
- Rigorous convergence analysis is provided for DRAG, proving a sublinear convergence rate.
- Experiments demonstrate the superior performance of DRAG compared to advanced algorithms.
- Additionally, DRAG exhibits remarkable resilience towards byzantine attack thanks to its "dragging" nature. Numerical results are provided to showcase this property.
- In the context of FL, the classical problem of interest is to minimize the average/sum of the local loss functions across agents/clients. With a predominant probability, different agents have distinct local datasets generated from different distributions, thus having distinct local loss functions, which is where the "heterogeneity" comes from.
- On a different note, since communication overhead is a bottleneck in FL, agents cannot afford to talk to the server each time they have finished computing a gradient. Ergo, the famous federated averaging (FedAvg) algorithm requires that each agent compute multiple local updates on its local model before communicating with the server to upload its latest model. This is known as the universal "intermittent communication" mechanism in FL.
This figure illustrates the client-drift effect (source: Karimireddy et al., 2021).
However, if these two features are combined together, the "client-drift" effect would then arise since each agent inevitably updates its local model towards its local optimum. Furthermore, simply adding these updated models up and averaging them will not necessarily perfectly cancels the heteregeneity out, but will impede the convergence rate, as both theoretically and experimentally demonstrated by previous work.
As FL operates in a distributed manner, it is susceptible to adversarial attacks launched by malicious clients, commonly known as Byzantine attacks. Classical Byzantine attacks include
- Reversing the direction of the local gradient.
- Scaling the local gradient by a factor to negatively influence the training process.
This figure illustrates the Byzantine attack.
- Problem of interest:
- Local update formula of each agent:
- Model difference uploaded by each agent
$m$ at round$t$ :
- Server aggregation:
- Server aggregation (w/ Byzantine attacks):
Here,
-
Reference direction: The objective of the reference direction is to offer a practical and sensible direction for modifying the local gradients, denoted as
$r^t$
where
Here,
- Degree of divergence: This is the fundamental concept of the DRAG algorithm, measuring the extent to which the local update direction diverges from the reference direction:
where the constant
- Vector manipulation: Using the reference direction and degree of divergence. we "drag" each local model difference towards the reference function via vector manipulation:
Note that we normalize the reference direction such that the modified difference consistently has a greater component aligned with
- Server broadcasts the global parameter to a selected set of agents at round
$t$ . - Each agent performs
$U$ local updates and sends the model difference to the server. - Server calculates the reference direction and the degree of divergence.
- Server modifies the model difference and aggregates the modified differences.
To deal with Byzantine attacks, we need to make adaptations to the DRAG algorithm. Since the malicious attacks can undermine the effectiveness of the reference direction, we update the definition of reference direction in this setup.
Reference direction for Byzantine attacks: The server maintains a small root dataset
The reference direction is then defined as
We also need to update the vector manipulation procedure in this case, since the malicious agents might scale the module of the local model differences.
Vector manipulation for Byzantine attacks: The modified model difference is defined as
Assuming each local loss function is smooth and lower-bounded, and assuming that the local gradient estimator is unbiased and has bounded variance, we show that DRAG converges to the first-order stationary point (FOSP) at a sublinear rate:
We compare our DRAG algorithm with advanced algorithms on the CIFAR-10 dataset:
We compare DRAG with other algorithms under Byzantine attacks with high data heterogeneity:
- The file
DRAG_cifar.py
conducts comparisons of ACAG with other algorithms on the CIFAR-10 dataset. - The file
DRAG_byzantine.py
compares ACAG with other algorithms on the CIFAR-10 dataset, accounting for the malicious agents.