Skip to content
/ FedTGP Public

AAAI 2024 accepted paper, FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced Contrastive Learning for Data and Model Heterogeneity in Federated Learning

License

Notifications You must be signed in to change notification settings

TsingZ0/FedTGP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

26 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Introduction

This is the implementation of our paper FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced Contrastive Learning for Data and Model Heterogeneity in Federated Learning (accepted by AAAI 2024).

Key words: federated learning, data heterogeneity, model heterogeneity, communication overhead, intellectual property (IP) protection

Take away: We enhance the typical HtFL method FedProto with Trainable Global Prototypes (TGP) and Adaptive-margin-enhanced Contrastive Learning (ACL), making it more versatile and resilient to various model heterogeneities.

Citation

@inproceedings{zhang2024fedtgp,
  title={FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced Contrastive Learning for Data and Model Heterogeneity in Federated Learning},
  author={Zhang, Jianqing and Liu, Yang and Hua, Yang and Cao, Jian},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  year={2024}
}

The global and client prototypes in FedProto and our FedTGP. Different colors and numbers represent classes and clients, respectively. Circles represent the client prototypes and triangles represent the global prototypes. The black and yellow dotted arrows show the inter-class separation among the client and global prototypes, respectively. Triangles with dotted borders represent our Trainable Global Prototypes (TGP). The red arrows show the inter-class intervals between TGP and the client prototypes of other classes in our Adaptive-margin-enhanced Contrastive Learning (ACL).

Dataset

Due to the file size limitation, we only upload the statistics (config.json) of the Cifar10 dataset in the practical setting ($\beta=0.1$). Please refer to our popular repository PFLlib and HtFLlib to generate all the datasets and create the required python environment.

System

Learning reasonable global prototypes can be challenging in some cases, particularly due to the limited number of client prototypes and the introduced adaptive margin during ACL. To address this, consider setting a larger top_cnt and ensuring that the global communication iteration number is larger than 1000, which should result in a Server loss smaller than 0.001. The best accuracy is typically achieved when a minimal Server loss is obtained. In most of our experiments, we achieved a Server loss of 0.0.

  • main.py: system configurations.
  • run_me.sh: command lines to run experiments. It is advisable to retune hyperparameters on new tasks.
  • flcore/:
    • clients/: the code on clients. See HtFL for baselines.
    • servers/: the code on servers. See HtFL for baselines.
    • trainmodel/: the code for some heterogeneous client models.
  • utils/:
    • data_utils.py: the code to read the dataset.
    • mem_utils.py: the code to record memory usage.
    • result_utils.py: the code to save results to files.

About

AAAI 2024 accepted paper, FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced Contrastive Learning for Data and Model Heterogeneity in Federated Learning

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published