Skip to content

Commit

Permalink
broadcasting for more eb targets
Browse files Browse the repository at this point in the history
  • Loading branch information
nignatiadis committed Sep 1, 2021
1 parent 87afd4a commit 9f91e53
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 26 deletions.
39 changes: 16 additions & 23 deletions src/amari.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ Base.@kwdef struct AMARI{N, G, M, EB}
n = nothing
end

function Base.show(io::IO, amari::AMARI)
print(io, "AMARI with")
print(io, " F-Localization: ")
show(io, amari.flocalization)
print(io, "\n")
print(io, " 𝒢: ")
show(io, amari.convexclass)
end

function initialize_modulus_model(method::AMARI, ::Type{ModulusModelWithF}, target::Empirikos.LinearEBayesTarget, δ)

Expand Down Expand Up @@ -504,9 +512,11 @@ end


function Base.broadcasted(::typeof(confint), amari::AMARI,
targets::AbstractArray{<:Empirikos.LinearEBayesTarget}, Zs)
targets::AbstractArray{<:Empirikos.EBayesTarget}, Zs)

init_target = isa(targets[1], LinearEBayesTarget) ? targets[1] : denominator(targets[1])
method = initialize_method(amari, init_target, Zs)

method = initialize_method(amari, targets[1], Zs)
_ci = confint(method, targets[1], Zs; initialize=false)
confint_vec = fill(_ci, axes(targets))
for (index, target) in enumerate(targets[2:end])
Expand All @@ -516,9 +526,11 @@ function Base.broadcasted(::typeof(confint), amari::AMARI,
end

function Base.broadcasted_kwsyntax(::typeof(confint), amari::AMARI,
targets::AbstractArray{<:Empirikos.LinearEBayesTarget}, Zs; level=0.95)
targets::AbstractArray{<:Empirikos.EBayesTarget}, Zs; level=0.95)

init_target = isa(targets[1], LinearEBayesTarget) ? targets[1] : denominator(targets[1])
method = initialize_method(amari, init_target, Zs)

method = initialize_method(amari, targets[1], Zs)
_ci = confint(method, targets[1], Zs; initialize=false, level=level)
confint_vec = fill(_ci, axes(targets))
for (index, target) in enumerate(targets[2:end])
Expand Down Expand Up @@ -609,27 +621,8 @@ function confint(method::AMARI, target::Empirikos.AbstractPosteriorTarget, Zs;
end


function Base.broadcasted(::typeof(confint), method::AMARI,
targets::AbstractVector{<:Empirikos.AbstractPosteriorTarget}, Zs, args...; kwargs...)

length(targets) >= 2 || throw(error("use non-broadcasting call to .fit"))
mid_idx = ceil(Int,median(Base.OneTo(length(targets))))

target = targets[mid_idx]
init_target = Empirikos.PosteriorTargetNullHypothesis(target, 0.0)
method = initialize_method(method, init_target, Zs; kwargs...)

init_target = Empirikos.PosteriorTargetNullHypothesis(target, target(method.plugin_G))
_fit = fit_initialized!(method, init_target, Zs)

confint_vec = Vector{LowerUpperConfidenceInterval}(undef, length(targets))

for (index, target) in enumerate(targets)
_ci = confint(method, target, Zs, args...; initialize=false, kwargs...)
confint_vec[index] = _ci
end
confint_vec
end


# used right now only for sanity check in tests
Expand Down
8 changes: 8 additions & 0 deletions src/flocalization_intervals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ Base.@kwdef struct FLocalizationInterval{N,G}
n_bisection::Int = 100
end

function Base.show(io::IO, floc::FLocalizationInterval)
print(io, "EB intervals with F-Localization: ")
show(io, floc.flocalization)
print(io, "\n")
print(io, " 𝒢: ")
show(io, floc.convexclass)
end

function Empirikos.nominal_alpha(floc::FLocalizationInterval)
Empirikos.nominal_alpha(floc.flocalization)
end
Expand Down
5 changes: 2 additions & 3 deletions src/interval_discretizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ _right_endpoint(x::Real) = x
function _discretize(sorted_intervals, x)
n = length(sorted_intervals)
left, right = 1, n

for i = 1:n
middle = div(left + right, 2)
middle_interval = sorted_intervals[middle]
Expand All @@ -44,7 +43,7 @@ function _discretize(sorted_intervals, x)
left = middle + 1
end
end
middle_interval
error("bisection failed")
end

function (discr::Discretizer)(x)
Expand Down Expand Up @@ -116,7 +115,7 @@ end


function integer_discretizer(grid::AbstractVector{Int}; unbounded = :both)
unbounded === :both || throw(ArgumentError("only positive boundary implemented"))
unbounded === :both || throw(ArgumentError("only unbounded == :both implemented"))

ints = Union{Int, Interval{Int64,Unbounded,Closed}, Interval{Int64,Closed,Unbounded}}[
Interval(nothing, first(grid))
Expand Down
53 changes: 53 additions & 0 deletions test/test_bernoulli.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,57 @@ for amari_ in (amari_with_F, amari_without_F)

both_cis_at_α = confint.(amari_, [Empirikos.PriorMean(), Empirikos.PriorSecondMoment()], Zs; level=1-α)
@test both_cis_at_α == [ci_priormean; ci_second_mean]

# some work towards AMARI with posterior mean
c = 0.6
postmean_lin = Empirikos.PosteriorTargetNullHypothesis(PosteriorMean(BinomialSample(1,1)), c)

amari_fit_postmean_lin = fit(amari_, postmean_lin, Zs)

@test amari_fit_postmean_lin.max_bias dkw_lb*(1-dkw_lb)/2 rtol = 0.005
@test amari_fit_postmean_lin.Q(BinomialSample(1,1)) (1-c) - dkw_lb*(1-dkw_lb)/2 atol = 1e-3
@test amari_fit_postmean_lin.Q(BinomialSample(0,1)) - dkw_lb*(1-dkw_lb)/2 atol = 1e-3

ci_postmean_lin = confint(amari_fit_postmean_lin, postmean_lin, Zs; level=1-α)

@test ci_postmean_lin.estimate (1-c)*mean(response.(Zs)) - dkw_lb*(1-dkw_lb)/2 atol = 1e-3
@test ci_postmean_lin.lower < ci_postmean_lin.estimate - quantile(Normal(), 1-α/2)*ci_postmean_lin.se
@test ci_postmean_lin.lower > ci_postmean_lin.estimate - quantile(Normal(), 1-α/2)*ci_postmean_lin.se - ci_postmean_lin.maxbias
@test amari_fit_postmean_lin.unit_var_proxy (1-c)^2*mean(response.(Zs))*(1- mean(response.(Zs))) atol=1e-5


function tmp_f(c)
barf = mean(response, Zs)
_est = (1-c)*barf - dkw_lb*(1-dkw_lb)/2
_se = (1-c)*sqrt(barf*(1-barf)/ nobs(Zs))
_maxbias = dkw_lb*(1-dkw_lb)/2
_pm = Empirikos.gaussian_ci(_se; maxbias=_maxbias, α=α)
_est - _pm
end

cs = 0:0.0001:1
idx = findfirst( tmp_f.(cs) .<= 0)
c_left = cs[idx]

ci_postmean = confint(amari_, PosteriorMean(BinomialSample(1,1)), Zs; level=1-α)
@test ci_postmean.upper 1.0
@test ci_postmean.lower c_left atol = 0.0015

Zs_flip = BinomialSample.(1 .- response.(Zs), 1)
ci_postmean_0_flip = confint(amari_, PosteriorMean(BinomialSample(0,1)), Zs_flip; level=1-α)
@test ci_postmean_0_flip.lower 1 - ci_postmean.upper atol = 1e-5
@test ci_postmean_0_flip.upper 1 - ci_postmean.lower atol = 1e-5

ci_postmean_0 = confint(amari_, PosteriorMean(BinomialSample(0,1)), Zs; level=1-α)
@test ci_postmean_0.lower 0 atol=1e-8

both_cis_postmean = confint.(amari_, identity(PosteriorMean.(BinomialSample.([0,1],1))), Zs)
@test getfield.(both_cis_postmean, ) [0.05; 0.05]

both_cis_postmean_kw = confint.(amari_, identity(PosteriorMean.(BinomialSample.([0,1],1))), Zs; level=0.95)
@test both_cis_postmean_kw == both_cis_postmean

both_cis_postmean_at_α = confint.(amari_, identity(PosteriorMean.(BinomialSample.([0,1],1))), Zs; level=1-α)
@test both_cis_postmean_at_α == [ci_postmean_0; ci_postmean]

end

0 comments on commit 9f91e53

Please sign in to comment.