Skip to content

Commit

Permalink
Merge pull request #23 from gdalle/speedup
Browse files Browse the repository at this point in the history
Miscellaneous improvements
  • Loading branch information
FlyingWorkshop authored Nov 18, 2024
2 parents cfcc98a + 4b9a415 commit a364bc7
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 87 deletions.
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
name = "ExpFamilyPCA"
uuid = "9c724b78-6801-4402-8a63-53f028696012"
authors = ["Logan-Mondal-Bhamidipaty"]
version = "1.1.0"
version = "2.0.0"

[deps]
CompressedBeliefMDPs = "0a809e47-b8eb-4578-b4e8-4c2c5f9f833c"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Sobol = "ed01d8cd-4d21-5b2a-85b4-cc3bdc58bad4"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -21,11 +20,10 @@ Distances = "0.10"
FunctionWrappers = "1"
LogExpFunctions = "0.3"
Optim = "1"
Parameters = "0.12"
Sobol = "1"
Statistics = "1"
Symbolics = "6"
julia = "1"
julia = "1.10"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
1 change: 0 additions & 1 deletion src/ExpFamilyPCA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ using Distances
using FunctionWrappers: FunctionWrapper
using LogExpFunctions
using Optim
using Parameters
using Sobol
using Symbolics

Expand Down
8 changes: 4 additions & 4 deletions src/compressor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ import CompressedBeliefMDPs # `import` rather than `using` to keep tidey namesp
"""
EPCACompressor(epca::EPCA)
Compressor for `CompressedBeliefMDPs.jl`.
Compressor for `CompressedBeliefMDPs.jl`.
"""
struct EPCACompressor <: CompressedBeliefMDPs.Compressor
epca::EPCA
struct EPCACompressor{E<:EPCA} <: CompressedBeliefMDPs.Compressor
epca::E
end

function (c::EPCACompressor)(beliefs)
Expand All @@ -20,4 +20,4 @@ end

function CompressedBeliefMDPs.fit!(c::EPCACompressor, beliefs)
ExpFamilyPCA.fit!(c.epca, beliefs)
end
end
21 changes: 13 additions & 8 deletions src/constructors/epca1.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
struct EPCA1 <: EPCA
F::Union{Function, FunctionWrapper} # Legendre dual of the log-partition
g::Union{Function, FunctionWrapper} # link function
V::AbstractMatrix{<:Real}
options::Options
struct EPCA1{
FT1<:Union{Function, FunctionWrapper},
FT2<:Union{Function, FunctionWrapper},
MT<:AbstractMatrix{<:Real},
OT<:Options
} <: EPCA
F::FT1 # Legendre dual of the log-partition
g::FT2 # link function
V::MT
options::OT
end

function _make_loss(epca::EPCA1, X)
@unpack F, g = epca
@unpack μ, ϵ = epca.options
(; F, g) = epca
(; μ, ϵ) = epca.options
@assert ϵ >= 0 "ϵ must be non-negative."

L(x, θ) = begin
Expand Down Expand Up @@ -44,7 +49,7 @@ function EPCA(
options = Options()
)
@assert isfinite(f(options.μ)) "μ must be in the range of g meaning f(μ) should be finite."
@unpack low, high, tol, maxiter = options
(; low, high, tol, maxiter) = options
g = _invert_legendre(f, options)
V = _initialize_V(indim, outdim, options)
epca = EPCA1(F, g, V, options)
Expand Down
19 changes: 12 additions & 7 deletions src/constructors/epca2.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
struct EPCA2 <: EPCA
G::Union{Function, FunctionWrapper} # log-parition function
g::Union{Function, FunctionWrapper} # link function
V::AbstractMatrix{<:Real}
options::Options
struct EPCA2{
FT1<:Union{Function, FunctionWrapper},
FT2<:Union{Function, FunctionWrapper},
MT<:AbstractMatrix{<:Real},
OT<:Options
} <: EPCA
G::FT1 # log-parition function
g::FT2 # link function
V::MT
options::OT
end

function _make_loss(epca::EPCA2, X)
@unpack G, g = epca
@unpack tol, μ, ϵ = epca.options
(; G, g) = epca
(; tol, μ, ϵ) = epca.options
@assert ϵ >= 0 "ϵ must be non-negative."

L(x, θ) = begin
Expand Down
19 changes: 12 additions & 7 deletions src/constructors/epca3.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
struct EPCA3 <: EPCA
B::Union{Function, FunctionWrapper, PreMetric} # Bregman divergence
g::Union{Function, FunctionWrapper} # link function
V::AbstractMatrix{<:Real}
options::Options
struct EPCA3{
FT1<:Union{Function, FunctionWrapper, PreMetric},
FT2<:Union{Function, FunctionWrapper},
MT<:AbstractMatrix{<:Real},
OT<:Options
} <: EPCA
B::FT1 # Bregman divergence
g::FT2 # link function
V::MT
options::OT
end

function _make_loss(epca::EPCA3, X)
@unpack B, g = epca
@unpack μ, ϵ = epca.options
(; B, g) = epca
(; μ, ϵ) = epca.options
@assert ϵ >= 0 "ϵ must be non-negative."

L(x, θ) = begin
Expand Down
19 changes: 12 additions & 7 deletions src/constructors/epca4.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
struct EPCA4 <: EPCA
Bg::Function # Bregman divergence composed with the link function in the 2nd slot, that is Bg(⋅, ⋅) = B_F(⋅, g(⋅)).
g::Function # link function
V::AbstractMatrix{<:Real}
options::Options
struct EPCA4{
FT1<:Function,
FT2<:Function,
MT<:AbstractMatrix{<:Real},
OT<:Options
} <: EPCA
Bg::FT1 # Bregman divergence composed with the link function in the 2nd slot, that is Bg(⋅, ⋅) = B_F(⋅, g(⋅)).
g::FT2 # link function
V::MT
options::OT
end

function _make_loss(epca::EPCA4, X)
Bg = epca.Bg
@unpack μ, ϵ = epca.options
(; μ, ϵ) = epca.options
@assert ϵ >= 0 "ϵ must be non-negative."

L(x, θ) = begin
Expand Down Expand Up @@ -51,4 +56,4 @@ function EPCA(
options = options
)
return epca
end
end
2 changes: 1 addition & 1 deletion src/epca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ function fit!(
steps_per_print,
epca.options
)
epca.V[:] = V
epca.V[:] = V # TODO: delete this line?
return A
end

Expand Down
6 changes: 3 additions & 3 deletions src/family/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function GammaEPCA(
indim::Integer,
outdim::Integer;
options::Options = Options(
A_init_value = -2,
A_init_value = -2.0,
A_upper = -eps(),
V_lower = eps()
)
Expand All @@ -47,11 +47,11 @@ function ItakuraSaitoEPCA(
indim::Integer,
outdim::Integer;
options::Options = Options(
A_init_value = -2,
A_init_value = -2.0,
A_upper = -eps(),
V_lower = eps()
)
)
epca = GammaEPCA(indim, outdim; options = options)
return epca
end
end
4 changes: 2 additions & 2 deletions src/family/negative_binomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function NegativeBinomialEPCA(
outdim::Integer,
r::Integer;
options::Options = Options(
A_init_value = -1,
A_init_value = -1.0,
A_upper = -eps(),
V_lower = eps()
)
Expand All @@ -40,4 +40,4 @@ function NegativeBinomialEPCA(
options = options
)
return epca
end
end
10 changes: 5 additions & 5 deletions src/family/pareto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ function ParetoEPCA(
outdim::Integer,
m::Real;
options::Options = Options(
μ = 2,
A_init_value = 2,
μ = 2.0,
A_init_value = 2.0,
A_lower = 1 / outdim,
V_init_value = -2,
V_upper = -1,
V_init_value = -2.0,
V_upper = -1.0,
)
)
@assert m > 0 "Minimum value m must be positive."
Expand All @@ -43,4 +43,4 @@ function ParetoEPCA(
options = options
)
return epca
end
end
44 changes: 22 additions & 22 deletions src/options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,29 @@ Defines a struct `Options` for configuring various parameters used in optimizati
- `tol::Real`: Tolerance for stopping binary search. Default is `1e-10`.
- `maxiter::Real`: Maximum iterations for binary search. Default is `1e6`.
"""
@with_kw struct Options
@kwdef struct Options{T<:Real}
# symbolic calculus
metaprogramming::Bool = true

# loss hyperparameters
μ::Real = 1
ϵ::Real = eps()
μ::T = 1.0
ϵ::T = eps()

A_init_value::Real = 1.0
A_lower::Union{Real, Nothing} = nothing
A_upper::Union{Real, Nothing} = nothing
A_init_value::T = 1.0
A_lower::Union{T, Nothing} = nothing
A_upper::Union{T, Nothing} = nothing
A_use_sobol::Bool = false

V_init_value::Real = 1.0
V_lower::Union{Real, Nothing} = nothing
V_upper::Union{Real, Nothing} = nothing
V_init_value::T = 1.0
V_lower::Union{T, Nothing} = nothing
V_upper::Union{T, Nothing} = nothing
V_use_sobol::Bool = false

# binary search options
low = -1e10
high = 1e10
tol = 1e-10
maxiter = 1e6
low::Float64 = -1e10
high::Float64 = 1e10
tol::Float64 = 1e-10
maxiter::Int = 10^6
end

"""
Expand All @@ -60,12 +60,12 @@ Other fields inherit from the `Options` struct.
"""
function NegativeDomain(;
metaprogramming::Bool = true,
μ::Real = 1,
μ::Real = 1.0,
ϵ::Real = eps(),
low = -1e10,
high = 1e10,
tol = 1e-10,
maxiter = 1e6,
maxiter = 10^6,
)
options = Options(
metaprogramming = metaprogramming,
Expand All @@ -75,9 +75,9 @@ function NegativeDomain(;
high = high,
tol = tol,
maxiter = maxiter,
A_init_value = -1,
A_init_value = -1.0,
A_upper = -1e-4,
V_init_value = 1,
V_init_value = 1.0,
V_lower = 1e-4,
)
return options
Expand All @@ -98,12 +98,12 @@ Other fields inherit from the `Options` struct.
"""
function PositiveDomain(
metaprogramming::Bool = true,
μ::Real = 1,
μ::Real = 1.0,
ϵ::Real = eps(),
low = -1e10,
high = 1e10,
tol = 1e-10,
maxiter = 1e6,
maxiter = 10^6,
)
options = Options(
metaprogramming = metaprogramming,
Expand All @@ -113,10 +113,10 @@ function PositiveDomain(
high = high,
tol = tol,
maxiter = maxiter,
A_init_value = 1,
A_init_value = 1.0,
A_upper = 1e-4,
V_init_value = 1,
V_init_value = 1.0,
V_lower = 1e-4,
)
return options
end
end
Loading

0 comments on commit a364bc7

Please sign in to comment.