From 810f532dfab75faddbd5fdcb3119ad2b77d3f98a Mon Sep 17 00:00:00 2001 From: Logan Mondal Bhamidipaty <76822456+FlyingWorkshop@users.noreply.github.com> Date: Tue, 7 May 2024 17:51:40 -0700 Subject: [PATCH] return of math symbols! --- src/explicit.jl | 51 +++++++++++++++++----------------- src/implicit.jl | 74 ++++++++++++++++++++++++------------------------- src/utils.jl | 4 +-- 3 files changed, 64 insertions(+), 65 deletions(-) diff --git a/src/explicit.jl b/src/explicit.jl index 059102b..ee82fc2 100644 --- a/src/explicit.jl +++ b/src/explicit.jl @@ -4,56 +4,55 @@ mutable struct ExplicitEPCA <: EPCA g::Function # link function # hyperparameters - mu - epsilon + μ + ϵ end -function EPCA(Bregman::Function, g::Function, mu, epsilon=eps()) - ExplicitEPCA(missing, Bregman, g, mu, epsilon) +function EPCA(Bregman::Function, g::Function, μ, ϵ=eps()) + ExplicitEPCA(missing, Bregman, g, μ, ϵ) end -function PoissonEPCA(; epsilon=eps()) - # assumes X = {integers} +function PoissonEPCA(; ϵ=eps()) + # assumes χ = ℤ @. begin - Bregman(p, q) = p * (log(p + epsilon) - log(q + epsilon)) + q - p - g(theta) = exp(theta) + Bregman(p, q) = p * (log(p + ϵ) - log(q + ϵ)) + q - p + g(θ) = exp(θ) end - mu = g(0) - EPCA(Bregman, g, mu, epsilon) + μ = g(0) + EPCA(Bregman, g, μ, ϵ) end -function BernoulliEPCA(; epsilon=eps()) - # assumes X = {0, 1} +function BernoulliEPCA(; ϵ=eps()) + # assumes χ = {0, 1} @. begin - Bregman(p, q) = p * (log(p + epsilon) - log(q + epsilon)) + (1 - p) * (log(1 - p + epsilon) - log(1 - q + epsilon)) - g(theta) = exp(theta) / (1 + exp(theta)) + Bregman(p, q) = p * (log(p + ϵ) - log(q + ϵ)) + (1 - p) * (log(1 - p + ϵ) - log(1 - q + ϵ)) + g(θ) = exp(θ) / (1 + exp(θ)) end - mu = g(1) - EPCA(Bregman, g, mu, epsilon) + μ = g(1) + EPCA(Bregman, g, μ, ϵ) end -function NormalEPCA(; epsilon=eps()) +function NormalEPCA(; ϵ=eps()) # NOTE: equivalent to generic PCA - # assume X = {reals} + # assume χ = ℝ @. begin Bregman(p, q) = (p - q)^2 / 2 - g(theta) = theta + g(θ) = θ end - mu = g(0) - EPCA(Bregman, g, mu, epsilon) + μ = g(0) + EPCA(Bregman, g, μ, ϵ) end function _make_loss(epca::ExplicitEPCA, X) - B, g, mu, epsilon = epca.Bregman, epca.g, epca.mu, epca.epsilon - L(theta) = begin - X_hat = g.(theta) - divergence = @. B(X, X_hat) + epsilon * B(mu, X_hat) - # @show sum(divergence) + B, g, μ, ϵ = epca.Bregman, epca.g, epca.μ, epca.ϵ + L(θ) = begin + X̂ = g.(θ) + divergence = @. B(X, X̂) + ϵ * B(μ, X̂) return sum(divergence) end return L diff --git a/src/implicit.jl b/src/implicit.jl index 421edd0..b0a3996 100644 --- a/src/implicit.jl +++ b/src/implicit.jl @@ -7,31 +7,31 @@ mutable struct ImplicitEPCA <: EPCA # hyperparameters tol - mu - epsilon + μ + ϵ end -function EPCA(G::Function; tol=eps(), mu=1, epsilon=eps()) - return ImplicitEPCA(G::Function; tol=tol, mu=mu, epsilon=epsilon) +function EPCA(G::Function; tol=eps(), μ=1, ϵ=eps()) + return ImplicitEPCA(G::Function; tol=tol, μ=μ, ϵ=ϵ) end -function ImplicitEPCA(G::Function; tol=eps(), mu=1, epsilon=eps()) - # NOTE: mu must be in the range of g, so g_inv(mu) is finite. It is up to the user to enforce this. - # G induces g, Fg = F(g(theta)), and fg = f(g(theta)) - @variables theta - D = Differential(theta) - _g = expand_derivatives(D(G(theta))) - _Fg = _g * theta - G(theta) +function ImplicitEPCA(G::Function; tol=eps(), μ=1, ϵ=eps()) + # NOTE: μ must be in the range of g, so g_inv(μ) is finite. It is up to the user to enforce this. + # G induces g, Fg = F(g(θ)), and fg = f(g(θ)) + @variables θ + D = Differential(θ) + _g = expand_derivatives(D(G(θ))) + _Fg = _g * θ - G(θ) _fg = expand_derivatives(D(_Fg) / D(_g)) ex = quote - g(theta) = $(Symbolics.toexpr(_g)) - Fg(theta) = $(Symbolics.toexpr(_Fg)) - fg(theta) = $(Symbolics.toexpr(_fg)) + g(θ) = $(Symbolics.toexpr(_g)) + Fg(θ) = $(Symbolics.toexpr(_Fg)) + fg(θ) = $(Symbolics.toexpr(_fg)) end eval(ex) - ImplicitEPCA(missing, G, g, Fg, fg, tol, mu, epsilon) + ImplicitEPCA(missing, G, g, Fg, fg, tol, μ, ϵ) end @@ -51,18 +51,18 @@ end function _make_loss(epca::ImplicitEPCA, X) - G, g, Fg, fg, tol, mu, epsilon = epca.G, epca.g, epca.Fg, epca.fg, epca.tol, epca.mu, epca.epsilon - g_inv_X = map(x->_binary_search_monotone(g, x; tol=tol), X) - g_inv_mu = _binary_search_monotone(g, mu; tol=0) # NOTE: mu is scalar, so we can have very low tol - F_X = @. g_inv_X * X - G(g_inv_X) - F_mu = g_inv_mu * mu - G(g_inv_mu) - L(theta) = begin - X_hat = g.(theta) - Fg_theta = Fg.(theta) - fg_theta = fg.(theta) - BF_X = @. F_X - Fg_theta - fg_theta * (X - X_hat) - BF_mu = @. F_mu - Fg_theta - fg_theta * (mu - X_hat) - divergence = @. BF_X + epsilon * BF_mu + G, g, Fg, fg, tol, μ, ϵ = epca.G, epca.g, epca.Fg, epca.fg, epca.tol, epca.μ, epca.ϵ + g⁻¹X = map(x->_binary_search_monotone(g, x; tol=tol), X) + g⁻¹μ = _binary_search_monotone(g, μ; tol=0) # NOTE: μ is scalar, so we can have very low tol + FX = @. g⁻¹X * X - G(g⁻¹X) + Fμ = g⁻¹μ * μ - G(g⁻¹μ) + L(θ) = begin + X̂ = g.(θ) + Fgθ = Fg.(θ) # Recall this is F(g(θ)) + fgθ = fg.(θ) + BF_X = @. FX - Fgθ - fgθ * (X - X̂) + BF_μ = @. Fμ - Fgθ - fgθ * (μ - X̂) + divergence = @. BF_X + ϵ * BF_μ return sum(divergence) end return L @@ -70,17 +70,17 @@ end function _make_loss_old(epca::ImplicitEPCA, X) - G, g, Fg, fg, tol, mu, epsilon = epca.G, epca.g, epca.Fg, epca.fg, epca.tol, epca.mu, epca.epsilon - g_inv_X = map(x->_binary_search_monotone(g, x; tol=tol), X) - g_inv_mu = _binary_search_monotone(g, mu; tol=0) # NOTE: mu is scalar, so we can have very low tol - F_X = @. X * g_inv_X - G(g_inv_X) - F_mu = mu * g_inv_mu - G(g_inv_mu) - L(theta) = begin + G, g, Fg, fg, tol, μ, ϵ = epca.G, epca.g, epca.Fg, epca.fg, epca.tol, epca.μ, epca.ϵ + g⁻¹X = map(x->_binary_search_monotone(g, x; tol=tol), X) + g⁻¹μ = _binary_search_monotone(g, μ; tol=0) # NOTE: μ is scalar, so we can have very low tol + FX = @. X * g⁻¹X - G(g⁻¹X) + Fμ = μ * g⁻¹μ - G(g⁻¹μ) + L(θ) = begin @infiltrate - X_hat = g.(theta) - B1 = @. F_X - Fg(theta) - fg(theta) * (X - X_hat) - B2 = @. F_mu - Fg(theta) - fg(theta) * (mu - X_hat) - divergence = @. B1 + epsilon * B2 + X̂ = g.(θ) + B1 = @. FX - Fg(θ) - fg(θ) * (X - X̂) + B2 = @. Fμ - Fg(θ) - fg(θ) * (μ - X̂) + divergence = @. B1 + ϵ * B2 return sum(divergence) end return L diff --git a/src/utils.jl b/src/utils.jl index 7722e6c..a4f2923 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,12 +1,12 @@ function _single_fit_iter(L::Function, V, A, verbose::Bool, i::Integer, steps_per_print::Integer, maxiter::Integer) - V = Optim.minimizer(optimize(V_hat->L(A * V_hat), V)) + V = Optim.minimizer(optimize(V̂->L(A * V̂), V)) A = _single_compress_iter(L, V, A, verbose, i, steps_per_print, maxiter) return V, A end function _single_compress_iter(L::Function, V, A, verbose::Bool, i::Integer, steps_per_print::Integer, maxiter::Integer) - result = optimize(A_hat->L(A_hat * V), A) + result = optimize(Â->L(Â * V), A) A = Optim.minimizer(result) if verbose && (i % steps_per_print == 0 || i == 1) loss = Optim.minimum(result)