-
Notifications
You must be signed in to change notification settings - Fork 3
/
save_samples.py
82 lines (61 loc) · 2.09 KB
/
save_samples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import tensorflow as tf
from os.path import join
import getopt
import sys
from LSTMModel import LSTMModel
from data_reader import DataReader
import constants as c
def process_sample(string):
words = string.split()
#remove everything before the first line break
words = words[words.index('*break*'):]
#remove opening line breaks
while words[0] == '*break':
words = words[1:]
newString = ' '.join(words)
newString = newString.replace('*break*', '\n')
return newString
def save(artist, model_path, num_save):
sample_save_dir = c.get_dir('../save/samples/')
sess = tf.Session()
print artist
data_reader = DataReader(artist)
vocab = data_reader.get_vocab()
print 'Init model...'
model = LSTMModel(sess,
vocab,
c.BATCH_SIZE,
c.SEQ_LEN,
c.CELL_SIZE,
c.NUM_LAYERS,
test=True)
saver = tf.train.Saver()
sess.run(tf.initialize_all_variables())
saver.restore(sess, model_path)
print 'Model restored from ' + model_path
artist_save_dir = c.get_dir(join(sample_save_dir, artist))
for i in xrange(num_save):
print i
path = join(artist_save_dir, str(i) + '.txt')
sample = model.generate()
processed_sample = process_sample(sample)
with open(path, 'w') as f:
f.write(processed_sample)
def main():
artist = 'kanye_west'
model_path = '../save/models/kanye_west/kanye_west.ckpt-30000'
num_save = 1000
try:
opts, _ = getopt.getopt(sys.argv[1:], 'l:a:N:', ['load_path=', 'artist_name=', 'num_save='])
except getopt.GetoptError:
sys.exit(2)
for opt, arg in opts:
if opt in ('-l', '--load_path'):
model_path = arg
if opt in ('-a', '--artist_name'):
artist = arg
if opt in ('-n', '--num_save'):
num_save = int(arg)
save(artist, model_path, num_save)
if __name__ == '__main__':
main()