-
Notifications
You must be signed in to change notification settings - Fork 1
/
data_parser.py
executable file
·163 lines (132 loc) · 5.36 KB
/
data_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
data_parser.py
Functions for text processing for trump speech generator.
"""
import logging
import random
import numpy as np
from basic_config import Config
import re
def read_input():
"""
Reads the input file, get rid of newlines and empty lines, replace with space
"""
with open(Config.Train.training_file, "r") as f:
input_text = f.read()
input_text = re.sub("SPEECH\s+\d+", "", input_text) # Remove the speech headers
input_text = re.sub("\n+", "\n", input_text)
input_text = re.sub("\d+\\\d+\\\d+", "", input_text) # Remove the dates
input_text = re.sub(" +", " ", input_text) # Remove double spaces.
return ' '.join(input_text.splitlines())
def create_examples(input_string, ):
"""
from the input, produce examples where the input is a sequence of integers
representing a string of characters, and the target is the character immediately
following the input sequence
"""
sequences = []
targets = []
depths = []
Config.char2int = {c: i for i, c in enumerate(sorted(set(input_string)))}
# ToDo Discuss with Ben how we want to train on text shorter than the window size?
# Get all examples
if Config.dataset_size == -1:
# iterate over the file window by window
i = 0
while i + Config.sequence_length + 1 < len(input_string):
sequences += [[Config.char2int[c] for c in input_string[i: i + Config.sequence_length]]]
depths.append(Config.sequence_length)
targets += [Config.char2int[input_string[i + Config.sequence_length]]]
i += 1
else:
# get size many examples
for z in range(Config.dataset_size):
# get a random starting point
r = random.choice(range(len(input_string) - Config.sequence_length - 1))
sequences.append([Config.char2int[c] for c in input_string[r: r + Config.sequence_length]])
depths.append(Config.sequence_length)
targets.append(Config.char2int[input_string[r + Config.sequence_length]])
assert (len(sequences) == len(targets))
# Define how to randomly split the input data into train and test
shuffled_list = list(range(len(sequences)))
random.shuffle(shuffled_list)
# Determine whether to do a validation split
if Config.perform_validation():
split_point = int(Config.training_split_ratio * len(sequences))
else:
split_point = len(sequences)
Config.Train.x = [sequences[idx] for idx in shuffled_list[:split_point]]
Config.Train.depth = [depths[idx] for idx in shuffled_list[:split_point]]
Config.Train.t = list(map(lambda idx: _build_target_vector(targets[idx]),
shuffled_list[:split_point]))
if Config.perform_validation():
Config.Validation.x = [sequences[idx] for idx in shuffled_list[split_point:]]
Config.Validation.depth = [depths[idx] for idx in shuffled_list[split_point:]]
Config.Validation.t = list(map(lambda idx: _build_target_vector(targets[idx]),
shuffled_list[split_point:]))
def _build_input_sequence(int_sequence):
"""
One-Hot Sequence Builder
Converts a list of integers into a sequence of integers.
:param int_sequence: List of the character indices
:type int_sequence: List[int]
:return: Input sequence converted into a matrix of one hot rows
:rtype: np.ndarray
"""
assert (0 < len(int_sequence) <= Config.sequence_length)
one_hots = []
while len(one_hots) < Config.sequence_length:
idx = len(one_hots)
char_id = 0 # This is used to pad the list as needed
if idx < len(int_sequence):
char_id = int_sequence[idx]
vec = np.zeros([Config.vocab_size()])
vec[char_id] = 1
one_hots.append(vec)
seq = np.vstack(one_hots)
return seq
def _build_target_vector(idx):
"""
Creates a one hot vector for the target with "1" in the correct character
location and zero everywhere else.
:param idx: Integer corresponding to the expected character
:type idx: int
:return: One hot vector for the target character
:rtype: np.array
"""
assert (0 <= idx < Config.vocab_size())
one_hot = np.zeros([Config.vocab_size()])
one_hot[idx] = 1
return one_hot
def build_training_and_verification_sets():
"""
Training and Verification Set Builder
Builds the training and verification datasets. Depending on the
configuration, this may be from the source files or from pickled
files.
"""
if not Config.Train.restore:
input_str = read_input()
create_examples(input_str)
# Character to integer map required during text generation
Config.export_character_to_integer_map()
# Export the training and verification data in case
# the previous setup will be trained on aga
Config.export_train_and_verification_data()
Config.word_count = len(input_str.split(" "))
else:
Config.import_character_to_integer_map()
Config.import_train_and_verification_data()
Config.dataset_size = Config.Train.size() + Config.Validation.size()
_print_basic_text_statistics()
def _print_basic_text_statistics():
# Print basic statistics on the training set
logging.info("Total Number of Characters: %d" % Config.dataset_size)
if Config.word_count > 0:
logging.info("Total Word Count: \t%d" % Config.word_count)
logging.info("Vocabulary Size: \t%d" % Config.vocab_size())
logging.info("Training Set Size: \t%d" % Config.Train.size())
logging.info("Validation Set Size: \t%d" % Config.Validation.size())
# testing
if __name__ == '__main__':
build_training_and_verification_sets()