-
Notifications
You must be signed in to change notification settings - Fork 1
/
predict.py
52 lines (38 loc) · 1.19 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from model import get_Model
from config import *
import glob
import cv2
import numpy as np
def remove_dup(list_idx):
text=[0]
list_idx = list(np.argmax(list_idx[0, 2:], axis=1))
print(list_idx)
for i in range(len(list_idx)-1):
if list_idx[i]== text[-1] :
continue
text.append(list_idx[i])
return text
def decode(list_idx,vocab=i2c):
text=remove_dup(list_idx)
text=[x for x in text if x !=0 and x!=87]
text=[vocab[idx] for idx in text]
return ''.join(text)
if __name__ == '__main__':
model = get_Model(img_w, img_h,is_training=False)
try:
model.load_weights('model.hdf5')
print("...Previous weight data...")
except:
raise Exception("No weight file!")
test_imgs = glob.glob('test_images/*.png')
for test_img in test_imgs:
img = cv2.imread(test_img, cv2.IMREAD_GRAYSCALE)
img_pred = img.astype(np.float32)
img_pred = cv2.resize(img_pred, (img_w, img_h))
img_pred = (img_pred / 255.0) * 2.0 - 1.0
img_pred = img_pred.T
img_pred = np.expand_dims(img_pred, axis=-1)
img_pred = np.expand_dims(img_pred, axis=0)
net_out_value = model.predict(img_pred)
pred_texts = decode(net_out_value)
print(pred_texts)