Skip to content

Commit

Permalink
code cleanup with utils
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingWorkshop committed May 4, 2024
1 parent d4efac6 commit e8f2254
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 74 deletions.
7 changes: 4 additions & 3 deletions src/ExpFamilyPCA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@ module ExpFamilyPCA

using Infiltrator

using Symbolics
using Optim

export
EPCA,
fit!,
compress,
decompress
include("epca.jl")

export
ImplicitEPCA
include("utils.jl")
include("implicit.jl")

export
ExplicitEPCA,
NormalEPCA,
BernoulliEPCA,
PoissonEPCA
Expand Down
52 changes: 16 additions & 36 deletions src/explicit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,34 +45,28 @@ function NormalEPCA()
end


function _make_loss(epca::ExplicitEPCA, X, epsilon, mu)
B, g = epca.Bregman, epca.g
L(theta) = begin
X_hat = g.(theta)
sum(@. B(X, X_hat) + epsilon * B(mu, X_hat))
end
return L
end


function fit!(
epca::ExplicitEPCA,
X;
mu=epca.mu,
maxoutdim=1,
maxiter=1000,
verbose=false,
print_steps=10,
steps_per_print=10,
epsilon=eps(),
)
B, g = epca.Bregman, epca.g
L(theta) = begin
X_hat = g.(theta)
sum(B(X, X_hat) + epsilon * B(mu, X_hat))
end
n, d = size(X)
A = ones(n, maxoutdim)
V = ismissing(epca.V) ? ones(maxoutdim, d) : epca.V
for i in 1:maxiter
V = Optim.minimizer(optimize(V_hat->L(A * V_hat), V))
result = optimize(A_hat->L(A_hat * V), A)
A = Optim.minimizer(result)
if verbose && (i % print_steps == 0 || i == 1)
loss = Optim.minimum(result)
println("Iteration: $i/$maxiter | Loss: $loss")
end
end
epca.V = V
L = _make_loss(epca, X, epsilon, mu)
A = _fit!(epca, X, maxoutdim, L, verbose, steps_per_print, maxiter)
return A
end

Expand All @@ -83,25 +77,11 @@ function compress(
mu=epca.mu,
maxiter=100,
verbose=false,
print_steps=10,
steps_per_print=10,
epsilon=eps()
)
B, g, V = epca.Bregman, epca.g, epca.V
L(theta) = begin
X_hat = g.(theta)
sum(@. B(X, X_hat) + epsilon * B(mu, X_hat))
end
n, _ = size(X)
outdim = size(V)[1]
A = ones(n, outdim)
for _ in 1:maxiter
result = optimize(A_hat->L(A_hat * V), A)
A = Optim.minimizer(result)
if verbose && (i % print_steps == 0 || i == 1)
loss = Optim.minimum(result)
println("Iteration: $i/$maxiter | Loss: $loss")
end
end
L = _make_loss(epca, X, epsilon, mu)
A = _compress(epca, X, L, maxiter, verbose, steps_per_print)
return A
end

Expand Down
39 changes: 5 additions & 34 deletions src/implicit.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
using Symbolics
using Optim


mutable struct ImplicitEPCA <: EPCA
V
G::Function
Expand All @@ -10,6 +6,7 @@ mutable struct ImplicitEPCA <: EPCA
fg::Function
end


function EPCA(G::Function)
return ImplicitEPCA(G::Function)
end
Expand Down Expand Up @@ -73,24 +70,12 @@ function fit!(
maxoutdim=1,
maxiter=100,
verbose=false,
print_steps=10,
steps_per_print=10,
epsilon=eps(),
tol=eps()
)
L = _make_loss(epca, X, mu, epsilon; tol=tol)
n, d = size(X)
A = ones(n, maxoutdim)
V = ismissing(epca.V) ? ones(maxoutdim, d) : epca.V
for i in 1:maxiter
V = Optim.minimizer(optimize(V_hat->L(A * V_hat), V))
result = optimize(A_hat->L(A_hat * V), A)
A = Optim.minimizer(result)
if verbose && (i % print_steps == 0 || i == 1)
loss = Optim.minimum(result)
println("Iteration: $i/$maxiter | Loss: $loss")
end
end
epca.V = V
A = _fit!(epca, X, maxoutdim, L, verbose, steps_per_print, maxiter)
return A
end

Expand All @@ -100,26 +85,12 @@ function compress(
mu=1, # NOTE: mu = 1 may not be valid for all link functions.
maxiter=100,
verbose=false,
print_steps=10,
steps_per_print=10,
epsilon=eps(),
tol=eps()
)
L = _make_loss(epca, X, mu, epsilon; tol=tol)
n, _ = size(X)
V = epca.V
outdim = size(V)[1]
A = ones(n, outdim)
for i in 1:maxiter
result = optimize(A_hat->L(A_hat * V), A)
A = Optim.minimizer(result)
if verbose && (i % print_steps == 0 || i == 1)
loss = Optim.minimum(result)
println("Iteration: $i/$maxiter | Loss: $loss")
end
end
A = _compress(epca, X, L, maxiter, verbose, steps_per_print)
return A
end

decompress(epca::ImplicitEPCA, A) = epca.g(A * epca.V)


39 changes: 39 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
function _single_fit_iter(L::Function, V, A, verbose, i, steps_per_print)
V = Optim.minimizer(optimize(V_hat->L(A * V_hat), V))
A = _single_compress_iter(L, V, A, verbose, i, steps_per_print)
return V, A
end

function _single_compress_iter(L::Function, V, A, verbose, i, steps_per_print)
result = optimize(A_hat->L(A_hat * V), A)
A = Optim.minimizer(result)
if verbose && (i % steps_per_print == 0 || i == 1)
loss = Optim.minimum(result)
println("Iteration: $i/$maxiter | Loss: $loss")
end
return A
end

function _fit!(epca::EPCA, X, maxoutdim, L, verbose, steps_per_print, maxiter)
n, d = size(X)
A = ones(n, maxoutdim)
V = ismissing(epca.V) ? ones(maxoutdim, d) : epca.V
for i in 1:maxiter
V, A = _single_fit_iter(L, V, A, verbose, i, steps_per_print)
end
epca.V = V
return A
end

function _compress(epca::EPCA, X, L, maxiter, verbose, steps_per_print)
V = epca.V
n, _ = size(X)
outdim = size(V)[1]
A = ones(n, outdim)
for i in 1:maxiter
_single_compress_iter(L, V, A, verbose, i, steps_per_print)
end
return A
end

decompress(epca::EPCA, A) = epca.g(A * epca.V)
2 changes: 1 addition & 1 deletion test/explicit_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function test_explicit(model::Function, X, rtol)
_, d = size(X)
epca = model()
Y1 = fit!(epca, X; maxoutdim=d)
Y2 = compress(epca, X; maxoutdim=d)
Y2 = compress(epca, X)
@test isapprox(Y1, Y2, rtol=rtol)
Z1 = decompress(epca, Y1)
Z2 = decompress(epca, Y2)
Expand Down

0 comments on commit e8f2254

Please sign in to comment.