Code for Adversarial Training Methods for Semi-Supervised Text Classification and Semi-Supervised Sequence Learning.
- Bazel (install)
- TensorFlow >= v1.1
$ wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz \
-O /tmp/imdb.tar.gz
$ tar -xf /tmp/imdb.tar.gz -C /tmp
The directory /tmp/aclImdb
contains the raw IMDB data.
$ IMDB_DATA_DIR=/tmp/imdb
$ bazel run data:gen_vocab -- \
--output_dir=$IMDB_DATA_DIR \
--dataset=imdb \
--imdb_input_dir=/tmp/aclImdb \
--lowercase=False
Vocabulary and frequency files will be generated in $IMDB_DATA_DIR
.
$ bazel run data:gen_data -- \
--output_dir=$IMDB_DATA_DIR \
--dataset=imdb \
--imdb_input_dir=/tmp/aclImdb \
--lowercase=False \
--label_gain=False
$IMDB_DATA_DIR
contains TFRecords files.
$ PRETRAIN_DIR=/tmp/models/imdb_pretrain
$ bazel run :pretrain -- \
--train_dir=$PRETRAIN_DIR \
--data_dir=$IMDB_DATA_DIR \
--vocab_size=86934 \
--embedding_dims=256 \
--rnn_cell_size=1024 \
--num_candidate_samples=1024 \
--batch_size=256 \
--learning_rate=0.001 \
--learning_rate_decay_factor=0.9999 \
--max_steps=100000 \
--max_grad_norm=1.0 \
--num_timesteps=400 \
--keep_prob_emb=0.5 \
--normalize_embeddings
$PRETRAIN_DIR
contains checkpoints of the pretrained language model.
Most flags stay the same, save for the removal of candidate sampling and the
addition of pretrained_model_dir
, from which the classifier will load the
pretrained embedding and LSTM variables, and flags related to adversarial
training and classification.
$ TRAIN_DIR=/tmp/models/imdb_classify
$ bazel run :train_classifier -- \
--train_dir=$TRAIN_DIR \
--pretrained_model_dir=$PRETRAIN_DIR \
--data_dir=$IMDB_DATA_DIR \
--vocab_size=86934 \
--embedding_dims=256 \
--rnn_cell_size=1024 \
--cl_num_layers=1 \
--cl_hidden_size=30 \
--batch_size=64 \
--learning_rate=0.0005 \
--learning_rate_decay_factor=0.9998 \
--max_steps=15000 \
--max_grad_norm=1.0 \
--num_timesteps=400 \
--keep_prob_emb=0.5 \
--normalize_embeddings \
--adv_training_method=vat \
--perturb_norm_length=5.0
$ EVAL_DIR=/tmp/models/imdb_eval
$ bazel run :evaluate -- \
--eval_dir=$EVAL_DIR \
--checkpoint_dir=$TRAIN_DIR \
--eval_data=test \
--run_once \
--num_examples=25000 \
--data_dir=$IMDB_DATA_DIR \
--vocab_size=86934 \
--embedding_dims=256 \
--rnn_cell_size=1024 \
--batch_size=256 \
--num_timesteps=400 \
--normalize_embeddings
The main entry points are the binaries listed below. Each training binary builds
a VatxtModel
, defined in graphs.py
, which in turn uses graph building blocks
defined in inputs.py
(defines input data reading and parsing), layers.py
(defines core model components), and adversarial_losses.py
(defines
adversarial training losses). The training loop itself is defined in
train_utils.py
.
- Pretraining:
pretrain.py
- Classifier Training:
train_classifier.py
- Evaluation:
evaluate.py
Flags related to distributed training and the training loop itself are defined
in train_utils.py
.
Flags related to model hyperparameters are defined in graphs.py
.
Flags related to adversarial training are defined in adversarial_losses.py
.
Flags particular to each job are defined in the main binary files.
- Vocabulary generation:
gen_vocab.py
- Data generation:
gen_data.py
Command-line flags defined in document_generators.py
control which dataset is processed and how.
- Ryan Sepassi, @rsepassi