-
Notifications
You must be signed in to change notification settings - Fork 19
/
data_analysis.py
124 lines (100 loc) · 4.4 KB
/
data_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Get class count
import json
import os
import collections
import csv
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--data_dir", default='data/woz', type=str)
args = parser.parse_args()
## Parsing ontology
if args.args.data_dir == "data/woz":
fp_ontology = open(os.path.join(args.data_dir, "ontology_dstc2_en.json"), "r")
ontology = json.load(fp_ontology)
ontology = ontology["informable"]
del ontology["request"]
for slot in ontology.keys():
ontology[slot].append("do not care")
ontology[slot].append("none")
fp_ontology.close()
# sorting the ontology according to the alphabetic order of the slots
ontology = collections.OrderedDict(sorted(ontology.items()))
ontology = ontology
target_slot = list(ontology.keys())
for i, slot in enumerate(target_slot):
if slot == "pricerange":
target_slot[i] = "price range"
elif args.data_dir == "data/multiwoz":
fp_ontology = open(os.path.join(args.data_dir, "ontology.json"), "r")
ontology = json.load(fp_ontology)
for slot in ontology.keys():
ontology[slot].append("none")
fp_ontology.close()
ontology = collections.OrderedDict(sorted(ontology.items()))
target_slot = list(ontology.keys())
else:
raise NotImplementedError()
nslot = len(ontology)
# count data of each slot-value
count = []
count_not_none = []
for slot in ontology.values():
slot_count = {}
for val in slot:
slot_count[val] = [0,0,0] #train, valid, test
count.append(slot_count)
count_not_none.append([0,0,0])
for d, data in enumerate(['train', 'dev', 'test']):
with open(os.path.join(args.data_dir, "%s.tsv" % data), "r", encoding='utf-8') as f:
reader = csv.reader(f, delimiter="\t")
for l, line in enumerate(reader):
if (args.data_dir == "data/woz") or (args.data_dir=='data/multiwoz' and l > 0):
for s in range(nslot):
val = line[4+s]
if val == 'dontcare':
val = 'do not care'
count[s][val][d] += 1
if val != 'none':
count_not_none[s][d] += 1
with open(os.path.join(args.data_dir, "train_analysis.txt"), "w", encoding='utf-8') as writer:
for i, c in enumerate(count):
writer.write('--- %s --- \n'% target_slot[i])
for k, v in c.items():
writer.write('%s\t%d\t%d\t%d\n' % (k, v[0], v[1], v[2]))
writer.close()
with open(os.path.join(args.data_dir, "data_not_none_analysis.txt"), "w", encoding='utf-8') as writer:
domain_data = {}
for i, slot in enumerate(ontology.keys()):
domain = slot.split('-')[0]
v = count_not_none[i]
if domain not in domain_data:
domain_data[domain] = [0, 0, 0]
for j, val in enumerate(v):
domain_data[domain][j] += val
writer.write('%s\t%d\t%d\t%d\n' % (slot, v[0], v[1], v[2]))
writer.write('----- total ----- \n')
for domain, v in domain_data.items():
writer.write('%s\t%d\t%d\t%d\n' % (domain, v[0], v[1], v[2]))
writer.close()
with open(os.path.join(args.data_dir, "none_ratio.txt"), "w", encoding='utf-8') as writer:
for i, slot in enumerate(ontology.keys()):
val = count_not_none[i]
none = count[i]['none']
ratio = [ n/(v+n) for v, n in zip(val, none) ]
writer.write('%s\t:\t%.6e\t%.6e\t%.6e\n' % (slot, ratio[0], ratio[1], ratio[2]))
writer.close()
# find common and different slots among domains
if args.data_dir == "data/multiwoz":
slot_dict = {}
for slot in target_slot:
domain = slot.split('-')[0]
slot_name = slot.split('-')[1]
if slot_name not in slot_dict:
slot_dict[slot_name] = []
slot_dict[slot_name].append(domain)
with open(os.path.join(args.data_dir, "domain_slot_analysis.txt"), "w", encoding='utf-8') as writer:
for slot, domains in slot_dict.items():
writer.write('%s\t%s\n'%(slot, ' '.join(domains)))
writer.close()
print(slot_dict)