Skip to content

Commit

Permalink
examples + minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingWorkshop committed Feb 26, 2024
1 parent 310369a commit 16d3110
Show file tree
Hide file tree
Showing 12 changed files with 193 additions and 9 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ version = "1.0.0-DEV"

[deps]
CompressedBeliefMDPs = "0a809e47-b8eb-4578-b4e8-4c2c5f9f833c"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[compat]
Expand Down
13 changes: 13 additions & 0 deletions demos/basic_usage.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using ExpFamilyPCA


n_samples = 5
n_dims = 10
X = rand(0:1, n_samples, n_dims) # generate random binary data

n_components = 2
epca = BernoulliPCA(n_components, n_dims)
fit!(epca, X; verbose=true, maxiter=10)

= compress(epca, X)
X_recon = decompress(epca, X̃)
140 changes: 140 additions & 0 deletions demos/benchmark.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 346.57359027997245\n",
"Loss: "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"339.4651260162266\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 326.83851096508596\n",
"Loss: "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"313.16798822607956\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 300.47024994515317\n",
"Loss: "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"287.68813149707574\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 278.5198879020468\n",
"Loss: "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"272.07194894059666\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 267.2032539001495\n",
"Loss: "
]
}
],
"source": [
"using ExpFamilyPCA\n",
"\n",
"n_samples = 10\n",
"n_dims = 5\n",
"X = rand(0:1, n_samples, n_dims) # generate random binary data\n",
"\n",
"n_components = 2\n",
"epca = BernoulliPCA(n_components, n_dims)\n",
"fit!(epca, X; verbose=true, maxiter=10)\n",
"\n",
"X̃1 = compress(epca, X)\n",
"recon1 = decompress(epca, X̃1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"using MultivariateStats\n",
"\n",
"M = fit(PCA, X; maxoutdim=n_components)\n",
"X̃2 = predict(M, X)\n",
"recon2 = reconstruct(M, X̃2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"using LinearAlgebra\n",
"\n",
"dist1 = X .- recon1\n",
"dist2 = X .- recon2\n",
"\n",
"# From theory, we expect the EPCA reconstruction to do better on the L1 norm but worse on L2\n",
"println(\"EPCA L1:\", norm(dist1, 1))\n",
"println(\"PCA L1:\", norm(dist2, 1))\n",
"println()\n",
"println(\"EPCA L2:\", norm(dist1, 2))\n",
"println(\"PCA L2:\", norm(dist2, 2))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.10.1",
"language": "julia",
"name": "julia-1.10"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.10.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion src/bernoulli.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# TODO: add caveat that this works best w/ data in [0, 1] and might fail otherwise
function BernoulliPCA(l::Integer, d::Integer, μ0::Real)
function BernoulliPCA(l::Integer, d::Integer; μ0::Real=0.5)
ϵ = eps()
@. begin
g(θ) = exp(θ) / (1 + exp(θ))
Expand Down
16 changes: 10 additions & 6 deletions src/epca.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using CompressedBeliefMDPs: Compressor, fit!, compress, decompress
using CompressedBeliefMDPs
using Optim
using Symbolics
using NonlinearSolve

Expand Down Expand Up @@ -61,20 +62,24 @@ function make_loss(epca::EPCA, X)
return L
end


# TODO: maybe add type hinting for X from compressor.jl in BeliefCompression
# TODO: perhaps add early exit for some ϵ
# TODO: make sure printing happens on 1 line
function CompressedBeliefMDPs.fit!(epca::EPCA, X; verbose=false, maxiter::Integer=50)
@assert maxiter > 0
L(A, V) = make_loss(epca, X)
n, _ = size(X)
l, _ = size(epca.V)
= zeros(n, l)
= epca.V
L = make_loss(epca, X)
for _ in 1:maxiter
if verbose println("Loss: ", L(Â, V̂)) end
= Optim.minimizer(optimize(V->epca.L(Â, V), V̂))
= Optim.minimizer(optimize(V->L(Â, V), V̂))
= Optim.minimizer(optimize(A->L(A, V̂), Â))
end
copyto!(epca.V, V̂)
return
end


Expand All @@ -87,9 +92,8 @@ function CompressedBeliefMDPs.compress(epca::EPCA, X; verbose=false, maxiter::In
if verbose println("Loss: ", L(Â, epca.V)) end
= Optim.minimizer(optimize(A->L(A, epca.V), Â))
end
=* epca.V
return
return
end


CompressedBeliefMDPs.decompress(epca::EPCA, ) = epca.g()
CompressedBeliefMDPs.decompress(epca::EPCA, A) = epca.g(A * epca.V)
6 changes: 5 additions & 1 deletion src/poisson.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function PoissonPCA(l::Integer, d::Integer, μ0::Real)
function PoissonPCA(l::Integer, d::Integer; μ0::Real=1)
ϵ = eps()
@. begin
g(θ) = exp(θ)
Expand All @@ -10,3 +10,7 @@ end


# TODO: include a normalized Poisson w/ link function in footnote 5 of long paper
function NormalizedPCA(l::Integer, d::Integer; μ0::Real=1)
# TODO: handle possible problem w/ dimension of theta
return
end
File renamed without changes.
File renamed without changes.
3 changes: 3 additions & 0 deletions test/poisson_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
@testset "PoissonPCA" begin
analytic_test(PoissonPCA, 500, 200)
end
20 changes: 19 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,24 @@
using ExpFamilyPCA
using Test

using Symbolics
using Random
Random.seed!(1)


function analytic_test(epca_constructor::Function, l::Integer, n::Integer, d::Integer)
X = rand(n, d)
epca = epca_constructor(l, d)
fit!(epca, X, maxiter=50)
= compress(X)
X_recon = decompress(X̃)
@test X X_recon
end


# TODO add numeric tests


@testset "ExpFamilyPCA.jl" begin
# Write your tests here.
include("poisson_tests.jl")
end
Empty file removed test/test_normalized_poisson.jl
Empty file.
Empty file removed test/test_poisson.jl
Empty file.

0 comments on commit 16d3110

Please sign in to comment.