-
Notifications
You must be signed in to change notification settings - Fork 1
/
NMF_ML.m
executable file
·187 lines (163 loc) · 4.83 KB
/
NMF_ML.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
% NMF_ML Blind unmixing based on multi-layering and the L1 norm
%
% NMF_ML(Y,m,A0,H0,split_every,max_it,tolA,display,flatten)
%
% Y LxM input mixed matrix
% m mask used to determine where the image was sampled
% A0 LxN initial mixing matrix (optional, defaults to exponential matrix)
% H0 NxM initial unmixed matrix estimation (optional, defaults to Y)
% split_every creates a new layer every split_every iterations (optional)
% max_it maximum number of iterations (optional)
% tolA tolerance of the A matrix (optional)
% display 1: display during the iterations, 0: no display
% flatten 1: for 3D images, calculates MIP before unmixing
% 0: use full volume
%
% Requires DipImage (www.dipimage.org)
%
% Based on the algorithm described in :
%
% Pengo et al., EFFICIENT BLIND SPECTRAL UNMIXING OF FLUORESCENTLY LABELED
% SAMPLES USING MULTI-LAYER NON-NEGATIVE MATRIX FACTORIZATION, PLoS ONE, 2013
%
% June 18th 2013 v1.0 First release
% Feb 12th 2014 bugfix affecting flattening, non-rectangular, 3D
%
function [Ac, H, xts, im_out] = NMF_ML(Y,m,varargin)
n=size(Y,1);
p = inputParser;
p.addOptional('A0',exponential_matrix(n), @(A) size(A,1)==size(Y,1));
p.addOptional('H0',Y, @(H) size(H,2)==size(Y,2));
p.addOptional('split_every',20);
p.addOptional('max_it',1500);
p.addOptional('tolA',1e-2);
p.addOptional('display',0);
p.addOptional('flatten',1);
p.parse(varargin{:});
s = p.Results;
% Check
if size(s.H0,1)~=size(s.A0,2)
error('A0 must be matrix-multiplyable by H0. Rows of H0 are not equal to columns of A0.');
end
% Convert mask to a binary image
m=dip_image(m,'bin');
% Transform arguments to variables
A0 = s.A0;
H0 = s.H0;
split_every = s.split_every;
max_it = s.max_it;
tolA = s.tolA;
display = s.display;
flatten = s.flatten;
alphaA = .1;
alphaX = .1;
psi = .1;
% Initialize
%Y=Y;
A=A0;
H=H0;
if flatten
if size(m,3)>1
% Flatten (calculate MIP)
im=iterate('max',dip_image(reconstruct_image(m,Y)),[],3);
im1=iterate('max',dip_image(reconstruct_image(m,H)),[],3);
m=max(m,[],3);
Y=[]; H=[];
for i=1:length(im)
Y=cat(1,Y,double(im{i}(m)));
end
for i=1:length(im1)
H=cat(1,H,double(im1{i}(m)));
end
end
end
Y0=Y;
cXT = @(a,B) mean(B(a))/mean(B);
xts = [];
for n=1:max_it
if mod(n,split_every)==0
if n==split_every
Ac=A;
A=exponential_matrix(size(Ac,2));
else
Ac=Ac*A;
end
Y=H;
H=max(1E6*eps,pinv(Ac'*Ac + alphaX)*Ac'*Y0);
% Max distance from 0 or 1 < tolA
% How close I am to an identity matrix
fprintf('M: %02.4f Fro: %02.4f\n', max(.5-abs(A(:)-.5)), norm(A-eye(size(A)),'fro'))
if max(.5-abs(A(:)-.5))<tolA
xt = cXT(Ac(end-1,end-1)*Y(end-1,:)>Ac(end,end-1)*Y(end,:),H(end,:));
xts=cat(1,xts,xt);
break
end
end
% H step
Yx = A'*Y - alphaX*psi;
Yx(Yx <= 0) = 100*eps;
H = H.*(Yx./((A'*A)*H + eps));
% A step
%Ap = A;
Ya = H*Y' - alphaA;
Ya(Ya <= 0) = 100*eps;
A = A.*(Ya./((A*(H*H'))' + eps))';
A = A*diag(1./(sum(A,1)+eps));
if ~exist('Ac','var')
Ac=A;
end
xt = cXT(Ac(end-1,end-1)*Y(end-1,:)>Ac(end,end-1)*Y(end,:),H(end,:));
xts=cat(1,xts,xt);
if display
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%Z = A*H + eps;
%Z = diag(1./sqrt(var(Z,2)'))*Z;
%kl = sum(sum(Y.*log(Y./Z + eps) - Y + Z));
if mod(n,split_every)==0
subplot(1,4,1)
plot(Y(end-1,:),Y(end,:),'.',H(end-1,:),H(end,:),'k.')
subplot(1,4,2)
plot(xts)
text(n/4,mean(xts),sprintf('%5d %1.5f',n,xt))
if ~isempty(m)
subplot(1,4,3)
im=reconstruct_image(m,H);
M=zeros(size(im{1},2),size(im{1},1),3);
for i=1:length(im)
if size(m,3)>1
mp = max(im{i},[],3);
else
mp = m;
end
if i==4
for j=1:3
M(:,:,j)=min(1,M(:,:,j)+double(mp/max(mp)));
end
else
M(:,:,i)=double(mp/max(mp));
end
end
image(M)
end
subplot(1,4,4)
if exist('Ac','var')
plot(Ac)
else
plot(A)
end
drawnow
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
end
end
H=pinv(Ac'*Ac + alphaX)*Ac'*Y0;
if nargout>3
im_out=reconstruct_image(m,H);
end
function im=reconstruct_image(m,Y)
Nch = size(Y,1);
im = newimar(Nch);
for i=1:Nch
im{i}=newim(size(m));
im{i}(m)=Y(i,:);
end