-
Notifications
You must be signed in to change notification settings - Fork 605
/
train-detector.py
125 lines (91 loc) · 3.89 KB
/
train-detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import sys
import numpy as np
import cv2
import argparse
import keras
from random import choice
from os.path import isfile, isdir, basename, splitext
from os import makedirs
from src.keras_utils import save_model, load_model
from src.label import readShapes
from src.loss import loss
from src.utils import image_files_from_folder, show
from src.sampler import augment_sample, labels2output_map
from src.data_generator import DataGenerator
from pdb import set_trace as pause
def load_network(modelpath,input_dim):
model = load_model(modelpath)
input_shape = (input_dim,input_dim,3)
# Fixed input size for training
inputs = keras.layers.Input(shape=(input_dim,input_dim,3))
outputs = model(inputs)
output_shape = tuple([s.value for s in outputs.shape[1:]])
output_dim = output_shape[1]
model_stride = input_dim / output_dim
assert input_dim % output_dim == 0, \
'The output resolution must be divisible by the input resolution'
assert model_stride == 2**4, \
'Make sure your model generates a feature map with resolution ' \
'16x smaller than the input'
return model, model_stride, input_shape, output_shape
def process_data_item(data_item,dim,model_stride):
XX,llp,pts = augment_sample(data_item[0],data_item[1].pts,dim)
YY = labels2output_map(llp,pts,dim,model_stride)
return XX,YY
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-m' ,'--model' ,type=str , required=True ,help='Path to previous model')
parser.add_argument('-n' ,'--name' ,type=str , required=True ,help='Model name')
parser.add_argument('-tr' ,'--train-dir' ,type=str , required=True ,help='Input data directory for training')
parser.add_argument('-its' ,'--iterations' ,type=int , default=300000 ,help='Number of mini-batch iterations (default = 300.000)')
parser.add_argument('-bs' ,'--batch-size' ,type=int , default=32 ,help='Mini-batch size (default = 32)')
parser.add_argument('-od' ,'--output-dir' ,type=str , default='./' ,help='Output directory (default = ./)')
parser.add_argument('-op' ,'--optimizer' ,type=str , default='Adam' ,help='Optmizer (default = Adam)')
parser.add_argument('-lr' ,'--learning-rate' ,type=float , default=.01 ,help='Optmizer (default = 0.01)')
args = parser.parse_args()
netname = basename(args.name)
train_dir = args.train_dir
outdir = args.output_dir
iterations = args.iterations
batch_size = args.batch_size
dim = 208
if not isdir(outdir):
makedirs(outdir)
model,model_stride,xshape,yshape = load_network(args.model,dim)
opt = getattr(keras.optimizers,args.optimizer)(lr=args.learning_rate)
model.compile(loss=loss, optimizer=opt)
print 'Checking input directory...'
Files = image_files_from_folder(train_dir)
Data = []
for file in Files:
labfile = splitext(file)[0] + '.txt'
if isfile(labfile):
L = readShapes(labfile)
I = cv2.imread(file)
Data.append([I,L[0]])
print '%d images with labels found' % len(Data)
dg = DataGenerator( data=Data, \
process_data_item_func=lambda x: process_data_item(x,dim,model_stride),\
xshape=xshape, \
yshape=(yshape[0],yshape[1],yshape[2]+1), \
nthreads=2, \
pool_size=1000, \
min_nsamples=100 )
dg.start()
Xtrain = np.empty((batch_size,dim,dim,3),dtype='single')
Ytrain = np.empty((batch_size,dim/model_stride,dim/model_stride,2*4+1))
model_path_backup = '%s/%s_backup' % (outdir,netname)
model_path_final = '%s/%s_final' % (outdir,netname)
for it in range(iterations):
print 'Iter. %d (of %d)' % (it+1,iterations)
Xtrain,Ytrain = dg.get_batch(batch_size)
train_loss = model.train_on_batch(Xtrain,Ytrain)
print '\tLoss: %f' % train_loss
# Save model every 1000 iterations
if (it+1) % 1000 == 0:
print 'Saving model (%s)' % model_path_backup
save_model(model,model_path_backup)
print 'Stopping data generator'
dg.stop()
print 'Saving model (%s)' % model_path_final
save_model(model,model_path_final)