Skip to content

Commit

Permalink
Merge pull request #1 from dinhno12313/dinhno12313-patch-1
Browse files Browse the repository at this point in the history
Update keras_spell.py
  • Loading branch information
dinhno12313 authored Dec 7, 2023
2 parents 8b4a5b5 + ed17e7b commit 724e6b9
Showing 1 changed file with 85 additions and 47 deletions.
132 changes: 85 additions & 47 deletions keras_spell.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# encoding: utf-8


from __future__ import print_function, division, unicode_literals

import os
Expand All @@ -17,7 +16,7 @@
from numpy import zeros as np_zeros # pylint:disable=no-name-in-module

from keras.models import Sequential, load_model
from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, Dropout, recurrent
from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, Dropout, LSTM
from keras.callbacks import Callback

# Set a logger for the module
Expand Down Expand Up @@ -50,7 +49,9 @@ class Configuration(object):
CONFIG.number_of_iterations = 10
#pylint:enable=attribute-defined-outside-init

DIGEST = sha256(json.dumps(CONFIG.__dict__, sort_keys=True)).hexdigest()
#DIGEST = sha256(json.dumps(CONFIG.__dict__, sort_keys=True)).hexdigest()
DIGEST = sha256(json.dumps(CONFIG.__dict__, sort_keys=True).encode('utf-8')).hexdigest()


# Parameters for the dataset
MIN_INPUT_LEN = 5
Expand All @@ -75,9 +76,9 @@ class Configuration(object):
# Some cleanup:
NORMALIZE_WHITESPACE_REGEX = re.compile(r'[^\S\n]+', re.UNICODE) # match all whitespace except newlines
RE_DASH_FILTER = re.compile(r'[\-\˗\֊\‐\‑\‒\–\—\⁻\₋\−\﹣\-]', re.UNICODE)
RE_APOSTROPHE_FILTER = re.compile(r''|[ʼ՚'‘’‛❛❜ߴߵ`‵´ˊˋ{}{}{}{}{}{}{}{}{}]'.format(unichr(768), unichr(769), unichr(832),
unichr(833), unichr(2387), unichr(5151),
unichr(5152), unichr(65344), unichr(8242)),
RE_APOSTROPHE_FILTER = re.compile(r''|[ʼ՚'‘’‛❛❜ߴߵ`‵´ˊˋ{}{}{}{}{}{}{}{}{}]'.format( chr(768), chr(769), chr(832),
chr(833), chr(2387), chr(5151),
chr(5152), chr(65344), chr(8242)),
re.UNICODE)
RE_LEFT_PARENTH_FILTER = re.compile(r'[\(\[\{\⁽\₍\❨\❪\﹙\(]', re.UNICODE)
RE_RIGHT_PARENTH_FILTER = re.compile(r'[\)\]\}\⁾\₎\❩\❫\﹚\)]', re.UNICODE)
Expand Down Expand Up @@ -142,15 +143,15 @@ def _vectorize(questions, answers, ctable):
"""Vectorize the data as numpy arrays"""
len_of_questions = len(questions)
X = np_zeros((len_of_questions, CONFIG.max_input_len, ctable.size), dtype=np.bool)
for i in xrange(len(questions)):
for i in range(len(questions)):
sentence = questions.pop()
for j, c in enumerate(sentence):
try:
X[i, j, ctable.char_indices[c]] = 1
except KeyError:
pass # Padding
y = np_zeros((len_of_questions, CONFIG.max_input_len, ctable.size), dtype=np.bool)
for i in xrange(len(answers)):
for i in range(len(answers)):
sentence = answers.pop()
for j, c in enumerate(sentence):
try:
Expand Down Expand Up @@ -213,14 +214,14 @@ def generate_model(output_len, chars=None):
# note: in a situation where your input sequences have a variable length,
# use input_shape=(None, nb_feature).
for layer_number in range(CONFIG.input_layers):
model.add(recurrent.LSTM(CONFIG.hidden_size, input_shape=(None, len(chars)), kernel_initializer=CONFIG.initialization,
model.add(LSTM(CONFIG.hidden_size, input_shape=(None, len(chars)), kernel_initializer=CONFIG.initialization,
return_sequences=layer_number + 1 < CONFIG.input_layers))
model.add(Dropout(CONFIG.amount_of_dropout))
# For the decoder's input, we repeat the encoded input for each time step
model.add(RepeatVector(output_len))
# The decoder RNN could be multiple layers stacked or a single layer
for _ in range(CONFIG.output_layers):
model.add(recurrent.LSTM(CONFIG.hidden_size, return_sequences=True, kernel_initializer=CONFIG.initialization))
model.add(LSTM(CONFIG.hidden_size, return_sequences=True, kernel_initializer=CONFIG.initialization))
model.add(Dropout(CONFIG.amount_of_dropout))

# For each of step of the output sequence, decide which character should be chosen
Expand Down Expand Up @@ -360,49 +361,85 @@ def clean_text(text):
result = RE_BASIC_CLEANER.sub('', result)
return result

# def preprocesses_data_clean():
# """Pre-process the data - step 1 - cleanup"""
# with open(NEWS_FILE_NAME_CLEAN, "wb") as clean_data:
# for line in open(NEWS_FILE_NAME):
# decoded_line = line.decode('utf-8')
# cleaned_line = clean_text(decoded_line)
# encoded_line = cleaned_line.encode("utf-8")
# clean_data.write(encoded_line + b"\n")
def preprocesses_data_clean():
"""Pre-process the data - step 1 - cleanup"""
with open(NEWS_FILE_NAME_CLEAN, "wb") as clean_data:
for line in open(NEWS_FILE_NAME):
decoded_line = line.decode('utf-8')
cleaned_line = clean_text(decoded_line)
for line in open(NEWS_FILE_NAME, encoding='utf-8'):
cleaned_line = clean_text(line)
encoded_line = cleaned_line.encode("utf-8")
clean_data.write(encoded_line + b"\n")


# def preprocesses_data_analyze_chars():
# """Pre-process the data - step 2 - analyze the characters"""
# counter = Counter()
# LOGGER.info("Reading data:")
# for line in open(NEWS_FILE_NAME_CLEAN):
# decoded_line = line.decode('utf-8')
# counter.update(decoded_line)
# # data = open(NEWS_FILE_NAME_CLEAN).read().decode('utf-8')
# # LOGGER.info("Read.\nCounting characters:")
# # counter = Counter(data.replace("\n", ""))
# LOGGER.info("Done.\nWriting to file:")
# with open(CHAR_FREQUENCY_FILE_NAME, 'wb') as output_file:
# output_file.write(json.dumps(counter))
# most_popular_chars = {key for key, _value in counter.most_common(CONFIG.number_of_chars)}
# LOGGER.info("The top %s chars are:", CONFIG.number_of_chars)
# LOGGER.info("".join(sorted(most_popular_chars)))
def preprocesses_data_analyze_chars():
"""Pre-process the data - step 2 - analyze the characters"""
counter = Counter()
LOGGER.info("Reading data:")
for line in open(NEWS_FILE_NAME_CLEAN):
decoded_line = line.decode('utf-8')
counter.update(decoded_line)
# data = open(NEWS_FILE_NAME_CLEAN).read().decode('utf-8')
# LOGGER.info("Read.\nCounting characters:")
# counter = Counter(data.replace("\n", ""))
for line in open(NEWS_FILE_NAME_CLEAN, "r", encoding="utf-8"):
counter.update(line)

LOGGER.info("Done.\nWriting to file:")
with open(CHAR_FREQUENCY_FILE_NAME, 'wb') as output_file:
with open(CHAR_FREQUENCY_FILE_NAME, 'w', encoding="utf-8") as output_file:
output_file.write(json.dumps(counter))

most_popular_chars = {key for key, _value in counter.most_common(CONFIG.number_of_chars)}
LOGGER.info("The top %s chars are:", CONFIG.number_of_chars)
LOGGER.info("".join(sorted(most_popular_chars)))


def read_top_chars():
"""Read the top chars we saved to file"""
chars = json.loads(open(CHAR_FREQUENCY_FILE_NAME).read())
counter = Counter(chars)
most_popular_chars = {key for key, _value in counter.most_common(CONFIG.number_of_chars)}
return most_popular_chars

# def preprocesses_data_filter():
# """Pre-process the data - step 3 - filter only sentences with the right chars"""
# most_popular_chars = read_top_chars()
# LOGGER.info("Reading and filtering data:")
# with open(NEWS_FILE_NAME_FILTERED, "wb") as output_file:
# for line in open(NEWS_FILE_NAME_CLEAN):
# decoded_line = line.decode('utf-8')
# if decoded_line and not bool(set(decoded_line) - most_popular_chars):
# output_file.write(line)
# LOGGER.info("Done.")
def preprocesses_data_filter():
"""Pre-process the data - step 3 - filter only sentences with the right chars"""
most_popular_chars = read_top_chars()
LOGGER.info("Reading and filtering data:")
with open(NEWS_FILE_NAME_FILTERED, "wb") as output_file:
for line in open(NEWS_FILE_NAME_CLEAN):
decoded_line = line.decode('utf-8')
if decoded_line and not bool(set(decoded_line) - most_popular_chars):
output_file.write(line)
LOGGER.info("Done.")
"""Tiền xử lý dữ liệu - Bước 3 - Lọc bỏ các ký tự không mong muốn"""
LOGGER.info("Đọc và lọc dữ liệu:")
with open(NEWS_FILE_NAME_CLEAN, "r", encoding="utf-8") as clean_data:
lines = clean_data.readlines()

filtered_lines = [clean_text(line) for line in lines]

with open(NEWS_FILE_NAME_FILTERED, 'w', encoding="utf-8") as output_file:
output_file.writelines(filtered_lines)

LOGGER.info("Hoàn thành.")


def read_filtered_data():
"""Read the filtered data corpus"""
Expand Down Expand Up @@ -487,17 +524,17 @@ def preprocesses_split_lines4():
FILTERED_W2V = "fw2v.bin"
model = Word2Vec.load_word2vec_format(FILTERED_W2V, binary=True) # C text format
print(len(model.wv.index2word))
# answers = set()
# for encoded_line in open(NEWS_FILE_NAME_FILTERED):
# line = encoded_line.decode('utf-8')
# if line.count(" ") < 5:
# answers.add(line)
# LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers))
# LOGGER.info("Here are some examples:")
# for answer in itertools.islice(answers, 10):
# LOGGER.info(answer)
# with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file:
# output_file.write("".join(answers).encode('utf-8'))
answers = set()
for encoded_line in open(NEWS_FILE_NAME_FILTERED):
line = encoded_line.decode('utf-8')
if line.count(" ") < 5:
answers.add(line)
LOGGER.info("There are %s 'answers' (sub-sentences)", len(answers))
LOGGER.info("Here are some examples:")
for answer in itertools.islice(answers, 10):
LOGGER.info(answer)
with open(NEWS_FILE_NAME_SPLIT, "wb") as output_file:
output_file.write("".join(answers).encode('utf-8'))

def preprocess_partition_data():
"""Set asside data for validation"""
Expand Down Expand Up @@ -565,12 +602,13 @@ def train_speller(from_file=None):
if __name__ == '__main__':
# download_the_news_data()
# uncompress_data()
# preprocesses_data_clean()
# preprocesses_data_analyze_chars()
# preprocesses_data_filter()
preprocesses_data_clean()
preprocesses_data_analyze_chars()
preprocesses_data_filter()
# preprocesses_split_lines() --- Choose this step or:
# preprocesses_split_lines2()
preprocesses_split_lines2()
# preprocesses_split_lines4()
# preprocess_partition_data()
# train_speller(os.path.join(DATA_FILES_FULL_PATH, "keras_spell_e15.h5"))
train_speller()
preprocess_partition_data()
train_speller(os.path.join(DATA_FILES_FULL_PATH, "keras_spell_e15.h5"))
train_speller()

0 comments on commit 724e6b9

Please sign in to comment.