forked from cheripai/en-ja-translator
-
Notifications
You must be signed in to change notification settings - Fork 2
/
process_data.py
158 lines (129 loc) · 4.83 KB
/
process_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import constants as c
import multiprocessing as mp
import numpy as np
import pickle
import re
import torch
from janome.tokenizer import Tokenizer
from Lang import Lang
def normalize_en(s):
""" Processes an English string by removing non-alphabetical characters (besides .!?).
"""
s = s.lower().strip()
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^\w.!?]+", r" ", s, flags=re.UNICODE)
return s
def normalize_ja(s, segmenter):
""" Processes a Japanese string by removing non-word characters and separating tokens with spaces.
"""
s = s.strip()
s = re.sub(r"[^\w.!?。]+", r" ", s, flags=re.UNICODE)
tokenized_list = []
for token in segmenter.tokenize(s):
tokenized_list.append(token.base_form)
s = " ".join(tokenized_list)
s = re.sub("\s+", " ", s).strip()
return s
def normalize(en_lines, ja_lines):
""" Process lists of both English and Japanese strings.
"""
tokenizer = Tokenizer()
return [[normalize_en(l1), normalize_ja(l2, tokenizer)] for l1, l2 in zip(en_lines, ja_lines)]
def read_langs(en_file, ja_file, n_processes=4):
""" Reads corpuses and returns a Lang object for each language and all normalized sentence pairs.
"""
en_lines = open(en_file, encoding="utf8", errors="ignore").read().split("\n")
ja_lines = open(ja_file, encoding="utf8", errors="ignore").read().split("\n")
pool = mp.Pool(processes=n_processes)
interval = len(en_lines) // n_processes
results = [
pool.apply_async(
normalize, args=(en_lines[i * interval:(i + 1) * interval], ja_lines[i * interval:(i + 1) * interval]))
for i in range(n_processes)
]
pairs = []
for p in results:
pairs += p.get()
en = Lang("en")
ja = Lang("ja")
return en, ja, pairs
def filter_pair_by_vocab(p, lang1, lang2):
""" Filter out sentences if they do not contain words in vocab.
"""
s1 = p[0].split(" ")
s2 = p[1].split(" ")
for word in s1:
if word not in lang1.word2index:
return False
for word in s2:
if word not in lang2.word2index:
return False
return True
def filter_pair_by_len(p, maxlen=c.MAX_LENGTH):
""" Filter out sentences if they are greater than maximum length.
"""
return len(p[0].split(" ")) < maxlen and len(p[1].split(" ")) < maxlen
def filter_vocab(lang, min_words=2):
""" Filters out words from Lang with counts less than min_words in place.
"""
remove_words = [key for key in lang.word2count.keys() if lang.word2count[key] < min_words]
for word in remove_words:
lang.remove_word(word)
def load_en_w2v(fname):
w2v_en = {}
with open(fname) as f:
for line in f:
line = line.split()
w = line[0]
v = np.array([float(x) for x in line[1:]])
w2v_en[w] = v
return w2v_en
def load_ja_w2v(fname):
w2v_ja = {}
with open(fname) as f:
for line in f:
if "[" in line:
line = line.strip().replace("[", "").split()
_, w, v = line[0], line[1], line[2:]
elif "]" in line:
line = line.strip().replace("]", "").split()
v += line
v = np.array([float(val) for val in v])
w2v_ja[w] = v
else:
line = line.strip().split()
v += line
return w2v_ja
def build_vecs(vocab, w2v, vec_dim=300):
vecs = np.zeros((len(vocab), vec_dim))
not_found = 0
for i, w in enumerate(vocab):
try:
vecs[i] = w2v[w]
except KeyError:
not_found += 1
vecs[i] = np.random.normal(scale=0.6, size=vec_dim)
print("Did not find {} words".format(not_found))
return torch.FloatTensor(vecs)
if __name__ == "__main__":
en, ja, pairs = read_langs(c.EN_PATH, c.JA_PATH)
print("Number of sentences:", len(pairs))
pairs = [pair for pair in pairs if filter_pair_by_len(pair, c.MAX_LENGTH)]
for pair in pairs:
en.add_sentence(pair[0])
ja.add_sentence(pair[1])
filter_vocab(en, c.MIN_VOCAB_WORDS)
filter_vocab(ja, c.MIN_VOCAB_WORDS)
pairs = [pair for pair in pairs if filter_pair_by_vocab(pair, en, ja)]
print("Number of trimmed sentences:", len(pairs))
print("Number of {} words: {}".format(en.name, en.n_words))
print("Number of {} words: {}".format(ja.name, ja.n_words))
w2v_en = load_en_w2v(c.EN_W2V_PATH)
w2v_ja = load_ja_w2v(c.JA_W2V_PATH)
en_vecs = build_vecs(list(en.word2index.keys()), w2v_en, vec_dim=300)
ja_vecs = build_vecs(list(ja.word2index.keys()), w2v_ja, vec_dim=300)
pickle.dump(en, open(c.EN_LANG_PATH, "wb"))
pickle.dump(ja, open(c.JA_LANG_PATH, "wb"))
pickle.dump(pairs, open(c.PAIRS_PATH, "wb"))
pickle.dump(en_vecs, open(c.EN_VECS_PATH, "wb"))
pickle.dump(ja_vecs, open(c.JA_VECS_PATH, "wb"))