-
Notifications
You must be signed in to change notification settings - Fork 4
/
data_util.py
357 lines (300 loc) · 14.1 KB
/
data_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
import itertools
import random
import re
import unicodedata
from config import *
import codecs
import csv
# 读数据文件的某些行来查看原始数据的格式
def printLines(file, n=10):
with open(file, 'rb') as datafile:
lines = datafile.readlines()
for line in lines[:n]:
print(line)
printLines(os.path.join(corpus, "movie_lines.txt"))
######################################################################
# 数据处理:创建一个格式良好的数据文件,其中每一行包含一个由tab制表符分隔的查询语句和响应语句对
# 三个辅助函数
# loadLines 把movie_lines.txt 文件切分成 (lineID, characterID, movieID, character, text)
# loadConversations 把上面的行group成一个个多轮的对话
# extractSentencePairs 从上面的每个对话中抽取句对
# 把每一行都parse成一个dict,key是lineID、characterID、movieID、character和text
# 分别代表这一行的ID、人物ID、电影ID,人物名称和文本。
# return:一个dict,key是lineID,value是一个dict。
# value这个dict的key是lineID、characterID、movieID、character和text
def loadLines(fileName, fields):
lines = {}
with open(fileName, 'r', encoding='iso-8859-1') as f:
for line in f:
values = line.split(" +++$+++ ")
# Extract fields
lineObj = {}
for i, field in enumerate(fields):
lineObj[field] = values[i]
lines[lineObj['lineID']] = lineObj
return lines
# 根据movie_conversations.txt文件和上输出的lines,把utterance组成对话。
# 最终输出一个list,这个list的每一个元素都是一个dict,key分别是character1ID、character2ID、movieID和utteranceIDs。
# 分别表示这对话的第一个人物的ID,第二个的ID,电影的ID以及它包含的utteranceIDs
# 最后根据lines,还给每一行的dict增加一个key为lines,其value是个list,包含所有utterance(上面得到的lines的value)
def loadConversations(fileName, lines, fields):
conversations = []
with open(fileName, 'r', encoding='iso-8859-1') as f:
for line in f:
values = line.split(" +++$+++ ")
# Extract fields
convObj = {}
for i, field in enumerate(fields):
convObj[field] = values[i]
# Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
lineIds = eval(convObj["utteranceIDs"])
# Reassemble lines
convObj["lines"] = []
for lineId in lineIds:
convObj["lines"].append(lines[lineId])
conversations.append(convObj)
return conversations
# 从对话中抽取句对
# 假设一段对话包含s1,s2,s3,s4这4个utterance
# 那么会返回3个句对:s1-s2,s2-s3和s3-s4。
def extractSentencePairs(conversations):
qa_pairs = []
for conversation in conversations:
# Iterate over all the lines of the conversation
for i in range(len(conversation["lines"]) - 1): # We ignore the last line (no answer for it)
inputLine = conversation["lines"][i]["text"].strip()
targetLine = conversation["lines"][i+1]["text"].strip()
# Filter wrong samples (if one of the lists is empty)
if inputLine and targetLine:
qa_pairs.append([inputLine, targetLine])
return qa_pairs
######################################################################
# 利用上面的3个函数对原始数据进行处理,最终得到formatted_movie_lines.txt
# Define path to new file
datafile = os.path.join(corpus, "formatted_movie_lines.txt")
delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))
# Initialize lines dict, conversations list, and field ids
lines = {}
conversations = []
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]
# Load lines and process conversations
print("\nProcessing corpus...")
lines = loadLines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS)
print("\nLoading conversations...")
conversations = loadConversations(os.path.join(corpus, "movie_conversations.txt"),
lines, MOVIE_CONVERSATIONS_FIELDS)
# Write new csv file
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
for pair in extractSentencePairs(conversations):
writer.writerow(pair)
# Print a sample of lines
print("\nSample lines from file:")
printLines(datafile)
######################################################################
# Load and trim data
# 创建词典
# 接下来需要构建词典然后把问答句对加载到内存里。
# 我们的输入是一个句对,每个句子都是词的序列,但是机器学习只能处理数值,因此我们需要建立词到数字ID的映射。
# 为此,我们会定义一个Voc类,它会保存词到ID的映射,同时也保存反向的从ID到词的映射。
# 除此之外,它还记录每个词出现的次数,以及总共出现的词的个数。这个类提供addWord方法来增加一个词, addSentence方法来增加句子,也提供方法trim来去除低频的词。
class Voc:
def __init__(self, name):
self.name = name
self.trimmed = False
self.word2index = {}
self.word2count = {}
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
self.num_words = 3 # Count SOS, EOS, PAD
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.num_words
self.word2count[word] = 1
self.index2word[self.num_words] = word
self.num_words += 1
else:
self.word2count[word] += 1
# Remove words below a certain count threshold
def trim(self, min_count):
if self.trimmed:
return
self.trimmed = True
keep_words = []
for k, v in self.word2count.items():
if v >= min_count:
keep_words.append(k)
print('keep_words {} / {} = {:.4f}'.format(
len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
))
# Reinitialize dictionaries
self.word2index = {}
self.word2count = {}
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
self.num_words = 3 # Count default tokens
for word in keep_words:
self.addWord(word)
######################################################################
# Now we can assemble our vocabulary and query/response sentence pairs.
#
# 有了上面的Voc类我们就可以通过问答句对来构建词典了。但是在构建之前我们需要进行一些预处理。
#
# 首先我们需要使用函数``unicodeToAscii``来把unicode字符变成ascii,比如把à变成a。注意,这里的代码只是用于处理西方文字,如果是中文,这个函数直接会丢弃掉。
# 接下来把所有字母变成小写同时丢弃掉字母和常见标点(.!?)之外的所有字符。
# 最后为了训练收敛,我们会用函数``filterPairs``去掉长度超过``MAX_LENGTH``的句子(句对)。
#
# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)
# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
# 变成小写、去掉前后空格,然后unicode变成ascii
s = unicodeToAscii(s.lower().strip())
# 在标点前增加空格,这样把标点当成一个词
s = re.sub(r"([.!?])", r" \1", s)
# 字母和标点之外的字符都变成空格
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
# 因为把不用的字符都变成空格,所以可能存在多个连续空格
# 下面的正则替换把多个空格变成一个空格,最后去掉前后空格
s = re.sub(r"\s+", r" ", s).strip()
return s
# 读取问答句对并且返回Voc词典对象
def readVocs(datafile, corpus_name):
print("Reading lines...")
# 文件每行读取到list lines中。
lines = open(datafile, encoding='utf-8').\
read().strip().split('\n')
# 每行用tab切分成问答两个句子,然后调用normalizeString函数进行处理。
pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
voc = Voc(corpus_name)
return voc, pairs
def filterPair(p):
return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH
# 过滤太长的句对
def filterPairs(pairs):
return [pair for pair in pairs if filterPair(pair)]
# 使用上面的函数进行处理,返回Voc对象和句对的list
def loadPrepareData(corpus, corpus_name, datafile, save_dir):
print("Start preparing training data ...")
voc, pairs = readVocs(datafile, corpus_name)
print("Read {!s} sentence pairs".format(len(pairs)))
pairs = filterPairs(pairs)
print("Trimmed to {!s} sentence pairs".format(len(pairs)))
print("Counting words...")
for pair in pairs:
voc.addSentence(pair[0])
voc.addSentence(pair[1])
print("Counted words:", voc.num_words)
return voc, pairs
# Load/Assemble voc and pairs
voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
print(pair)
######################################################################
# 1) 使用voc.trim函数去掉频次低于MIN_COUNT 的词。
#
# 2) 去掉包含低频词的句子(只保留这样的句子——每一个词都是高频的,也就是在voc中出现的)
#
MIN_COUNT = 3 # Minimum word count threshold for trimming
def trimRareWords(voc, pairs, MIN_COUNT):
# Trim words used under the MIN_COUNT from the voc
voc.trim(MIN_COUNT)
# Filter out pairs with trimmed words
keep_pairs = []
for pair in pairs:
input_sentence = pair[0]
output_sentence = pair[1]
keep_input = True
keep_output = True
# Check input sentence
for word in input_sentence.split(' '):
if word not in voc.word2index:
keep_input = False
break
# Check output sentence
for word in output_sentence.split(' '):
if word not in voc.word2index:
keep_output = False
break
# Only keep pairs that do not contain trimmed word(s) in their input or output sentence
if keep_input and keep_output:
keep_pairs.append(pair)
print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
return keep_pairs
# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT) # 得到词频小于阈值的修剪后的句对
######################################################################
# Prepare Data for Models
# 为模型准备数据
def indexesFromSentence(voc, sentence):
return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]
# l是多个长度不同句子(list),使用zip_longest padding成定长,长度为最长句子的长度。
def zeroPadding(l, fillvalue=PAD_token):
return list(itertools.zip_longest(*l, fillvalue=fillvalue))
# l是二维的padding后的list
# 返回m和l的大小一样,如果某个位置是padding,那么值为0,否则为1
# 得到二进制矩阵
def binaryMatrix(l, value=PAD_token):
m = []
for i, seq in enumerate(l):
m.append([])
for token in seq:
if token == PAD_token:
m[i].append(0)
else:
m[i].append(1)
return m
# 把输入句子变成ID,然后再padding,同时返回lengths这个list,标识实际长度。
# 返回的padVar是一个LongTensor,shape是(max_length, batch),
# lengths是一个list,长度为(batch,),表示每个句子的实际长度。
def inputVar(l, voc):
# 在下一步,这里句子到ID的转换,并且形状从(batch, max_length)变换到(max_length, batch)
indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
padList = zeroPadding(indexes_batch)
padVar = torch.LongTensor(padList)
return padVar, lengths
# 对输出句子进行padding,然后用binaryMatrix得到每个位置是padding(0)还是非padding,
# 同时返回最大最长句子的长度(也就是padding后的长度)
# 返回值padVar是LongTensor,shape是(max_target_length, batch)
# mask是ByteTensor,shape也是(max_target_length, batch)
def outputVar(l, voc):
indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
max_target_len = max([len(indexes) for indexes in indexes_batch])
padList = zeroPadding(indexes_batch)
mask = binaryMatrix(padList)
mask = torch.ByteTensor(mask)
padVar = torch.LongTensor(padList)
return padVar, mask, max_target_len
# Returns all items for a given batch of pairs
# 处理一个batch的pair句对
def batch2TrainData(voc, pair_batch):
pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
input_batch, output_batch = [], []
for pair in pair_batch:
input_batch.append(pair[0])
output_batch.append(pair[1])
inp, lengths = inputVar(input_batch, voc) # input_batch从(batch, max_length)到(max_length, batch)
output, mask, max_target_len = outputVar(output_batch, voc) # output_batch从(batch, max_length)到(max_length, batch)
return inp, lengths, output, mask, max_target_len
# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches
print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)