-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathksupport_FISTA_logistic.m
69 lines (48 loc) · 1.23 KB
/
ksupport_FISTA_logistic.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
function [beta_opt,fun] = ksupport_FISTA_logistic(X,y,lambda,opts)
% Last modified on May 25, 2018
%% initilziation
if isfield(opts, 'init_w')
w_current=opts.init_w;
w_old = opts.init_w;
else
w_current = zeros(size(X,2),1);
w_old = w_current;
end
N = length(y);
k=opts.k;
t=1;
t_old=0;
iter=0;
fun=[];
%% pre-compute
if isfield(opts, 'L')
L = opts.L;
else
[~,~,H] = Loss_logistic(X,y,w_current);
L =eigs(H,1)*0.96;
end
is_contin=1;
if max(abs(X))==0
beta_opt = zeros(size(X,2),1);
fun=0;
is_contin=0;
end
%% main loop
while iter<opts.max_iter_fista & is_contin
alpha = (t_old-1)/t;
beta_s = (1+alpha)*w_current - alpha*w_old;
[~,grad,~] = Loss_logistic(X,y,beta_s);
w_old =w_current;
[~,~,H] = Loss_logistic(X,y,w_current);
L =eigs(H,1)*0.96;
w_current = prox_ksupport(beta_s - grad/L,k,2*lambda/L);
fun = cat(1,fun, Loss_logistic(X,y,w_current) + lambda*norm_overlap(w_current,k)^2);
if iter>=2 & fun(end-1) - fun(end) <=opts.rel_tol * fun(end-1)
break;
end
iter=iter+1;
t_old=t;
t=0.5 * (1+(1+4*t^2)^0.5);
end
beta_opt = w_current;
end