-
Notifications
You must be signed in to change notification settings - Fork 4
/
plot_training_history.py
139 lines (122 loc) · 4.7 KB
/
plot_training_history.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python
u"""
plot_training_history.py
by Yara Mohajerani (Last Update 10/2018)
Plot training history generated by unet_train.py
Update History
10/2018 Written
"""
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
#from scipy.interpolate import spline
#-- train model and make predictions
def plot_history(parameters):
glacier = parameters['GLACIER_NAME']
n_batch = int(parameters['BATCHES'])
n_epochs = int(parameters['EPOCHS'])
n_layers = int(parameters['LAYERS_DOWN'])
n_init = int(parameters['N_INIT'])
suffix = parameters['SUFFIX']
drop = float(parameters['DROPOUT'])
imb_str = '_%.2fweight'%(float(parameters['imb_str']))
#-- set up configurations based on parameters
if parameters['AUGMENT'] in ['Y','y']:
aug_config = np.int(parameters['AUG_CONFIG'])
aug_str = '_augment-x%i'%aug_config
else:
aug_config = 0
aug_str = ''
if parameters['CROP'] in ['Y','y']:
crop_str = '_cropped'
else:
crop_str = ''
if parameters['NORMALIZE'] in ['y','Y']:
normalize = True
norm_str = '_normalized'
else:
normalize = False
norm_str = ''
if parameters['LINEAR'] in ['Y','Y']:
linear = True
lin_str = '_linear'
else:
linear = False
lin_str = ''
drop_str = ''
if drop>0:
drop_str = '_w%.1fdrop'%drop
#-- plotting
if parameters['PLOT'] in ['y','Y']:
PLOT = True
else:
PLOT = False
if parameters['LABEL_WIDTH'] == '3':
lbl_width = ''
else:
lbl_width = '_%ipx'%int(parameters['LABEL_WIDTH'])
if (normalize) and (drop!=0):
sys.exit('Both batch normalization and dropout are selecte. Choose one.')
#-- directory setup
#- current directory
current_dir = os.path.dirname(os.path.realpath(__file__))
main_dir = os.path.join(current_dir,'..','FrontLearning_data')
glacier_ddir = os.path.join(main_dir,'%s.dir'%glacier)
data_dir = os.path.join(glacier_ddir, 'data')
trn_dir = os.path.join(data_dir,'train')
tst_dir = os.path.join(data_dir,'test')
#-- input file
infile = os.path.join(glacier_ddir,\
'training_history_%ibatches_%iepochs_%ilayers_%iinit%s%s%s%s%s%s%s%s.txt'\
%(n_batch,n_epochs,n_layers,n_init,lin_str,imb_str,drop_str,norm_str,\
aug_str,suffix,crop_str,lbl_width))
history = pd.read_table(infile,delim_whitespace=True)
epochs = history['Epoch'] + 1 #- start from 1 instead of 0
for item,name in zip(['acc','loss'],['Accuracy','Loss']):
fig = plt.figure(1,figsize=(8,6))
plt.plot(epochs,history[item],'b-')
plt.plot(epochs,history['val_%s'%item],'r-')
#-- also plot smoothd lines
epoch_smooth = np.linspace(np.min(history['Epoch']),np.max(history['Epoch']),100)
#smooth = spline(epochs,history[item],epoch_smooth)
#smooth_val = spline(epochs,history['val_%s'%item],epoch_smooth)
#plt.plot(epoch_smooth,smooth,'b:')
#plt.plot(epoch_smooth,smooth_val,'r:')
plt.title('Model %s'%name)
plt.ylabel(name)
plt.xlabel('Epochs')
plt.legend(['Training', 'Validation'], loc='upper left')
plt.savefig(os.path.join(glacier_ddir,\
'training_history_%s_%ibatches_%iepochs_%ilayers_%iinit%s%s%s%s%s%s%s%s.pdf'\
%(item,n_batch,n_epochs,n_layers,n_init,lin_str,imb_str,drop_str,norm_str,\
aug_str,suffix,crop_str,lbl_width)),format='pdf')
plt.close(fig)
#-- main function to get parameters and pass them along to fitting function
def main():
if (len(sys.argv) == 1):
sys.exit('You need to input at least one parameter file to set run configurations.')
else:
#-- Input Parameter Files (sys.argv[0] is the python code)
input_files = sys.argv[1:]
#-- for each input parameter file
for file in input_files:
#-- keep track of progress
print(os.path.basename(file))
#-- variable with parameter definitions
parameters = {}
#-- Opening parameter file and assigning file ID number (fid)
fid = open(file, 'r')
#-- for each line in the file will extract the parameter (name and value)
for fileline in fid:
#-- Splitting the input line between parameter name and value
part = fileline.split()
#-- filling the parameter definition variable
parameters[part[0]] = part[1]
#-- close the parameter file
fid.close()
#-- pass parameters to training function
plot_history(parameters)
if __name__ == '__main__':
main()