-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain_attention.py
51 lines (43 loc) · 2.78 KB
/
train_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import math
from tensorflow.keras.optimizers import Adam
from metrics_losses.metrics_and_losses import *
from model_attention import get_model_attention
from data_generators.DataGenerator_pipal import DataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from data_generators.DataGenerator_h5 import DataGeneratorH5, DataGeneratorValH5
'''
This code train the models (without attention and batch normalization)
'''
# ********* Setting part (replace that with suitable setting from configs.txt) ********
configs = {'MODEL_NAME': "./models/model_attention_without_pretraining.hdf5", 'VALIDATION_RATIO': 0.2,
'TOTAL_NUM_TRAINING': 37120, 'TOTAL_NUM_VALIDATION': 9280,
'learning_rate': 1e-5, 'epochs': 20, 'batch_size': 16, 'alpha': 0.5,
'nfs': [8, 16, 32, 64, 64, 64, 32, 16, 8, 32], 'kss': [5, 5, 3, 3],
'dense_num': 32, 'l2_reg': [1e-6, 1e-4], 'mse_factor': 1.0, 'spearman_factor': 0.01,
'num_channels': 3, 'learning_rate_drop': 0.1, 'drop_out': 0.0, 'drop_learning_rate_after_epochs': 15.0,
'num_resblocks': 3, 'attention_flag': True}
training_generator = DataGenerator(False, configs['batch_size'])
validation_generator = DataGenerator(True, configs['batch_size'])
pretrained_weights = './pretrained_models/model_attention_pretrained_h5.hdf5'
use_pretrained_weights = False
# ********* End of Setting part (replace that with suitable setting from configs.txt) ********
def step_decay(epoch):
initial_lrate = configs['learning_rate']
drop = configs['learning_rate_drop']
epochs_drop = configs['drop_learning_rate_after_epochs']
lrate = initial_lrate * math.pow(drop, math.floor((1 + epoch) / epochs_drop))
return lrate
if __name__ == "__main__":
model = get_model_attention(configs['num_channels'], configs['nfs'], configs['kss'], configs['l2_reg'], configs['alpha'],
configs['drop_out'],
configs['dense_num'], configs['num_resblocks'], configs['attention_flag'])
if use_pretrained_weights:
model.load_weights(pretrained_weights)
optimizer = Adam(configs['learning_rate'])
checkpoint = ModelCheckpoint(configs['MODEL_NAME'], monitor='val_loss', verbose=1, save_best_only=True, mode='min')
lrate = LearningRateScheduler(step_decay, verbose=1)
model.compile(loss=Final_loss(configs['mse_factor'], configs['spearman_factor']), optimizer=optimizer,
metrics=[correlation_coefficient_loss, total_score])
model.fit(training_generator, validation_data=validation_generator, epochs=configs['epochs'],
callbacks=[checkpoint, lrate], steps_per_epoch=configs['TOTAL_NUM_TRAINING'] // configs['batch_size'],
validation_steps=configs['TOTAL_NUM_VALIDATION'] // configs['batch_size'], workers=1)