-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNN_001.js
113 lines (95 loc) · 4.95 KB
/
NN_001.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
function main()
{
console.log("\nrunning\n");
let u = [784, 25, 25, 25, 10];
let learningRate = 0.0067;
let bounceResRate = 50.0;
let weightInitRange = 0.35;
let runs = 1000;
let miniBatch = 8;
let networkInfoCheck = 10;
let dnn = u.length - 1, nns = 0, wnn = 0, inputs = u[0], output = u[dnn], correct = 0;
let ce = 0, ce2 = 0;
for (let n = 0; n < dnn + 1; n++) nns += u[n]; // num of neurons
for (let n = 1; n < dnn + 1; n++) wnn += u[n - 1] * u[n]; // num of weights
let neuron = [nns];
let gradient = [nns - inputs];
let weight = [wnn];
let delta = [wnn];
let target = [output];
let FS = require('fs');
let MNISTimage = FS.readFileSync("train-images.idx3-ubyte", "binary");
let MNISTlabel = FS.readFileSync("train-labels.idx1-ubyte", "binary");
// load input data...
for (let n = 0, p = 314; n < wnn; n++)
weight[n] = ((p = p * 2718 % 2718281) / (2718281.0 * Math.E * Math.PI * weightInitRange));
//--- start training
for (let x = 1; x < runs + 1; x++)
{
//+----------- 1. MNIST as Inputs --------------------------------------+
// feed input data...
for (let n = 0; n < inputs; n++)
{
neuron[n] = neuron.push(MNISTimage [((x-1)*784) + n + 16]) / 255;
}
let tmp = [];
tmp.push(MNISTlabel [(x-1) + 8]);
let targetNum = tmp;
//+----------- 2. Feed Forward -----------------------------------------+
for (let i = 0, j = inputs, t = 0, w = 0; i < dnn; i++, t += u[i - 1], w += u[i] * u[i - 1])
for (let k = 0; k < u[i + 1]; k++, j++){
let net = gradient[j - inputs] = 0;
for (let n = t, m = w + k; n < t + u[i]; n++, m += u[i + 1])
net += neuron[n] * weight[m];
neuron[j] = i == dnn - 1 ? net : net > 0 ? net : 0;
}//--- k ends
//+------------ 3. NN prediction ---------------------------------------+
let outMaxPos = nns - output;
let outMaxVal = neuron[nns - output], scale = 0;
for (let i = nns - output + 1; i < nns; i++)
if (neuron[i] > outMaxVal){
outMaxPos = i; outMaxVal = neuron[i];
}
if (targetNum + nns - output == outMaxPos) correct++;
//+----------- 4. Loss / Error with Softmax and Cross Entropy ----------+
for (let n = nns - output; n != nns; n++)
scale += Math.exp(neuron[n] - outMaxVal);
for (let n = nns - output, m = 0; n != nns; m++, n++)
neuron[n] = Math.exp(neuron[n] - outMaxVal) / scale;
ce2 = (ce -= Math.log(neuron[outMaxPos])) / x;
//+----------- 5. Backpropagation --------------------------------------+
target[targetNum] = 1.0;
for (let i = dnn, j = nns - 1, ls = output, wd = wnn - 1, ws = wd, us = nns - output - 1, gs = nns - inputs - 1;
i != 0; i--, wd -= u[i + 1] * u[i + 0], us -= u[i], gs -= u[i + 1])
for (let k = 0; k != u[i]; k++, j--){
let gra = 0;
//--- first check if output or hidden, calc delta for both
if (i == dnn)
gra = target[--ls] - neuron[j];
else if(neuron[j] > 0)
for (let n = gs + u[i + 1]; n > gs; n--, ws--)
gra += weight[ws] * gradient[n];
else ws -= u[i + 1];
for (let n = us, w = wd - k; n > us - u[i - 1]; w -= u[i], n--)
delta[w] += gra * neuron[n];
gradient[j - inputs] = gra;
}
target[targetNum] = 0;
//+----------- 6. update Weights ---------------------------------------+
if ((x % miniBatch == 0) || (x == runs - 1)){
for (let m = 0; m < wnn; m++){
//--- bounce restriction
if (delta[m] * delta[m] > bounceResRate) continue;
//--- update weights
weight[m] += learningRate * delta[m];
delta[m] *= 0.67;
}
} //--- batch end
if (x % (runs / networkInfoCheck) == 0)
console.log("runs: " + x + " accuracy: " + (correct * 100.0 / x));
} // x end
console.log("\nneurons: " + nns + " weights: " + wnn + " batch: " + miniBatch);
console.log("accuracy: " + (correct * 100.0 / (runs * 1.0)) + " cross entropy: " + ce2);
console.log("correct: "+(correct) + " incorrect: " + (runs - correct));
}
main();