-
Notifications
You must be signed in to change notification settings - Fork 5
/
ted2.py
137 lines (116 loc) · 4.43 KB
/
ted2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import argparse
import rouge
from eval_utils import Meteor, stanford_parsetree_extractor, compute_tree_edit_distance
from tqdm import tqdm
import subprocess
from nltk.tree import Tree
def trim_tree_nltk(root, height):
try:
root.label()
except AttributeError:
return
if height < 1:
return
all_child_state = []
# print(root.label())
all_child_state.append(root.label())
if len(root) >= 1:
for child_index in range(len(root)):
child = root[child_index]
if trim_tree_nltk(child, height - 1):
all_child_state.append(trim_tree_nltk(child, height - 1))
# print(all_child_state)
return all_child_state
def string_comma(string):
start = 0
new_string = ''
while start < len(string):
if string[start:].find(",") == -1:
new_string += string[start:]
break
else:
index = string[start:].find(",")
if string[start - 2] != "(":
new_string += string[start:start + index]
new_string += " "
else:
new_string = new_string[:start - 1] + ", "
start = start + index + 1
return new_string
def clean_tuple_str(tuple_str):
new_str_ls = []
if len(tuple_str) == 1:
new_str_ls.append(tuple_str[0])
else:
for i in str(tuple_str).split(", "):
if i.count("'") == 2:
new_str_ls.append(i.replace("'", ""))
elif i.count("'") == 1:
new_str_ls.append(i.replace("\"", ""))
str_join = ' '.join(ele for ele in new_str_ls)
return string_comma(str_join)
def to_tuple(lst):
return tuple(to_tuple(i) if isinstance(i, list) else i for i in lst)
def get_syntax_templates(template_file):
parses = [test_str.split("<sep>")[-1].strip() for test_str in open(template_file).readlines()]
parses = [clean_tuple_str(to_tuple(trim_tree_nltk(Tree.fromstring(parse_str), 3))) for
parse_str in parses]
return parses
parser = argparse.ArgumentParser()
parser.add_argument('--input_file', '-i', type=str, help="full generated file, ")
parser.add_argument('--select_file', '-s', type=str)
parser.add_argument('--temp_file', '-t', type=str)
args = parser.parse_args()
n_select_line = len(list(open(args.select_file)))
input_lines = [line.strip("\n").strip() for line in open(args.input_file, "r").readlines()]
indices = []
for line in open(args.select_file, "r").readlines():
new_line = line.strip("\n").strip()
indices.append(input_lines.index(new_line))
temp_parses = ""
if "scpn" in args.input_file.lower():
templates = [
'( ROOT ( S ( NP ) ( VP ) ( . ) ) )',
'( ROOT ( S ( VP ) ( . ) ) )',
'( ROOT ( NP ( NP ) ( . ) ) )',
'( ROOT ( FRAG ( SBAR ) ( . ) ) )',
'( ROOT ( S ( S ) ( , ) ( CC ) ( S ) ( . ) ) )',
'( ROOT ( S ( LST ) ( VP ) ( . ) ) )',
'( ROOT ( SBARQ ( WHADVP ) ( SQ ) ( . ) ) )',
'( ROOT ( S ( PP ) ( , ) ( NP ) ( VP ) ( . ) ) )',
'( ROOT ( S ( ADVP ) ( NP ) ( VP ) ( . ) ) )',
'( ROOT ( S ( SBAR ) ( , ) ( NP ) ( VP ) ( . ) ) )'
]
if "qqpp" in args.input_file.lower():
temp_parses = templates * 3000
elif "paranmt" in args.input_file.lower():
temp_parses = templates * 800
else:
temp_parses = get_syntax_templates(args.temp_file)
temp_parses = [clean_tuple_str(to_tuple(trim_tree_nltk(Tree.fromstring(parse_str), 3))) for
parse_str in temp_parses]
if not isinstance(temp_parses, list):
raise Exception("template parses are not a items of a list!")
temp_parses = [temp_parses[i] for i in indices]
print("#lines - select: {}, temp: {}".format(n_select_line, len(temp_parses)))
assert n_select_line == len(temp_parses), \
"#select {} != #templates {}".format(n_select_line, temp_parses)
spe = stanford_parsetree_extractor()
select_parses = spe.run(args.select_file)
select_parses = [clean_tuple_str(to_tuple(trim_tree_nltk(Tree.fromstring(parse_str), 3))) for
parse_str in select_parses]
spe.cleanup()
all_ted = []
all_ted_t = []
# Default F1_score
pbar = tqdm(zip(select_parses, temp_parses))
for select_parse, temp_parse in pbar:
ted_t = compute_tree_edit_distance(select_parse, temp_parse)
all_ted_t.append(ted_t)
pbar.set_description(
"ted-e: {:.3f}".format(
sum(all_ted_t) / len(all_ted_t)
))
print("ted-e: {:.3f}".format(
sum(all_ted_t) / len(all_ted_t)
))