forked from philipperemy/deep-speaker
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
109 lines (91 loc) · 5.21 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import logging
import os
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers import SGD
from tqdm import tqdm
from batcher import KerasFormatConverter, LazyTripletBatcher
from constants import BATCH_SIZE, CHECKPOINTS_SOFTMAX_DIR, CHECKPOINTS_TRIPLET_DIR, NUM_FRAMES, NUM_FBANKS
from conv_models import DeepSpeakerModel
from triplet_loss import deep_speaker_loss
from utils import load_best_checkpoint, ensures_dir
logger = logging.getLogger(__name__)
# Otherwise it's just too much logging from Tensorflow...
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
def fit_model(dsm: DeepSpeakerModel, working_dir: str, max_length: int = NUM_FRAMES, batch_size=BATCH_SIZE):
batcher = LazyTripletBatcher(working_dir, max_length, dsm)
# build small test set.
test_batches = []
for _ in tqdm(range(200), desc='Build test set'):
test_batches.append(batcher.get_batch_test(batch_size))
def test_generator():
while True:
for bb in test_batches:
yield bb
def train_generator():
while True:
yield batcher.get_random_batch(batch_size, is_test=False)
checkpoint_name = dsm.m.name + '_checkpoint'
checkpoint_filename = os.path.join(CHECKPOINTS_TRIPLET_DIR, checkpoint_name + '_{epoch}.h5')
checkpoint = ModelCheckpoint(monitor='val_loss', filepath=checkpoint_filename, save_best_only=True)
dsm.m.fit(x=train_generator(), y=None, steps_per_epoch=2000, shuffle=False,
epochs=1000, validation_data=test_generator(), validation_steps=len(test_batches),
callbacks=[checkpoint])
def fit_model_softmax(dsm: DeepSpeakerModel, kx_train, ky_train, kx_test, ky_test,
batch_size=BATCH_SIZE, max_epochs=1000, initial_epoch=0):
checkpoint_name = dsm.m.name + '_checkpoint'
checkpoint_filename = os.path.join(CHECKPOINTS_SOFTMAX_DIR, checkpoint_name + '_{epoch}.h5')
checkpoint = ModelCheckpoint(monitor='val_accuracy', filepath=checkpoint_filename, save_best_only=True)
# if the accuracy does not increase by 0.1% over 20 epochs, we stop the training.
early_stopping = EarlyStopping(monitor='val_accuracy', min_delta=0.001, patience=20, verbose=1, mode='max')
# if the accuracy does not increase over 10 epochs, we reduce the learning rate by half.
reduce_lr = ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=10, min_lr=0.0001, verbose=1)
max_len_train = len(kx_train) - len(kx_train) % batch_size
kx_train = kx_train[0:max_len_train]
ky_train = ky_train[0:max_len_train]
max_len_test = len(kx_test) - len(kx_test) % batch_size
kx_test = kx_test[0:max_len_test]
ky_test = ky_test[0:max_len_test]
dsm.m.fit(x=kx_train,
y=ky_train,
batch_size=batch_size,
epochs=initial_epoch + max_epochs,
initial_epoch=initial_epoch,
verbose=1,
shuffle=True,
validation_data=(kx_test, ky_test),
callbacks=[early_stopping, reduce_lr, checkpoint])
def start_training(working_dir, pre_training_phase=True):
ensures_dir(CHECKPOINTS_SOFTMAX_DIR)
ensures_dir(CHECKPOINTS_TRIPLET_DIR)
batch_input_shape = [None, NUM_FRAMES, NUM_FBANKS, 1]
if pre_training_phase:
logger.info('Softmax pre-training.')
kc = KerasFormatConverter(working_dir)
num_speakers_softmax = len(kc.categorical_speakers.speaker_ids)
dsm = DeepSpeakerModel(batch_input_shape, include_softmax=True, num_speakers_softmax=num_speakers_softmax)
dsm.m.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
pre_training_checkpoint = load_best_checkpoint(CHECKPOINTS_SOFTMAX_DIR)
if pre_training_checkpoint is not None:
initial_epoch = int(pre_training_checkpoint.split('/')[-1].split('.')[0].split('_')[-1])
logger.info(f'Initial epoch is {initial_epoch}.')
logger.info(f'Loading softmax checkpoint: {pre_training_checkpoint}.')
dsm.m.load_weights(pre_training_checkpoint) # latest one.
else:
initial_epoch = 0
fit_model_softmax(dsm, kc.kx_train, kc.ky_train, kc.kx_test, kc.ky_test, initial_epoch=initial_epoch)
else:
logger.info('Training with the triplet loss.')
dsm = DeepSpeakerModel(batch_input_shape, include_softmax=False)
triplet_checkpoint = load_best_checkpoint(CHECKPOINTS_TRIPLET_DIR)
pre_training_checkpoint = load_best_checkpoint(CHECKPOINTS_SOFTMAX_DIR)
if triplet_checkpoint is not None:
logger.info(f'Loading triplet checkpoint: {triplet_checkpoint}.')
dsm.m.load_weights(triplet_checkpoint)
elif pre_training_checkpoint is not None:
logger.info(f'Loading pre-training checkpoint: {pre_training_checkpoint}.')
# If `by_name` is True, weights are loaded into layers only if they share the
# same name. This is useful for fine-tuning or transfer-learning models where
# some of the layers have changed.
dsm.m.load_weights(pre_training_checkpoint, by_name=True)
dsm.m.compile(optimizer=SGD(), loss=deep_speaker_loss)
fit_model(dsm, working_dir, NUM_FRAMES)