Skip to content

dell-research-harvard/BertGCN

 
 

Repository files navigation

BertGCN

This repo contains code for BertGCN: Transductive Text Classification by Combining GCN and BERT.

Introduction

In this work, we propose BertGCN, a model that combines large scale pretraining and transductive learning for text classification. BertGCN constructs a heterogeneous graph over the dataset and represents documents as nodes using BERT representations. By jointly training the BERT and GCN modules within BertGCN, the proposed model is able to leverage the advantages of both worlds: large-scale pretraining which takes the advantage of the massive amount of raw data and transductive learning which jointly learns representations for both training data and unlabeled test data by propagating label influence through graph convolution. Experiments show that BertGCN achieves SOTA performances on a wide range of text classification datasets.

Main Results

Model 20NG R8 R52 Ohsumed MR
TextGCN 86.3 97.1 93.6 68.4 76.7
SGC 88.5 97.2 94.0 68.5 75.9
BERT 85.3 97.8 96.4 70.5 85.7
RoBERTa 83.8 97.8 96.2 70.7 89.4
BertGCN 89.3 98.1 96.6 72.8 86.0
RoBERTaGCN 89.5 98.2 96.1 72.8 89.7
BertGAT 87.4 97.8 96.5 71.2 86.5
RoBERTaGAT 86.5 98.0 96.1 71.2 89.2

Dependencies

Create environment and install required packages for BertGCN using conda:

conda create --name BertGCN --file requirements.txt -c default -c pytorch -c dglteam -c huggingface

If the NVIDIA driver version does not support CUDA 10.1 you may edit requirements.txt to use older cudatooklit and the corresponding dgl instead.

Usage

  1. Run python build_graph.py [dataset] to build the text graph.

  2. Run python finetune_bert.py --dataset [dataset] to finetune the BERT model over target dataset. The model and training logs will be saved to checkpoint/[bert_init]_[dataset]/ by default. Run python finetune_bert.py -h to see the full list of hyperparameters.

  3. Run python train_bert_gcn.py --dataset [dataset] --pretrained_bert_ckpt [pretrained_bert_ckpt] -m [m] to train the BertGCN. [m] is the factor balancing BERT and GCN prediction (lambda in the paper). The model and training logs will be saved to checkpoint/[bert_init]_[gcn_model]_[dataset]/ by default. Run python train_bert_gcn.py -h to see the full list of hyperparameters.

Trained BertGCN parameters can be downloaded here.

Acknowledgement

The data preprocess and graph construction are from TextGCN

Citation

To appear in Findings of ACL 2021

@article{lin2021bertgcn,
  title={BertGCN: Transductive Text Classification by Combining GCN and BERT},
  author={Lin, Yuxiao and Meng, Yuxian and Sun, Xiaofei and Han, Qinghong and Kuang, Kun and Li, Jiwei and Wu, Fei},
  journal={arXiv preprint arXiv:2105.05727},
  year={2021}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%