Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Miscellaneous improvements #23

Merged
merged 3 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
name = "ExpFamilyPCA"
uuid = "9c724b78-6801-4402-8a63-53f028696012"
authors = ["Logan-Mondal-Bhamidipaty"]
version = "1.1.0"
version = "2.0.0"

[deps]
CompressedBeliefMDPs = "0a809e47-b8eb-4578-b4e8-4c2c5f9f833c"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Sobol = "ed01d8cd-4d21-5b2a-85b4-cc3bdc58bad4"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -21,11 +20,10 @@ Distances = "0.10"
FunctionWrappers = "1"
LogExpFunctions = "0.3"
Optim = "1"
Parameters = "0.12"
Sobol = "1"
Statistics = "1"
Symbolics = "6"
julia = "1"
julia = "1.10"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
1 change: 0 additions & 1 deletion src/ExpFamilyPCA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ using Distances
using FunctionWrappers: FunctionWrapper
using LogExpFunctions
using Optim
using Parameters
using Sobol
using Symbolics

Expand Down
8 changes: 4 additions & 4 deletions src/compressor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ import CompressedBeliefMDPs # `import` rather than `using` to keep tidey namesp
"""
EPCACompressor(epca::EPCA)

Compressor for `CompressedBeliefMDPs.jl`.
Compressor for `CompressedBeliefMDPs.jl`.
"""
struct EPCACompressor <: CompressedBeliefMDPs.Compressor
epca::EPCA
struct EPCACompressor{E<:EPCA} <: CompressedBeliefMDPs.Compressor
epca::E
end

function (c::EPCACompressor)(beliefs)
Expand All @@ -20,4 +20,4 @@ end

function CompressedBeliefMDPs.fit!(c::EPCACompressor, beliefs)
ExpFamilyPCA.fit!(c.epca, beliefs)
end
end
21 changes: 13 additions & 8 deletions src/constructors/epca1.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
struct EPCA1 <: EPCA
F::Union{Function, FunctionWrapper} # Legendre dual of the log-partition
g::Union{Function, FunctionWrapper} # link function
V::AbstractMatrix{<:Real}
options::Options
struct EPCA1{
FT1<:Union{Function, FunctionWrapper},
FT2<:Union{Function, FunctionWrapper},
MT<:AbstractMatrix{<:Real},
OT<:Options
} <: EPCA
F::FT1 # Legendre dual of the log-partition
g::FT2 # link function
V::MT
options::OT
end

function _make_loss(epca::EPCA1, X)
@unpack F, g = epca
@unpack μ, ϵ = epca.options
(; F, g) = epca
(; μ, ϵ) = epca.options
@assert ϵ >= 0 "ϵ must be non-negative."

L(x, θ) = begin
Expand Down Expand Up @@ -44,7 +49,7 @@ function EPCA(
options = Options()
)
@assert isfinite(f(options.μ)) "μ must be in the range of g meaning f(μ) should be finite."
@unpack low, high, tol, maxiter = options
(; low, high, tol, maxiter) = options
g = _invert_legendre(f, options)
V = _initialize_V(indim, outdim, options)
epca = EPCA1(F, g, V, options)
Expand Down
19 changes: 12 additions & 7 deletions src/constructors/epca2.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
struct EPCA2 <: EPCA
G::Union{Function, FunctionWrapper} # log-parition function
g::Union{Function, FunctionWrapper} # link function
V::AbstractMatrix{<:Real}
options::Options
struct EPCA2{
FT1<:Union{Function, FunctionWrapper},
FT2<:Union{Function, FunctionWrapper},
MT<:AbstractMatrix{<:Real},
OT<:Options
} <: EPCA
G::FT1 # log-parition function
g::FT2 # link function
V::MT
options::OT
end

function _make_loss(epca::EPCA2, X)
@unpack G, g = epca
@unpack tol, μ, ϵ = epca.options
(; G, g) = epca
(; tol, μ, ϵ) = epca.options
@assert ϵ >= 0 "ϵ must be non-negative."

L(x, θ) = begin
Expand Down
19 changes: 12 additions & 7 deletions src/constructors/epca3.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
struct EPCA3 <: EPCA
B::Union{Function, FunctionWrapper, PreMetric} # Bregman divergence
g::Union{Function, FunctionWrapper} # link function
V::AbstractMatrix{<:Real}
options::Options
struct EPCA3{
FT1<:Union{Function, FunctionWrapper, PreMetric},
FT2<:Union{Function, FunctionWrapper},
MT<:AbstractMatrix{<:Real},
OT<:Options
} <: EPCA
B::FT1 # Bregman divergence
g::FT2 # link function
V::MT
options::OT
end

function _make_loss(epca::EPCA3, X)
@unpack B, g = epca
@unpack μ, ϵ = epca.options
(; B, g) = epca
(; μ, ϵ) = epca.options
@assert ϵ >= 0 "ϵ must be non-negative."

L(x, θ) = begin
Expand Down
19 changes: 12 additions & 7 deletions src/constructors/epca4.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
struct EPCA4 <: EPCA
Bg::Function # Bregman divergence composed with the link function in the 2nd slot, that is Bg(⋅, ⋅) = B_F(⋅, g(⋅)).
g::Function # link function
V::AbstractMatrix{<:Real}
options::Options
struct EPCA4{
FT1<:Function,
FT2<:Function,
MT<:AbstractMatrix{<:Real},
OT<:Options
} <: EPCA
Bg::FT1 # Bregman divergence composed with the link function in the 2nd slot, that is Bg(⋅, ⋅) = B_F(⋅, g(⋅)).
g::FT2 # link function
V::MT
options::OT
end

function _make_loss(epca::EPCA4, X)
Bg = epca.Bg
@unpack μ, ϵ = epca.options
(; μ, ϵ) = epca.options
@assert ϵ >= 0 "ϵ must be non-negative."

L(x, θ) = begin
Expand Down Expand Up @@ -51,4 +56,4 @@ function EPCA(
options = options
)
return epca
end
end
2 changes: 1 addition & 1 deletion src/epca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ function fit!(
steps_per_print,
epca.options
)
epca.V[:] = V
epca.V[:] = V # TODO: delete this line?
return A
end

Expand Down
6 changes: 3 additions & 3 deletions src/family/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function GammaEPCA(
indim::Integer,
outdim::Integer;
options::Options = Options(
A_init_value = -2,
A_init_value = -2.0,
A_upper = -eps(),
V_lower = eps()
)
Expand All @@ -47,11 +47,11 @@ function ItakuraSaitoEPCA(
indim::Integer,
outdim::Integer;
options::Options = Options(
A_init_value = -2,
A_init_value = -2.0,
A_upper = -eps(),
V_lower = eps()
)
)
epca = GammaEPCA(indim, outdim; options = options)
return epca
end
end
4 changes: 2 additions & 2 deletions src/family/negative_binomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function NegativeBinomialEPCA(
outdim::Integer,
r::Integer;
options::Options = Options(
A_init_value = -1,
A_init_value = -1.0,
A_upper = -eps(),
V_lower = eps()
)
Expand All @@ -40,4 +40,4 @@ function NegativeBinomialEPCA(
options = options
)
return epca
end
end
10 changes: 5 additions & 5 deletions src/family/pareto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ function ParetoEPCA(
outdim::Integer,
m::Real;
options::Options = Options(
μ = 2,
A_init_value = 2,
μ = 2.0,
A_init_value = 2.0,
A_lower = 1 / outdim,
V_init_value = -2,
V_upper = -1,
V_init_value = -2.0,
V_upper = -1.0,
)
)
@assert m > 0 "Minimum value m must be positive."
Expand All @@ -43,4 +43,4 @@ function ParetoEPCA(
options = options
)
return epca
end
end
44 changes: 22 additions & 22 deletions src/options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,29 @@ Defines a struct `Options` for configuring various parameters used in optimizati
- `tol::Real`: Tolerance for stopping binary search. Default is `1e-10`.
- `maxiter::Real`: Maximum iterations for binary search. Default is `1e6`.
"""
@with_kw struct Options
@kwdef struct Options{T<:Real}
# symbolic calculus
metaprogramming::Bool = true

# loss hyperparameters
μ::Real = 1
ϵ::Real = eps()
μ::T = 1.0
ϵ::T = eps()

A_init_value::Real = 1.0
A_lower::Union{Real, Nothing} = nothing
A_upper::Union{Real, Nothing} = nothing
A_init_value::T = 1.0
A_lower::Union{T, Nothing} = nothing
A_upper::Union{T, Nothing} = nothing
A_use_sobol::Bool = false

V_init_value::Real = 1.0
V_lower::Union{Real, Nothing} = nothing
V_upper::Union{Real, Nothing} = nothing
V_init_value::T = 1.0
V_lower::Union{T, Nothing} = nothing
V_upper::Union{T, Nothing} = nothing
V_use_sobol::Bool = false

# binary search options
low = -1e10
high = 1e10
tol = 1e-10
maxiter = 1e6
low::Float64 = -1e10
high::Float64 = 1e10
tol::Float64 = 1e-10
maxiter::Int = 10^6
end

"""
Expand All @@ -60,12 +60,12 @@ Other fields inherit from the `Options` struct.
"""
function NegativeDomain(;
metaprogramming::Bool = true,
μ::Real = 1,
μ::Real = 1.0,
ϵ::Real = eps(),
low = -1e10,
high = 1e10,
tol = 1e-10,
maxiter = 1e6,
maxiter = 10^6,
)
options = Options(
metaprogramming = metaprogramming,
Expand All @@ -75,9 +75,9 @@ function NegativeDomain(;
high = high,
tol = tol,
maxiter = maxiter,
A_init_value = -1,
A_init_value = -1.0,
A_upper = -1e-4,
V_init_value = 1,
V_init_value = 1.0,
V_lower = 1e-4,
)
return options
Expand All @@ -98,12 +98,12 @@ Other fields inherit from the `Options` struct.
"""
function PositiveDomain(
metaprogramming::Bool = true,
μ::Real = 1,
μ::Real = 1.0,
ϵ::Real = eps(),
low = -1e10,
high = 1e10,
tol = 1e-10,
maxiter = 1e6,
maxiter = 10^6,
)
options = Options(
metaprogramming = metaprogramming,
Expand All @@ -113,10 +113,10 @@ function PositiveDomain(
high = high,
tol = tol,
maxiter = maxiter,
A_init_value = 1,
A_init_value = 1.0,
A_upper = 1e-4,
V_init_value = 1,
V_init_value = 1.0,
V_lower = 1e-4,
)
return options
end
end
Loading
Loading