-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrbm.m
98 lines (76 loc) · 3.4 KB
/
rbm.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
% Code provided by Ruslan Salakhutdinov and Geoff Hinton
% Permission is granted for anyone to copy, use, modify, or distribute this
% program and accompanying programs and documents for any purpose, provided
% this copyright notice is retained and prominently displayed, along with
% a note saying that the original programs are available from our
% web page.
% The programs and documents are distributed without any warranty, express or
% implied. As the programs were written for research purposes only, they have
% not been tested to the degree that would be advisable in any important
% application. All use of these programs is entirely at the user's own risk.
epsilonw = 0.1; % Learning rate for weights
epsilonvb = 0.1; % Learning rate for biases of visible units
epsilonhb = 0.1; % Learning rate for biases of hidden units
weightcost = 0.0002;
initialmomentum = 0.5;
finalmomentum = 0.9;
double errsum1=[];
[numcases numdims numbatches]=size(batchdata);
if restart ==1,
restart=0;
epoch=1;
% Initializing symmetric weights and biases.
vishid = 0.1*randn(numdims, numhid);
hidbiases = zeros(numcases,numhid);
visbiases = zeros(numcases,numdims);
poshidprobs = zeros(numcases,numhid);
neghidprobs = zeros(numcases,numhid);
posprods = zeros(numdims,numhid);
negprods = zeros(numdims,numhid);
vishidinc = zeros(numdims,numhid);
hidbiasinc = zeros(numcases,numhid);
visbiasinc = zeros(numcases,numdims);
batchposhidprobs=zeros(numcases,numhid,numbatches);
end
for epoch = epoch:maxepoch,
%fprintf(1,'epoch %d\r',epoch);
errsum=0;
for batch = 1:numbatches,
% fprintf(1,'epoch %d batch %d\r',epoch,batch);
%%%%%%%%% START POSITIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
data = batchdata(:,:,batch);
poshidprobs = 1./(1 + exp(-data*vishid - hidbiases));
batchposhidprobs(:,:,batch)=poshidprobs;
posprods = data' * poshidprobs; % 784*500
poshidact = sum(poshidprobs);
posvisact = sum(data);
%%%%%%%%% END OF POSITIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
poshidstates = poshidprobs > rand(numcases,numhid);
%%%%%%%%% START NEGATIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
negdata = 1./(1 + exp(-poshidstates*vishid' - visbiases));
neghidprobs = 1./(1 + exp(-negdata*vishid - hidbiases));
negprods = negdata'*neghidprobs;
neghidact = sum(neghidprobs);
negvisact = sum(negdata);
%%%%%%%%% END OF NEGATIVE PHASE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
err= sum(sum( (data-negdata).^2 ));
errsum = err + errsum;
errsum1(epoch)=errsum;
if epoch>5,
momentum=finalmomentum;
else
momentum=initialmomentum;
end;
%%%%%%%%% UPDATE WEIGHTS AND BIASES %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
vishidinc = momentum*vishidinc + ...
epsilonw*( (posprods-negprods)/numcases - weightcost*vishid);
visbiasinc = momentum*visbiasinc + (epsilonvb/numcases)*(repmat(posvisact,numcases,1)-repmat(negvisact,numcases,1));
hidbiasinc = momentum*hidbiasinc + (epsilonhb/numcases)*(repmat(poshidact,numcases,1)-repmat(neghidact,numcases,1));
vishid = vishid + vishidinc;
visbiases = visbiases + visbiasinc;
hidbiases = hidbiases + hidbiasinc;
%%%%%%%%%%%%%%%% END OF UPDATES %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
end
fprintf(1, 'epoch %4i error %6.1f \n', epoch, errsum);
end;
save rbmerr errsum1;