Skip to content

Commit

Permalink
return of math symbols!
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingWorkshop committed May 8, 2024
1 parent fdc054e commit 810f532
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 65 deletions.
51 changes: 25 additions & 26 deletions src/explicit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,55 @@ mutable struct ExplicitEPCA <: EPCA
g::Function # link function

# hyperparameters
mu
epsilon
μ
ϵ
end


function EPCA(Bregman::Function, g::Function, mu, epsilon=eps())
ExplicitEPCA(missing, Bregman, g, mu, epsilon)
function EPCA(Bregman::Function, g::Function, μ, ϵ=eps())
ExplicitEPCA(missing, Bregman, g, μ, ϵ)
end


function PoissonEPCA(; epsilon=eps())
# assumes X = {integers}
function PoissonEPCA(; ϵ=eps())
# assumes χ =
@. begin
Bregman(p, q) = p * (log(p + epsilon) - log(q + epsilon)) + q - p
g(theta) = exp(theta)
Bregman(p, q) = p * (log(p + ϵ) - log(q + ϵ)) + q - p
g(θ) = exp(θ)
end
mu = g(0)
EPCA(Bregman, g, mu, epsilon)
μ = g(0)
EPCA(Bregman, g, μ, ϵ)
end


function BernoulliEPCA(; epsilon=eps())
# assumes X = {0, 1}
function BernoulliEPCA(; ϵ=eps())
# assumes χ = {0, 1}
@. begin
Bregman(p, q) = p * (log(p + epsilon) - log(q + epsilon)) + (1 - p) * (log(1 - p + epsilon) - log(1 - q + epsilon))
g(theta) = exp(theta) / (1 + exp(theta))
Bregman(p, q) = p * (log(p + ϵ) - log(q + ϵ)) + (1 - p) * (log(1 - p + ϵ) - log(1 - q + ϵ))
g(θ) = exp(θ) / (1 + exp(θ))
end
mu = g(1)
EPCA(Bregman, g, mu, epsilon)
μ = g(1)
EPCA(Bregman, g, μ, ϵ)
end


function NormalEPCA(; epsilon=eps())
function NormalEPCA(; ϵ=eps())
# NOTE: equivalent to generic PCA
# assume X = {reals}
# assume χ =
@. begin
Bregman(p, q) = (p - q)^2 / 2
g(theta) = theta
g(θ) = θ
end
mu = g(0)
EPCA(Bregman, g, mu, epsilon)
μ = g(0)
EPCA(Bregman, g, μ, ϵ)
end


function _make_loss(epca::ExplicitEPCA, X)
B, g, mu, epsilon = epca.Bregman, epca.g, epca.mu, epca.epsilon
L(theta) = begin
X_hat = g.(theta)
divergence = @. B(X, X_hat) + epsilon * B(mu, X_hat)
# @show sum(divergence)
B, g, μ, ϵ = epca.Bregman, epca.g, epca.μ, epca.ϵ
L(θ) = begin
= g.(θ)
divergence = @. B(X, X̂) + ϵ * B(μ, X̂)
return sum(divergence)
end
return L
Expand Down
74 changes: 37 additions & 37 deletions src/implicit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,31 @@ mutable struct ImplicitEPCA <: EPCA

# hyperparameters
tol
mu
epsilon
μ
ϵ
end


function EPCA(G::Function; tol=eps(), mu=1, epsilon=eps())
return ImplicitEPCA(G::Function; tol=tol, mu=mu, epsilon=epsilon)
function EPCA(G::Function; tol=eps(), μ=1, ϵ=eps())
return ImplicitEPCA(G::Function; tol=tol, μ=μ, ϵ=ϵ)
end


function ImplicitEPCA(G::Function; tol=eps(), mu=1, epsilon=eps())
# NOTE: mu must be in the range of g, so g_inv(mu) is finite. It is up to the user to enforce this.
# G induces g, Fg = F(g(theta)), and fg = f(g(theta))
@variables theta
D = Differential(theta)
_g = expand_derivatives(D(G(theta)))
_Fg = _g * theta - G(theta)
function ImplicitEPCA(G::Function; tol=eps(), μ=1, ϵ=eps())
# NOTE: μ must be in the range of g, so g_inv(μ) is finite. It is up to the user to enforce this.
# G induces g, Fg = F(g(θ)), and fg = f(g(θ))
@variables θ
D = Differential(θ)
_g = expand_derivatives(D(G(θ)))
_Fg = _g * θ - G(θ)
_fg = expand_derivatives(D(_Fg) / D(_g))
ex = quote
g(theta) = $(Symbolics.toexpr(_g))
Fg(theta) = $(Symbolics.toexpr(_Fg))
fg(theta) = $(Symbolics.toexpr(_fg))
g(θ) = $(Symbolics.toexpr(_g))
Fg(θ) = $(Symbolics.toexpr(_Fg))
fg(θ) = $(Symbolics.toexpr(_fg))
end
eval(ex)
ImplicitEPCA(missing, G, g, Fg, fg, tol, mu, epsilon)
ImplicitEPCA(missing, G, g, Fg, fg, tol, μ, ϵ)
end


Expand All @@ -51,36 +51,36 @@ end


function _make_loss(epca::ImplicitEPCA, X)
G, g, Fg, fg, tol, mu, epsilon = epca.G, epca.g, epca.Fg, epca.fg, epca.tol, epca.mu, epca.epsilon
g_inv_X = map(x->_binary_search_monotone(g, x; tol=tol), X)
g_inv_mu = _binary_search_monotone(g, mu; tol=0) # NOTE: mu is scalar, so we can have very low tol
F_X = @. g_inv_X * X - G(g_inv_X)
F_mu = g_inv_mu * mu - G(g_inv_mu)
L(theta) = begin
X_hat = g.(theta)
Fg_theta = Fg.(theta)
fg_theta = fg.(theta)
BF_X = @. F_X - Fg_theta - fg_theta * (X - X_hat)
BF_mu = @. F_mu - Fg_theta - fg_theta * (mu - X_hat)
divergence = @. BF_X + epsilon * BF_mu
G, g, Fg, fg, tol, μ, ϵ = epca.G, epca.g, epca.Fg, epca.fg, epca.tol, epca.μ, epca.ϵ
g⁻¹X = map(x->_binary_search_monotone(g, x; tol=tol), X)
g⁻¹μ = _binary_search_monotone(g, μ; tol=0) # NOTE: μ is scalar, so we can have very low tol
FX = @. g⁻¹X * X - G(g⁻¹X)
= g⁻¹μ * μ - G(g⁻¹μ)
L(θ) = begin
= g.(θ)
Fgθ = Fg.(θ) # Recall this is F(g(θ))
fgθ = fg.(θ)
BF_X = @. FX - Fgθ - fgθ * (X - )
BF_μ = @. - Fgθ - fgθ * (μ - )
divergence = @. BF_X + ϵ * BF_μ
return sum(divergence)
end
return L
end


function _make_loss_old(epca::ImplicitEPCA, X)
G, g, Fg, fg, tol, mu, epsilon = epca.G, epca.g, epca.Fg, epca.fg, epca.tol, epca.mu, epca.epsilon
g_inv_X = map(x->_binary_search_monotone(g, x; tol=tol), X)
g_inv_mu = _binary_search_monotone(g, mu; tol=0) # NOTE: mu is scalar, so we can have very low tol
F_X = @. X * g_inv_X - G(g_inv_X)
F_mu = mu * g_inv_mu - G(g_inv_mu)
L(theta) = begin
G, g, Fg, fg, tol, μ, ϵ = epca.G, epca.g, epca.Fg, epca.fg, epca.tol, epca.μ, epca.ϵ
g⁻¹X = map(x->_binary_search_monotone(g, x; tol=tol), X)
g⁻¹μ = _binary_search_monotone(g, μ; tol=0) # NOTE: μ is scalar, so we can have very low tol
FX = @. X * g⁻¹X - G(g⁻¹X)
= μ * g⁻¹μ - G(g⁻¹μ)
L(θ) = begin
@infiltrate
X_hat = g.(theta)
B1 = @. F_X - Fg(theta) - fg(theta) * (X - X_hat)
B2 = @. F_mu - Fg(theta) - fg(theta) * (mu - X_hat)
divergence = @. B1 + epsilon * B2
= g.(θ)
B1 = @. FX - Fg(θ) - fg(θ) * (X - )
B2 = @. - Fg(θ) - fg(θ) * (μ - )
divergence = @. B1 + ϵ * B2
return sum(divergence)
end
return L
Expand Down
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
function _single_fit_iter(L::Function, V, A, verbose::Bool, i::Integer, steps_per_print::Integer, maxiter::Integer)
V = Optim.minimizer(optimize(V_hat->L(A * V_hat), V))
V = Optim.minimizer(optimize(->L(A * ), V))
A = _single_compress_iter(L, V, A, verbose, i, steps_per_print, maxiter)
return V, A
end


function _single_compress_iter(L::Function, V, A, verbose::Bool, i::Integer, steps_per_print::Integer, maxiter::Integer)
result = optimize(A_hat->L(A_hat * V), A)
result = optimize(->L( * V), A)
A = Optim.minimizer(result)
if verbose && (i % steps_per_print == 0 || i == 1)
loss = Optim.minimum(result)
Expand Down

0 comments on commit 810f532

Please sign in to comment.