-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
executable file
·62 lines (48 loc) · 2.04 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/python3
from __future__ import print_function # for vim and jedi using py2
import argparse
import csv
import joblib
import sys
import chios.question as cq
import chios.feats
def predict_and_dump(questions, featgen, cfier):
prf = open('prediction.csv', 'w')
anf = open('analysis.csv', 'w')
prcsv = csv.DictWriter(prf, fieldnames=['id', 'correctAnswer'])
prcsv.writeheader()
ancsv = csv.DictWriter(anf, fieldnames=['id', 'question', 'qNE', 'l', 'c', 'i', 'p', 'answer', 'aNE'] + featgen.labels())
ancsv.writeheader()
for i, q in enumerate(questions):
print('\rQuestion %d/%d' % (i, len(questions)), file=sys.stderr, end='')
s = featgen.score(q)
p = cfier.predict_proba(s)[:, 1]
choice = p.argmax()
prcsv.writerow({'id': q.id, 'correctAnswer': 'ABCD'[choice]})
qne = q.ne()
for i, a in enumerate(q.answers):
row = {
'id': q.id,
'question': ' '.join(q.tokens()),
'qNE': '; '.join(['%s(%.3f)' % (ne.label, ne.score) for ne in qne]),
'l': 'ABCD'[i],
'c': '*' if i == q.correct else '.',
'i': '+' if i == choice else '-',
'p': p[i],
'answer': ' '.join(a.tokens()),
'aNE': '; '.join(['%s(%.3f)' % (ne.label, ne.score) for ne in a.ne()])
}
row.update(dict(zip(featgen.labels(), ['%.2f' % (f,) for f in s[i]])))
ancsv.writerow(row)
print('', file=sys.stderr)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--glove-dim', type=int, default=50, help='embedding size (50, 100, 200, 300 only)')
parser.add_argument('TSVFILE', help='questions set')
args = parser.parse_args()
questions = cq.load_questions(args.TSVFILE)
featgen = chios.feats.FeatureGenerator(args.glove_dim)
cfier = joblib.load('data/model')
# cfier.coef_ = np.array([[0, 1]])
print('Initialized.', file=sys.stderr)
predict_and_dump(questions, featgen, cfier)