-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate.py
96 lines (79 loc) · 3.03 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import string
import re
import numpy as np
import json
def get_exact_match(answers1, answers2):
if type(answers1)==list:
if len(answers1)==0:
return 0
return np.max([get_exact_match(a, answers2) for a in answers1])
if type(answers2)==list:
if len(answers2)==0:
return 0
return np.max([get_exact_match(answers1, a) for a in answers2])
return (normalize_answer(answers1) == normalize_answer(answers2))
def normalize_answer(s):
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def get_f1(answers, predictions, is_equal=get_exact_match, return_p_and_r=False):
'''
:answers: a list of list of strings
:predictions: a list of strings
'''
assert len(answers)>0 and len(predictions)>0, (answers, predictions)
occupied_answers = [False for _ in answers]
occupied_predictions = [False for _ in predictions]
for i, answer in enumerate(answers):
for j, prediction in enumerate(predictions):
if occupied_answers[i] or occupied_predictions[j]:
continue
em = is_equal(answer, prediction)
if em:
occupied_answers[i] = True
occupied_predictions[j] = True
assert np.sum(occupied_answers)==np.sum(occupied_predictions)
a, b = np.mean(occupied_answers), np.mean(occupied_predictions)
if return_p_and_r:
if a+b==0:
return 0., 0., 0.
return 2*a*b/(a+b), float(a), float(b)
if a+b==0:
return 0.
return 2*a*b/(a+b)
if __name__ == "__main__":
with open("predictions/TriviaQA_test_PromptRetrieve_preds.json", 'r') as f:
test = json.load(f)
answers = []
predictions = []
for eg in test:
ans = eg["gold_answer"][0].strip()
# ans = eg["label"]
pred = eg["prediction"]
# also tested a bunch of methods to increase accuracy, and found this one change.
# We sub out what is in between two parnethesis, instead of removing everything after
# parenthesis thus making the match score quite bad for a few results.
if(type(ans) != list):
ans = re.sub("[\(\[].*?[\)\]]", "", ans)
pred = re.sub("[\(\[].*?[\)\]]", "", pred)
# ans = ans.split('(')[0].strip()
answers.append(ans)
predictions.append(pred)
EM = 0
for i in range(len(answers)):
em = get_exact_match(answers[i], predictions[i])
EM += em
# if em == 0:
# print ("question: ", test[i]["text"])
# print ("gold: ", answers[i])
# print ("pred: ", predictions[i])
# print ('\n')
print (EM)
# print (get_f1(answers, predictions))