-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathpredict.py
77 lines (55 loc) · 2.12 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# -*- coding: utf-8 -*-
"""
Created on Thu May 30 17:12:37 2019
@author: cm
"""
import os
import sys
pwd = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import numpy as np
import tensorflow as tf
from classifier_multi_label_denses.networks import NetworkAlbert
from classifier_multi_label_denses.classifier_utils import get_feature_test,id2label
from classifier_multi_label_denses.hyperparameters import Hyperparamters as hp
class ModelAlbertDenses(object,):
"""
Load NetworkAlbertDenses
"""
def __init__(self):
self.albert, self.sess = self.load_model()
@staticmethod
def load_model():
with tf.Graph().as_default():
sess = tf.Session()
out_dir = os.path.join(pwd, "model")
with sess.as_default():
albert = NetworkAlbert(is_training=False)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
checkpoint_dir = os.path.abspath(os.path.join(out_dir,hp.inference_model))
print (checkpoint_dir)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
saver.restore(sess, ckpt.model_checkpoint_path)
return albert,sess
MODEL = ModelAlbertDenses()
print('Load model finished!')
def get_label(sentence):
"""
Prediction of the sentence's label.
"""
feature = get_feature_test(sentence)
fd = {MODEL.albert.input_ids: [feature[0]],
MODEL.albert.input_masks: [feature[1]],
MODEL.albert.segment_ids:[feature[2]],
}
prediction = MODEL.sess.run(MODEL.albert.predictions, feed_dict=fd)[0]
return [id2label(l) for l in np.where(prediction==1)[0] if l!=0]
if __name__ == '__main__':
##
sentences = ['外形外观:好看',
'有股难闻的怪味',
'制热效果很好',
'开一晚上也感觉不到热']
for sentence in sentences:
print(get_label(sentence))