-
Notifications
You must be signed in to change notification settings - Fork 3
/
CNN_Main.m
139 lines (114 loc) · 4.53 KB
/
CNN_Main.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
%A Strategic Weight Refinement Maneuver for Convolutional Neural Networks - GSGD for Deep Learning Networks
clear all;
clc;
close all;
test_accuracy = [];
validation_accuracy = [];
training_accuracy = [];
for k=1: 1
close all;
% Relative Path of Data Folder
filepath = 'MNIST_DATASET/MNIST_DATA';
imdsTrain = imageDatastore(filepath, ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
test_filepath = 'MNIST_DATASET/MNIST_DATA_TEST';
imdsTest = imageDatastore(test_filepath, ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
%Split Training set into Training and Validation Data
[imdsTrain,imdsTest] = splitEachLabel(imdsTrain,0.8,'randomize');
[imdsTrain,ValidationSet] = splitEachLabel(imdsTrain,0.8,'randomize');
%Get X and Y Validation Data
XValidation = ValidationSet;
YValidation = ValidationSet.Labels;
%Get the image size
image = readimage(imdsTrain,1);
imagesize = size(image)
% Received from Bayesian
bestVars.InitialLearnRate = 0.00084469; % 1e-4;
bestVars.Momentum = 0.92358 ; %0.97544;
bestVars.L2Regularization = 3.1505e-07 ; %4.7344e-05
bestVars.SectionDepth = 1;%2
bestVars.filterSz = [3 3]; %[3 3];
bestVars.Rho = 7; %4
% bestVars.SquaredGradientDecayFactor = 0.98976;
%Get number of classes
numClasses = size(imdsTrain.countEachLabel.Count,1);
imgType = 1; %grayscale
if size(imagesize,2) == 3
imgType = 3;
end
%Bayesian
numF = round(0.5*imagesize(1)/sqrt(bestVars.SectionDepth));
imageSize = [imagesize(1:2) imgType];
%CNN architecture
layers = ...
[
imageInputLayer(imageSize) % 1 => grayscale images, 3 => RGB
convBlock(bestVars.filterSz,1*numF) %x4
convBlock(bestVars.filterSz,2*numF) %x4
convBlock(bestVars.filterSz,4*numF) %x4
fullyConnectedLayer(numClasses) % total classes/labels
softmaxLayer
classificationLayer
];
%Set Training Options
% Set 'isGuided' parameters to true for GSGD and supply 'Rho', 'RevisitBatchNum' and
% 'VerificationSetNum' values
%Simply remove the above parameters to run without GSGD or set 'isGuided'
%to false
%'Rho' - number of iterations to run for collection and checking of consistent data
% before guided approach is activated to update the weights with consistent
% data
%'RevisitBatchNum' - number of previous batches to revisit and
% check how it performs on present batch weights
%'VerificationSetNum' - number of batches to set aside at the beginning of each epoch.
% Each batch gets picked randomly from this set to attain true
% error on weights updated by each batch during
% training
options = trainingOptions('sgdm', ...
'Momentum',bestVars.Momentum, ...
'MaxEpochs',45,...
'ExecutionEnvironment','gpu', ...
'MiniBatchSize',256, ...
'InitialLearnRate',bestVars.InitialLearnRate, ...
'Verbose',false, ...
'L2Regularization',bestVars.L2Regularization, ...
'ValidationData',{XValidation,YValidation}, ...
'ValidationPatience', inf, ...
'ValidationFrequency', 10, ...
'isGuided', true, ...
'Rho', bestVars.Rho, ...
'RevisitBatchNum', 2, ...
'VerificationSetNum', 4, ...
'Plots','training-progress');
warning off parallel:gpu:devie:DeviceLibsNeedsRecompiling
try
gpuArray.eye(2)^2;
catch ME
end
try
nnet.internal.cnngpu.reluForward(1);
catch ME
end
% Train the network
[net,info] = trainNetwork(imdsTrain,layers,options);
% collect information from info
%Get Training Accuracy
YPred = classify(net,imdsTest);
YTest = imdsTest.Labels;
acc = (sum(YPred == YTest)/numel(YTest)) * 100;
test_accuracy = [test_accuracy acc];
fprintf('Highest training SR: %.2f\n',max(info.TrainingAccuracy));
training_accuracy = [training_accuracy info.TrainingAccuracy];
%Avoid overfitting
validation_accuracy = [validation_accuracy info.ValidationAccuracy];
fprintf('Highest validation SR: %.2f\n',max(info.ValidationAccuracy));
fprintf("%i Test Accuracy is: %i\n",k, acc);
end
fprintf('test_accuracy AT THE END OF FOR LOOP: %.2f\n',test_accuracy);
%Save the Model. If model maybe needed in future, then rename it otherwise
%it will be overwritten after every training
% save('Models/trainedmodel.mat','net');
end