-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_irisAttention.py
62 lines (51 loc) · 2.25 KB
/
train_irisAttention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import numpy as np
import imageio
import keras
import cv2
import matplotlib.pyplot as plt
from keras.optimizers import *
from keras.callbacks import ModelCheckpoint,EarlyStopping, ReduceLROnPlateau
from utilities.datareader import datareader
from model.iris_detetction_model import irisAttention
dtreader = datareader(train_state=True)
valdtreader = datareader(train_state=False)
model = irisAttention()
model.summary()
ckpt_path = r'ckpt/irisAttention_CASIA-iris-distance.h5'
class loss_history(keras.callbacks.Callback):
def __init__(self, x=0):
self.x = x
def on_epoch_begin(self, epoch, logs={}):
bbox = self.model.predict(np.expand_dims(valdtreader.images[self.x], axis=0))
bbox = np.squeeze(bbox)
bbox = np.squeeze(bbox)
bbox = np.squeeze(bbox)
bbox[0] = bbox[0] * 256
bbox[1] = bbox[1] * 256
bbox[2] = bbox[2] * 256
bbox[3] = bbox[3] * 256
image = imageio.imread(os.path.join(r'dataset/images/'+valdtreader.imagelist[self.x].split(" ")[0]+'.jpeg'),as_gray=False, pilmode="RGB")
cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 2)
cv2.rectangle(image,
(int(valdtreader.bboxes[self.x, 0, 0, 0] * 256), int(valdtreader.bboxes[self.x, 0, 0, 1] * 256)),
(int(valdtreader.bboxes[self.x, 0, 0, 2] * 256), int(valdtreader.bboxes[self.x, 0, 0, 3] * 256)),
(0, 255, 0),
2)
plt.imshow(image)
plt.show()
model.compile(optimizer=Adam(lr=0.0001), loss=['mse'], metrics=['accuracy'])
if os.path.exists(ckpt_path):
model.load_weights(ckpt_path)
print('the checkpoint is loaded successfully.')
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1 ,patience=3 ,verbose=1, min_lr=0.00001)
earlystopper = EarlyStopping(patience=7, verbose=1)
checkpointer = ModelCheckpoint(ckpt_path,verbose=1,save_best_only=True)
r = model.fit(x=dtreader.images,
y=[dtreader.bboxes],
validation_data=(valdtreader.images,[ valdtreader.bboxes]),
callbacks=[loss_history(), checkpointer, earlystopper,reduce_lr],
epochs=40,
verbose=1,
batch_size=2,
shuffle=True)