-
Notifications
You must be signed in to change notification settings - Fork 80
/
test.py
89 lines (70 loc) · 2.88 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import tensorflow as tf
import numpy as np
import sys
import os
import data_input
import librosa
from tqdm import tqdm
import argparse
import audio
def test(model, config, prompts):
sr = 24000 if 'blizzard' in config.data_path else 16000
meta = data_input.load_meta(config.data_path)
config.r = audio.r
ivocab = meta['vocab']
config.vocab_size = len(ivocab)
with tf.device('/cpu:0'):
batch_inputs = data_input.load_prompts(prompts, ivocab)
config.num_prompts = len(prompts)
with tf.Session() as sess:
stft_mean = tf.get_variable('stft_mean', shape=(1025*audio.r,), dtype=tf.float16)
stft_std = tf.get_variable('stft_std', shape=(1025*audio.r,), dtype=tf.float32)
# initialize model
model = model(config, batch_inputs, train=False)
train_writer = tf.summary.FileWriter('log/' + config.save_path + '/test', sess.graph)
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
saver = tf.train.Saver()
print('restoring weights')
latest_ckpt = tf.train.latest_checkpoint(
'weights/' + config.save_path[:config.save_path.rfind('/')]
)
saver.restore(sess, latest_ckpt)
stft_mean, stft_std = sess.run([stft_mean, stft_std])
try:
while(True):
out = sess.run([
model.output,
model.alignments,
batch_inputs
])
outputs, alignments, inputs = out
print('saving samples')
for out, words, align in zip(outputs, inputs['text'], alignments):
# store a sample to listen to
text = ''.join([ivocab[w] for w in words])
attention_plot = data_input.generate_attention_plot(align)
sample = audio.invert_spectrogram(out*stft_std + stft_mean)
merged = sess.run(tf.summary.merge(
[tf.summary.audio(text, sample[None, :], sr),
tf.summary.image(text, attention_plot)]
))
train_writer.add_summary(merged, 0)
except tf.errors.OutOfRangeError:
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--train-set', default='nancy')
args = parser.parse_args()
prompts = sys.stdin.readlines()
prompts = [p for p in prompts if len(p) > 0]
from models.tacotron import Tacotron, Config
model = Tacotron
config = Config()
config.data_path = 'data/%s/' % args.train_set
config.save_path = args.train_set + '/tacotron'
print('Buliding Tacotron')
test(model, config, prompts)