-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
79 lines (62 loc) · 1.96 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy as np
import random
from rnn import RNN
from data import train_data, test_data
# Create the vocabulary.
vocab = list(set([w for text in train_data.keys() for w in text.split(' ')]))
vocab_size = len(vocab)
print('%d unique words found' % vocab_size)
# Assign indices to each word.
word_to_idx = { w: i for i, w in enumerate(vocab) }
idx_to_word = { i: w for i, w in enumerate(vocab) }
def createInputs(text):
'''
Returns an array of one-hot vectors representing the words in the input text string.
- text is a string
- Each one-hot vector has shape (vocab_size, 1)
'''
inputs = []
for w in text.split(' '):
v = np.zeros((vocab_size, 1))
v[word_to_idx[w]] = 1
inputs.append(v)
return inputs
def softmax(xs):
# Applies the Softmax Function to the input array.
return np.exp(xs) / sum(np.exp(xs))
# Initialize RNN
rnn = RNN(vocab_size, 2)
def processData(data, backprop=True):
'''
Returns the RNN's loss and accuracy for the given data.
- data is a dictionary mapping text to True or False.
- backprop determines if the backward phase should be run.
'''
items = list(data.items())
random.shuffle(items)
loss = 0
num_correct = 0
for x, y in items:
inputs = createInputs(x)
target = int(y)
# Forward
out, _ = rnn.forward(inputs)
probs = softmax(out)
# Calculate loss / accuracy
loss -= np.log(probs[target])
num_correct += int(np.argmax(probs) == target)
if backprop:
# Build dL/dy
d_L_d_y = probs
d_L_d_y[target] -= 1
# Backward
rnn.backprop(d_L_d_y)
return loss / len(data), num_correct / len(data)
# Training loop
for epoch in range(1000):
train_loss, train_acc = processData(train_data)
if epoch % 100 == 99:
print('--- Epoch %d' % (epoch + 1))
print('Train:\tLoss %.3f | Accuracy: %.3f' % (train_loss, train_acc))
test_loss, test_acc = processData(test_data, backprop=False)
print('Test:\tLoss %.3f | Accuracy: %.3f' % (test_loss, test_acc))