-
Notifications
You must be signed in to change notification settings - Fork 2k
/
project_tests.py
103 lines (77 loc) · 3.94 KB
/
project_tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
from keras.losses import sparse_categorical_crossentropy
from keras.models import Sequential
from keras.preprocessing.text import Tokenizer
from keras.utils import to_categorical
def _test_model(model, input_shape, output_sequence_length, french_vocab_size):
if isinstance(model, Sequential):
model = model.model
assert model.input_shape == (None, *input_shape[1:]),\
'Wrong input shape. Found input shape {} using parameter input_shape={}'.format(model.input_shape, input_shape)
assert model.output_shape == (None, output_sequence_length, french_vocab_size),\
'Wrong output shape. Found output shape {} using parameters output_sequence_length={} and french_vocab_size={}'\
.format(model.output_shape, output_sequence_length, french_vocab_size)
assert len(model.loss_functions) > 0,\
'No loss function set. Apply the `compile` function to the model.'
assert sparse_categorical_crossentropy in model.loss_functions,\
'Not using `sparse_categorical_crossentropy` function for loss.'
def test_tokenize(tokenize):
sentences = [
'The quick brown fox jumps over the lazy dog .',
'By Jove , my quick study of lexicography won a prize .',
'This is a short sentence .']
tokenized_sentences, tokenizer = tokenize(sentences)
assert tokenized_sentences == tokenizer.texts_to_sequences(sentences),\
'Tokenizer returned and doesn\'t generate the same sentences as the tokenized sentences returned. '
def test_pad(pad):
tokens = [
[i for i in range(4)],
[i for i in range(6)],
[i for i in range(3)]]
padded_tokens = pad(tokens)
padding_id = padded_tokens[0][-1]
true_padded_tokens = np.array([
[i for i in range(4)] + [padding_id]*2,
[i for i in range(6)],
[i for i in range(3)] + [padding_id]*3])
assert isinstance(padded_tokens, np.ndarray),\
'Pad returned the wrong type. Found {} type, expected numpy array type.'
assert np.all(padded_tokens == true_padded_tokens), 'Pad returned the wrong results.'
padded_tokens_using_length = pad(tokens, 9)
assert np.all(padded_tokens_using_length == np.concatenate((true_padded_tokens, np.full((3, 3), padding_id)), axis=1)),\
'Using length argument return incorrect results'
def test_simple_model(simple_model):
input_shape = (137861, 21, 1)
output_sequence_length = 21
english_vocab_size = 199
french_vocab_size = 344
model = simple_model(input_shape, output_sequence_length, english_vocab_size, french_vocab_size)
_test_model(model, input_shape, output_sequence_length, french_vocab_size)
def test_embed_model(embed_model):
input_shape = (137861, 21)
output_sequence_length = 21
english_vocab_size = 199
french_vocab_size = 344
model = embed_model(input_shape, output_sequence_length, english_vocab_size, french_vocab_size)
_test_model(model, input_shape, output_sequence_length, french_vocab_size)
def test_encdec_model(encdec_model):
input_shape = (137861, 15, 1)
output_sequence_length = 21
english_vocab_size = 199
french_vocab_size = 344
model = encdec_model(input_shape, output_sequence_length, english_vocab_size, french_vocab_size)
_test_model(model, input_shape, output_sequence_length, french_vocab_size)
def test_bd_model(bd_model):
input_shape = (137861, 21, 1)
output_sequence_length = 21
english_vocab_size = 199
french_vocab_size = 344
model = bd_model(input_shape, output_sequence_length, english_vocab_size, french_vocab_size)
_test_model(model, input_shape, output_sequence_length, french_vocab_size)
def test_model_final(model_final):
input_shape = (137861, 15)
output_sequence_length = 21
english_vocab_size = 199
french_vocab_size = 344
model = model_final(input_shape, output_sequence_length, english_vocab_size, french_vocab_size)
_test_model(model, input_shape, output_sequence_length, french_vocab_size)