diff --git a/keras_spell.py b/keras_spell.py index 90da6ba..a6833e7 100644 --- a/keras_spell.py +++ b/keras_spell.py @@ -1,6 +1,5 @@ # encoding: utf-8 - from __future__ import print_function, division, unicode_literals import os @@ -17,7 +16,7 @@ from numpy import zeros as np_zeros # pylint:disable=no-name-in-module from keras.models import Sequential, load_model -from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, Dropout, recurrent +from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, Dropout, LSTM from keras.callbacks import Callback # Set a logger for the module @@ -50,7 +49,9 @@ class Configuration(object): CONFIG.number_of_iterations = 10 #pylint:enable=attribute-defined-outside-init -DIGEST = sha256(json.dumps(CONFIG.__dict__, sort_keys=True)).hexdigest() +#DIGEST = sha256(json.dumps(CONFIG.__dict__, sort_keys=True)).hexdigest() +DIGEST = sha256(json.dumps(CONFIG.__dict__, sort_keys=True).encode('utf-8')).hexdigest() + # Parameters for the dataset MIN_INPUT_LEN = 5 @@ -75,9 +76,9 @@ class Configuration(object): # Some cleanup: NORMALIZE_WHITESPACE_REGEX = re.compile(r'[^\S\n]+', re.UNICODE) # match all whitespace except newlines RE_DASH_FILTER = re.compile(r'[\-\˗\֊\‐\‑\‒\–\—\⁻\₋\−\﹣\-]', re.UNICODE) -RE_APOSTROPHE_FILTER = re.compile(r''|[ʼ՚'‘’‛❛❜ߴߵ`‵´ˊˋ{}{}{}{}{}{}{}{}{}]'.format(unichr(768), unichr(769), unichr(832), - unichr(833), unichr(2387), unichr(5151), - unichr(5152), unichr(65344), unichr(8242)), +RE_APOSTROPHE_FILTER = re.compile(r''|[ʼ՚'‘’‛❛❜ߴߵ`‵´ˊˋ{}{}{}{}{}{}{}{}{}]'.format( chr(768), chr(769), chr(832), + chr(833), chr(2387), chr(5151), + chr(5152), chr(65344), chr(8242)), re.UNICODE) RE_LEFT_PARENTH_FILTER = re.compile(r'[\(\[\{\⁽\₍\❨\❪\﹙\(]', re.UNICODE) RE_RIGHT_PARENTH_FILTER = re.compile(r'[\)\]\}\⁾\₎\❩\❫\﹚\)]', re.UNICODE) @@ -142,7 +143,7 @@ def _vectorize(questions, answers, ctable): """Vectorize the data as numpy arrays""" len_of_questions = len(questions) X = np_zeros((len_of_questions, CONFIG.max_input_len, ctable.size), dtype=np.bool) - for i in xrange(len(questions)): + for i in range(len(questions)): sentence = questions.pop() for j, c in enumerate(sentence): try: @@ -150,7 +151,7 @@ def _vectorize(questions, answers, ctable): except KeyError: pass # Padding y = np_zeros((len_of_questions, CONFIG.max_input_len, ctable.size), dtype=np.bool) - for i in xrange(len(answers)): + for i in range(len(answers)): sentence = answers.pop() for j, c in enumerate(sentence): try: @@ -213,14 +214,14 @@ def generate_model(output_len, chars=None): # note: in a situation where your input sequences have a variable length, # use input_shape=(None, nb_feature). for layer_number in range(CONFIG.input_layers): - model.add(recurrent.LSTM(CONFIG.hidden_size, input_shape=(None, len(chars)), kernel_initializer=CONFIG.initialization, + model.add(LSTM(CONFIG.hidden_size, input_shape=(None, len(chars)), kernel_initializer=CONFIG.initialization, return_sequences=layer_number + 1 < CONFIG.input_layers)) model.add(Dropout(CONFIG.amount_of_dropout)) # For the decoder's input, we repeat the encoded input for each time step model.add(RepeatVector(output_len)) # The decoder RNN could be multiple layers stacked or a single layer for _ in range(CONFIG.output_layers): - model.add(recurrent.LSTM(CONFIG.hidden_size, return_sequences=True, kernel_initializer=CONFIG.initialization)) + model.add(LSTM(CONFIG.hidden_size, return_sequences=True, kernel_initializer=CONFIG.initialization)) model.add(Dropout(CONFIG.amount_of_dropout)) # For each of step of the output sequence, decide which character should be chosen @@ -360,32 +361,55 @@ def clean_text(text): result = RE_BASIC_CLEANER.sub('', result) return result +# def preprocesses_data_clean(): +# """Pre-process the data - step 1 - cleanup""" +# with open(NEWS_FILE_NAME_CLEAN, "wb") as clean_data: +# for line in open(NEWS_FILE_NAME): +# decoded_line = line.decode('utf-8') +# cleaned_line = clean_text(decoded_line) +# encoded_line = cleaned_line.encode("utf-8") +# clean_data.write(encoded_line + b"\n") def preprocesses_data_clean(): """Pre-process the data - step 1 - cleanup""" with open(NEWS_FILE_NAME_CLEAN, "wb") as clean_data: - for line in open(NEWS_FILE_NAME): - decoded_line = line.decode('utf-8') - cleaned_line = clean_text(decoded_line) + for line in open(NEWS_FILE_NAME, encoding='utf-8'): + cleaned_line = clean_text(line) encoded_line = cleaned_line.encode("utf-8") clean_data.write(encoded_line + b"\n") + +# def preprocesses_data_analyze_chars(): +# """Pre-process the data - step 2 - analyze the characters""" +# counter = Counter() +# LOGGER.info("Reading data:") +# for line in open(NEWS_FILE_NAME_CLEAN): +# decoded_line = line.decode('utf-8') +# counter.update(decoded_line) +# # data = open(NEWS_FILE_NAME_CLEAN).read().decode('utf-8') +# # LOGGER.info("Read.\nCounting characters:") +# # counter = Counter(data.replace("\n", "")) +# LOGGER.info("Done.\nWriting to file:") +# with open(CHAR_FREQUENCY_FILE_NAME, 'wb') as output_file: +# output_file.write(json.dumps(counter)) +# most_popular_chars = {key for key, _value in counter.most_common(CONFIG.number_of_chars)} +# LOGGER.info("The top %s chars are:", CONFIG.number_of_chars) +# LOGGER.info("".join(sorted(most_popular_chars))) def preprocesses_data_analyze_chars(): """Pre-process the data - step 2 - analyze the characters""" counter = Counter() LOGGER.info("Reading data:") - for line in open(NEWS_FILE_NAME_CLEAN): - decoded_line = line.decode('utf-8') - counter.update(decoded_line) -# data = open(NEWS_FILE_NAME_CLEAN).read().decode('utf-8') -# LOGGER.info("Read.\nCounting characters:") -# counter = Counter(data.replace("\n", "")) + for line in open(NEWS_FILE_NAME_CLEAN, "r", encoding="utf-8"): + counter.update(line) + LOGGER.info("Done.\nWriting to file:") - with open(CHAR_FREQUENCY_FILE_NAME, 'wb') as output_file: + with open(CHAR_FREQUENCY_FILE_NAME, 'w', encoding="utf-8") as output_file: output_file.write(json.dumps(counter)) + most_popular_chars = {key for key, _value in counter.most_common(CONFIG.number_of_chars)} LOGGER.info("The top %s chars are:", CONFIG.number_of_chars) LOGGER.info("".join(sorted(most_popular_chars))) + def read_top_chars(): """Read the top chars we saved to file""" chars = json.loads(open(CHAR_FREQUENCY_FILE_NAME).read()) @@ -393,16 +417,29 @@ def read_top_chars(): most_popular_chars = {key for key, _value in counter.most_common(CONFIG.number_of_chars)} return most_popular_chars +# def preprocesses_data_filter(): +# """Pre-process the data - step 3 - filter only sentences with the right chars""" +# most_popular_chars = read_top_chars() +# LOGGER.info("Reading and filtering data:") +# with open(NEWS_FILE_NAME_FILTERED, "wb") as output_file: +# for line in open(NEWS_FILE_NAME_CLEAN): +# decoded_line = line.decode('utf-8') +# if decoded_line and not bool(set(decoded_line) - most_popular_chars): +# output_file.write(line) +# LOGGER.info("Done.") def preprocesses_data_filter(): - """Pre-process the data - step 3 - filter only sentences with the right chars""" - most_popular_chars = read_top_chars() - LOGGER.info("Reading and filtering data:") - with open(NEWS_FILE_NAME_FILTERED, "wb") as output_file: - for line in open(NEWS_FILE_NAME_CLEAN): - decoded_line = line.decode('utf-8') - if decoded_line and not bool(set(decoded_line) - most_popular_chars): - output_file.write(line) - LOGGER.info("Done.") + """Tiền xử lý dữ liệu - Bước 3 - Lọc bỏ các ký tự không mong muốn""" + LOGGER.info("Đọc và lọc dữ liệu:") + with open(NEWS_FILE_NAME_CLEAN, "r", encoding="utf-8") as clean_data: + lines = clean_data.readlines() + + filtered_lines = [clean_text(line) for line in lines] + + with open(NEWS_FILE_NAME_FILTERED, 'w', encoding="utf-8") as output_file: + output_file.writelines(filtered_lines) + + LOGGER.info("Hoàn thành.") + def read_filtered_data(): """Read the filtered data corpus""" @@ -487,17 +524,17 @@ def preprocesses_split_lines4(): FILTERED_W2V = "fw2v.bin" model = Word2Vec.load_word2vec_format(FILTERED_W2V, binary=True) # C text format print(len(model.wv.index2word)) -# answers = set() -# for encoded_line in open(NEWS_FILE_NAME_FILTERED): -# line = encoded_line.decode('utf-8') -# if line.count(" ") < 5: -# answers.add(line) -# LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) -# LOGGER.info("Here are some examples:") -# for answer in itertools.islice(answers, 10): -# LOGGER.info(answer) -# with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file: -# output_file.write("".join(answers).encode('utf-8')) + answers = set() + for encoded_line in open(NEWS_FILE_NAME_FILTERED): + line = encoded_line.decode('utf-8') + if line.count(" ") < 5: + answers.add(line) + LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers)) + LOGGER.info("Here are some examples:") + for answer in itertools.islice(answers, 10): + LOGGER.info(answer) + with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file: + output_file.write("".join(answers).encode('utf-8')) def preprocess_partition_data(): """Set asside data for validation""" @@ -565,12 +602,13 @@ def train_speller(from_file=None): if __name__ == '__main__': # download_the_news_data() # uncompress_data() -# preprocesses_data_clean() -# preprocesses_data_analyze_chars() -# preprocesses_data_filter() + preprocesses_data_clean() + preprocesses_data_analyze_chars() + preprocesses_data_filter() # preprocesses_split_lines() --- Choose this step or: -# preprocesses_split_lines2() + preprocesses_split_lines2() # preprocesses_split_lines4() -# preprocess_partition_data() -# train_speller(os.path.join(DATA_FILES_FULL_PATH, "keras_spell_e15.h5")) - train_speller() + preprocess_partition_data() + train_speller(os.path.join(DATA_FILES_FULL_PATH, "keras_spell_e15.h5")) + train_speller() +