Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[BUGFIX] update paths and imports in bert scripts (#634)
Browse files Browse the repository at this point in the history
* update paths and imports in bert scripts

* add commit ref for bert

* add option for comparing mxnet parameter file
  • Loading branch information
szha authored Mar 18, 2019
1 parent 25ce9d7 commit 386a45f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 20 deletions.
38 changes: 24 additions & 14 deletions scripts/bert/compare_tf_gluon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,16 @@
import numpy as np
import mxnet as mx
import gluonnlp as nlp
from gluonnlp.data import TSVDataset
from gluonnlp.data import BERTTokenizer
from gluon.data import BERTSentenceTransform

parser = argparse.ArgumentParser(description='Comparison script for BERT model in Tensorflow'
'and that in Gluon')
'and that in Gluon. This script works with '
'google/bert@f39e881b')
parser.add_argument('--input_file', type=str, default='input.txt',
help='sample input file for testing. Default is input.txt')
parser.add_argument('--tf_bert_repo_dir', type=str,
default='~/bert/',
help='path to the original Tensorflow bert repository. '
'The repo should be at f39e881b. '
'Default is ~/bert/')
parser.add_argument('--tf_model_dir', type=str,
default='~/uncased_L-12_H-768_A-12/',
Expand All @@ -48,15 +47,17 @@
help='gluon dataset name. Default is book_corpus_wiki_en_uncased')
parser.add_argument('--gluon_model', type=str, default='bert_12_768_12',
help='gluon model name. Default is bert_12_768_12')
parser.add_argument('--gluon_parameter_file', type=str, default=None,
help='gluon parameter file name.')

args = parser.parse_args()

input_file = os.path.expanduser(args.input_file)
tf_bert_repo_dir = os.path.expanduser(args.tf_bert_repo_dir)
tf_model_dir = os.path.expanduser(args.tf_model_dir)
vocab_file = tf_model_dir + 'vocab.txt'
bert_config_file = tf_model_dir + 'bert_config.json'
init_checkpoint = tf_model_dir + 'bert_model.ckpt'
vocab_file = os.path.join(tf_model_dir, 'vocab.txt')
bert_config_file = os.path.join(tf_model_dir, 'bert_config.json')
init_checkpoint = os.path.join(tf_model_dir, 'bert_model.ckpt')
do_lower_case = not args.cased
max_length = 128

Expand Down Expand Up @@ -130,13 +131,24 @@

bert, vocabulary = nlp.model.get_model(args.gluon_model,
dataset_name=args.gluon_dataset,
pretrained=True, use_pooler=False,
use_decoder=False, use_classifier=False)
pretrained=not args.gluon_parameter_file,
use_pooler=False,
use_decoder=False,
use_classifier=False)
if args.gluon_parameter_file:
try:
bert.cast('float16')
bert.load_parameters(args.gluon_parameter_file, ignore_extra=True)
bert.cast('float32')
except AssertionError:
bert.cast('float32')
bert.load_parameters(args.gluon_parameter_file, ignore_extra=True)

print(bert)
tokenizer = BERTTokenizer(vocabulary, lower=do_lower_case)
dataset = TSVDataset(input_file, field_separator=nlp.data.Splitter(' ||| '))
tokenizer = nlp.data.BERTTokenizer(vocabulary, lower=do_lower_case)
dataset = nlp.data.TSVDataset(input_file, field_separator=nlp.data.Splitter(' ||| '))

trans = BERTSentenceTransform(tokenizer, max_length)
trans = nlp.data.BERTSentenceTransform(tokenizer, max_length)
dataset = dataset.transform(trans)

bert_dataloader = mx.gluon.data.DataLoader(dataset, batch_size=1,
Expand All @@ -152,7 +164,5 @@
b = out[0][:length].asnumpy()

print('stdev = %s' % (np.std(a - b)))
mx.test_utils.assert_almost_equal(a, b, atol=1e-4, rtol=1e-4)
mx.test_utils.assert_almost_equal(a, b, atol=1e-5, rtol=1e-5)
mx.test_utils.assert_almost_equal(a, b, atol=5e-6, rtol=5e-6)
break
13 changes: 7 additions & 6 deletions scripts/bert/convert_tf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
help='BERT model name. options are bert_12_768_12 and bert_24_1024_16.'
'Default is bert_12_768_12')
parser.add_argument('--tf_checkpoint_dir', type=str,
default='/home/ubuntu/cased_L-12_H-768_A-12/',
default=os.path.join('~', 'cased_L-12_H-768_A-12/'),
help='Path to Tensorflow checkpoint folder. '
'Default is /home/ubuntu/cased_L-12_H-768_A-12/')
parser.add_argument('--out_dir', type=str,
default='/home/ubuntu/output/',
default=os.path.join('~', 'output'),
help='Path to output folder. The folder must exist. '
'Default is /home/ubuntu/output/')
parser.add_argument('--debug', action='store_true', help='debugging mode')
Expand All @@ -49,17 +49,18 @@
vocab, reserved_token_idx_map = convert_vocab(vocab_path)

# vocab serialization
tmp_file_path = os.path.join(args.out_dir, 'tmp')
tmp_file_path = os.path.expanduser(os.path.join(args.out_dir, 'tmp'))
with open(tmp_file_path, 'w') as f:
f.write(vocab.to_json())
hash_full, hash_short = get_hash(tmp_file_path)
gluon_vocab_path = os.path.join(args.out_dir, hash_short + '.vocab')
gluon_vocab_path = os.path.expanduser(os.path.join(args.out_dir, hash_short + '.vocab'))
with open(gluon_vocab_path, 'w') as f:
f.write(vocab.to_json())
logging.info('vocab file saved to %s. hash = %s', gluon_vocab_path, hash_full)

# load tf model
tf_checkpoint_file = os.path.join(args.tf_checkpoint_dir, 'bert_model.ckpt')
tf_checkpoint_file = os.path.expanduser(
os.path.join(args.tf_checkpoint_dir, 'bert_model.ckpt'))
logging.info('loading Tensorflow checkpoint %s ...', tf_checkpoint_file)
tf_tensors = read_tf_checkpoint(tf_checkpoint_file)
tf_names = sorted(tf_tensors.keys())
Expand Down Expand Up @@ -177,7 +178,7 @@
# param serialization
bert.save_parameters(tmp_file_path)
hash_full, hash_short = get_hash(tmp_file_path)
gluon_param_path = os.path.join(args.out_dir, hash_short + '.params')
gluon_param_path = os.path.expanduser(os.path.join(args.out_dir, hash_short + '.params'))
logging.info('param saved to %s. hash = %s', gluon_param_path, hash_full)
bert.save_parameters(gluon_param_path)
mx.nd.waitall()

0 comments on commit 386a45f

Please sign in to comment.