-
Notifications
You must be signed in to change notification settings - Fork 1
/
FCM_RDpSGDM.m
126 lines (119 loc) · 5.08 KB
/
FCM_RDpSGDM.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
function [RMSEtrain,RMSEtest,CB,SigmaB,WB,fB]=FCM_RDpSGDM(XTrain,yTrain,XTest,yTest,alpha,rr,P,gammaP,nRules,nIt,Nbs,C0,Sigma0,W0)
% %% Inputs:
% XTrain: N*M matrix of the training inputs. N is the number of samples, and M the feature dimensionality.
% yTrain: N*1 vector of the labels for XTrain
% XTest: {NValidation*M, NTest*M} matrix cell of the validation and test inputs
% yTest: {NValidation*1, NTest*1} vector cell of the labels for XTest
% alpha: scalar, initial learning rate
% rr: scalar, L2 regularization coefficient
% P: scalar, DropRule preservation rate
% gammaP: scalar, Powerball power exponent
% nRules: scalar, total number of rules
% nIt: scalar, maximum number of iterations
% Nbs: batch size. typically 32 or 64
% C0: nRules*M initialization matrix of the centers of the Gaussian MFs
% Sigma0: nRules*M initialization matrix of the standard deviations of the Gaussian MFs
% W0: nRules*(M+1) initialization matrix of the consequent parameters for the nRules rules
%
% %% Outputs:
% RMSEtrain: 1*nIt vector of the training RMSE at different iterations
% RMSEtest: {1*nIt, 1*nIt} vector cell of the validation and test RMSE at different iterations
% CB: nRules*M matrix of the centers of the Gaussian MFs
% SigmaB: nRules*M matrix of the standard deviations of the Gaussian MFs
% WB: nRules*(M+1) matrix of the consequent parameters for the nRules rules
% fB: 1*nRules vector of the firing levels for validation inputs
beta1=0.9; beta2=0.999; thre=inf;
if ~iscell(XTest)
XTest={XTest};
yTest={yTest};
end
[N,M]=size(XTrain);
Nbs=min(N,Nbs);
if nargin<12
W0=zeros(nRules,M+1); % Rule consequents
% FCM initialization
[C0,U] = fcm(XTrain,nRules,[2 100 0.001 0]);
Sigma0=C0;
for r=1:nRules
Sigma0(r,:)=std(XTrain,U(r,:));
W0(r,1)=U(r,:)*yTrain/sum(U(r,:));
end
Sigma0(Sigma0==0)=mean(Sigma0(:));
end
C=C0; Sigma=Sigma0; W=W0;
minSigma=.1*min(Sigma0(:));
[CB,SigmaB,WB,fB]=deal(C,Sigma,W,zeros(1,nRules));
%% Iterative update
RMSEtrain=zeros(1,nIt); RMSEtest=cellfun(@(u)RMSEtrain,XTest,'UniformOutput',false);
mC=0; vC=0; mW=0; mSigma=0; vSigma=0; vW=0; yPred=nan(Nbs,1);
for it=1:nIt
deltaC=zeros(nRules,M); deltaSigma=deltaC; deltaW=rr*W; deltaW(:,1)=0; % consequent
f=zeros(Nbs,nRules); % firing level of rules
idsTrain=datasample(1:N,Nbs,'replace',false);
idsGoodTrain=true(Nbs,1);
for n=1:Nbs
idsKeep=rand(1,nRules)<=P;
f(n,idsKeep)=prod(exp(-(XTrain(idsTrain(n),:)-C(idsKeep,:)).^2./(2*Sigma(idsKeep,:).^2)),2);
if sum(~isfinite(f(n,idsKeep)))
continue;
end
if ~sum(f(n,idsKeep)) % special case: all f(n,:)=0; no dropRule
idsKeep=~idsKeep;
f(n,idsKeep)=prod(exp(-(XTrain(idsTrain(n),:)-C(idsKeep,:)).^2./(2*Sigma(idsKeep,:).^2)),2);
idsKeep=true(1,nRules);
end
deltamuC=(XTrain(idsTrain(n),:)-C(idsKeep,:))./(Sigma(idsKeep,:).^2);
deltamuSigma=(XTrain(idsTrain(n),:)-C(idsKeep,:)).^2./(Sigma(idsKeep,:).^3);
fBar=f(n,idsKeep)/sum(f(n,idsKeep));
yR=[1 XTrain(idsTrain(n),:)]*W(idsKeep,:)';
yPred(n)=fBar*yR'; % prediction
if isnan(yPred(n))
%save2base(); return;
idsGoodTrain(n)=false;
continue;
end
% Compute delta
deltaYmu=(yPred(n)-yTrain(idsTrain(n)))*(yR*sum(f(n,idsKeep))-f(n,idsKeep)*yR')/sum(f(n,idsKeep))^2.*f(n,idsKeep);
if ~sum(~isfinite(deltaYmu(:)))
deltaC(idsKeep,:)=deltaC(idsKeep,:)+deltaYmu'.*deltamuC;
deltaSigma(idsKeep,:)=deltaSigma(idsKeep,:)+deltaYmu'.*deltamuSigma;
deltaW(idsKeep,:)=deltaW(idsKeep,:)+(yPred(n)-yTrain(idsTrain(n)))*fBar'*[1 XTrain(idsTrain(n),:)];
end
end
% powerball
deltaC=sign(deltaC).*(abs(deltaC).^gammaP);
deltaSigma=sign(deltaSigma).*(abs(deltaSigma).^gammaP);
deltaW=sign(deltaW).*(abs(deltaW).^gammaP);
% SGDM
vC=beta1*vC+deltaC;
C=C-alpha*vC;
vSigma=beta1*vSigma+deltaSigma;
Sigma=max(minSigma,Sigma-alpha*vSigma);
vW=beta1*vW+deltaW;
W=W-alpha*vW;
% Training RMSE on the minibatch
RMSEtrain(it)=sqrt(sum((yTrain(idsTrain(idsGoodTrain))-yPred(idsGoodTrain)).^2)/sum(idsGoodTrain));
% Test RMSE
for i=1:length(XTest)
NTest=size(XTest{i},1);
f=zeros(NTest,nRules); % firing level of rules
for n=1:NTest
f(n,:)=prod(exp(-(XTest{i}(n,:)-C).^2./(2*Sigma.^2)),2);
end
f(:,P==0)=0;
yR=[ones(NTest,1) XTest{i}]*W';
yPredTest=sum(f.*yR,2)./sum(f,2); % prediction
yPredTest(isnan(yPredTest))=nanmean(yPredTest);
RMSEtest{i}(it)=sqrt((yTest{i}-yPredTest)'*(yTest{i}-yPredTest)/NTest);
if isnan(RMSEtest{i}(it)) && it>1
RMSEtest{i}(it)=RMSEtest{i}(it-1);
end
if nargout>2&&i==1&&RMSEtest{i}(it)<thre
thre=RMSEtest{i}(it);
[CB,SigmaB,WB,fB]=deal(C,Sigma,W,mean(f));
end
end
end
if length(XTest)==1
RMSEtest=RMSEtest{1};
end