Skip to content

Enhancing the BERT training with Semi-supervised Generative Adversarial Networks in Pytorch/HuggingFace

License

Notifications You must be signed in to change notification settings

crux82/ganbert-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 

Repository files navigation

GAN-BERT (in Pytorch and compatible with HuggingFace)

This is an implementation in Pytorch (and HuggingFace) of the GAN-BERT method from https://github.com/crux82/ganbert which is available in Tensorflow. While the original GAN-BERT was an extension of BERT, this implementation can be adapted to several architectures, ranging from Roberta to Albert!

IMPORTANT: Since this implementation is slightly different from the original Tensorflow one, some results may vary. Any feedback or suggestions for improving this first version would be appreciated.

GANBERT

This is the code for the paper "GAN-BERT: Generative Adversarial Learning for Robust Text Classification with a Bunch of Labeled Examples" published in the ACL 2020 - short paper by Danilo Croce (Tor Vergata, University of Rome), Giuseppe Castellucci (Amazon) and Roberto Basili (Tor Vergata, University of Rome).

GAN-BERT is an extension of BERT which uses a Generative Adversarial setting to implement an effective semi-supervised learning schema. It allows training BERT with datasets composed of a limited amount of labeled examples and larger subsets of unlabeled material. GAN-BERT can be used in sequence classification tasks (also involving text pairs).

As in the original implementation in Tensorflow, this code runs the GAN-BERT experiment over the TREC dataset for the fine-grained Question Classification task. We provide in this package the code as well as the data for running an experiment by using 2% of the labeled material (109 examples) and 5343 unlabeled examples. The test set is composed of 500 annotated examples.

The Model

GAN-BERT is an extension of the BERT model within the Generative Adversarial Network (GAN) framework (Goodfellow et al, 2014). In particular, the Semi-Supervised GAN (Salimans et al, 2016) is used to make the BERT fine-tuning robust in such training scenarios where obtaining annotated material is problematic. When fine-tuned with very few labeled examples the BERT model is not able to provide sufficient performances. With GAN-BERT we extend the fine-tuning stage by introducing a Discriminator-Generator setting, where:

  • the Generator G is devoted to producing "fake" vector representations of sentences;
  • the Discriminator D is a BERT-based classifier over k+1 categories.

GAN-BERT model

D has the role of classifying an example concerning the k categories of the task of interest, and it should recognize the examples that are generated by G (the k+1 category). G, instead, must produce representations as much similar as possible to the ones produced by the model for the "real" examples. G is penalized when D correctly classifies an example as fake.

In this context, the model is trained on both labeled and unlabeled examples. The labeled examples contribute to the computation of the loss function concerning the task k categories. The unlabeled examples contribute to the computation of the loss functions as they should not be incorrectly classified as belonging to the k+1 category (i.e., the fake category).

The resulting model is demonstrated to learn text classification tasks starting from very few labeled examples (50-60 examples) and to outperform the classical BERT fine-tuned models by a large margin in this setting.

More details are available at https://github.com/crux82/ganbert

Citation

If this software is usefull for your research, please cite the following paper:

@inproceedings{croce-etal-2020-gan,
    title = "{GAN}-{BERT}: Generative Adversarial Learning for Robust Text Classification with a Bunch of Labeled Examples",
    author = "Croce, Danilo  and
      Castellucci, Giuseppe  and
      Basili, Roberto",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
    month = jul,
    year = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url = "https://www.aclweb.org/anthology/2020.acl-main.191",
    pages = "2114--2119"
}

Acknowledgments

We would like to thank Osman Mutlu and Ali Hürriyetoğlu for their implementation of GAN-BERT in Pytorch that inspired our porting. You can find their initial repository at this link. We would like to thank Claudia Breazzano (Tor Vergata, University of Rome) that supported this porting.

About

Enhancing the BERT training with Semi-supervised Generative Adversarial Networks in Pytorch/HuggingFace

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published