Skip to content

Commit

Permalink
Bandwidth fix (#46)
Browse files Browse the repository at this point in the history
* add test for float bandwidth

* add more tests for bandwith

* fix bug

* format

* fix docs

* refactor bandwidth conversion
  • Loading branch information
dufourc1 authored Sep 21, 2023
1 parent a763b8a commit ccbc6b3
Show file tree
Hide file tree
Showing 13 changed files with 119 additions and 121 deletions.
35 changes: 18 additions & 17 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,25 @@ using NetworkHistogram
using Documenter

DocMeta.setdocmeta!(NetworkHistogram, :DocTestSetup, :(using NetworkHistogram);
recursive = true)
recursive = true)

makedocs(;
modules = [NetworkHistogram],
authors = "Jake Grainger, Charles Dufour",
#repo = "github.com/SDS-EPFL/NetworkHistogram.jl.git",
sitename = "NetworkHistogram.jl",
#format = Documenter.HTML(;
# prettyurls = get(ENV, "CI", "false") == "true",
# canonical = "https://SDS-EPFL.github.io/NetworkHistogram.jl",
# edit_link = "main",
# assets = String[]),
pages = [
"Home" => "index.md",
"API Reference" => "api.md",
"Optimization hyperparameters" => "rules.md",
"Development" => "internals.md",
])
modules = [NetworkHistogram],
authors = "Jake Grainger, Charles Dufour",
#repo = "github.com/SDS-EPFL/NetworkHistogram.jl.git",
sitename = "NetworkHistogram.jl",
#format = Documenter.HTML(;
# prettyurls = get(ENV, "CI", "false") == "true",
# canonical = "https://SDS-EPFL.github.io/NetworkHistogram.jl",
# edit_link = "main",
# assets = String[]),
pages = [
"Home" => "index.md",
"API Reference" => "api.md",
"Optimization hyperparameters" => "rules.md",
"Development" => "internals.md",
],
checkdocs = :none)

deploydocs(;
repo = "github.com/SDS-EPFL/NetworkHistogram.jl.git")
repo = "github.com/SDS-EPFL/NetworkHistogram.jl.git")
18 changes: 9 additions & 9 deletions src/assignment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ mutable struct Assignment{T}

estimated_theta = realized ./ counts
likelihood = compute_log_likelihood(number_groups, estimated_theta, counts,
size(A, 1))
size(A, 1))

new{T}(group_size,
node_labels,
counts,
realized,
estimated_theta,
likelihood)
node_labels,
counts,
realized,
estimated_theta,
likelihood)
end
end

Expand Down Expand Up @@ -84,9 +84,9 @@ where ``\\hat{\\theta}_{ab}`` is the estimated probability of an edge between co
"""
function compute_log_likelihood(assignment::Assignment)
compute_log_likelihood(length(assignment.group_size),
assignment.estimated_theta,
assignment.counts,
sum(assignment.group_size))
assignment.estimated_theta,
assignment.counts,
sum(assignment.group_size))
end

function deepcopy!(a::Assignment, b::Assignment)
Expand Down
4 changes: 2 additions & 2 deletions src/config_rules/accept_rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ Return the updated `current` assignment based on the `accept_rule`.
accept_reject_update!

function accept_reject_update!(history::GraphOptimizationHistory, iteration::Int,
proposal::Assignment,
current::Assignment, accept_rule::Strict)
proposal::Assignment,
current::Assignment, accept_rule::Strict)
if proposal.likelihood > current.likelihood
deepcopy!(current, proposal)
end
Expand Down
8 changes: 4 additions & 4 deletions src/history.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ end
"""
function initialize_history(best, current, proposal, ::Val{true})
history = MVHistory(Dict([
:proposal_likelihood => QHistory(Float64),
:current_likelihood => QHistory(Float64),
:best_likelihood => QHistory(Float64),
]))
:proposal_likelihood => QHistory(Float64),
:current_likelihood => QHistory(Float64),
:best_likelihood => QHistory(Float64),
]))
push!(history, :proposal_likelihood, 0, proposal.likelihood)
push!(history, :current_likelihood, 0, current.likelihood)
push!(history, :best_likelihood, 0, best.likelihood)
Expand Down
48 changes: 16 additions & 32 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ Computes the graph histogram approximation.
# Arguments
- `A`: adjacency matrix of a simple graph
- `h`: bandwidth of the graph histogram (number of nodes in a group or percentage of nodes in a
group)
- `h`: bandwidth of the graph histogram (number of nodes in a group or percentage (in [0,1]) of
nodes in a group)
- `record_trace` (optional): whether to record the trace of the optimization process and return
it as part of the output. Default is `true`.
Expand Down Expand Up @@ -55,18 +55,17 @@ julia> loglikelihood = out.likelihood
```
"""
function graphhist(A; h = select_bandwidth(A), maxitr = 1000,
swap_rule::NodeSwapRule = RandomNodeSwap(),
starting_assignment_rule::StartingAssignment = RandomStart(),
accept_rule::AcceptRule = Strict(),
stop_rule::StopRule = PreviousBestValue(3), record_trace = true)
swap_rule::NodeSwapRule = RandomNodeSwap(),
starting_assignment_rule::StartingAssignment = RandomStart(),
accept_rule::AcceptRule = Strict(),
stop_rule::StopRule = PreviousBestValue(3), record_trace = true)
checkadjacency(A)
h = sanitize_bandwidth(h, size(A, 1))
@assert maxitr > 0

return _graphhist(A, Val{record_trace}(), h = h, maxitr = maxitr, swap_rule = swap_rule,
starting_assignment_rule = starting_assignment_rule,
accept_rule = accept_rule,
stop_rule = stop_rule)
starting_assignment_rule = starting_assignment_rule,
accept_rule = accept_rule,
stop_rule = stop_rule)
end

"""
Expand All @@ -75,9 +74,9 @@ end
Internal version of `graphhist` which is type stable.
"""
function _graphhist(A, record_trace = Val{true}(); h, maxitr, swap_rule,
starting_assignment_rule, accept_rule, stop_rule)
starting_assignment_rule, accept_rule, stop_rule)
best, current, proposal, history = initialize(A, h, starting_assignment_rule,
record_trace)
record_trace)

for i in 1:maxitr
proposal = create_proposal!(history, i, proposal, current, A, swap_rule)
Expand All @@ -104,8 +103,8 @@ function graphhist_format_output(best, history::NoTraceHistory)
end

function update_best!(history::GraphOptimizationHistory, iteration::Int,
current::Assignment,
best::Assignment)
current::Assignment,
best::Assignment)
if current.likelihood > best.likelihood
update_best!(history, iteration, current.likelihood)
deepcopy!(best, current)
Expand All @@ -127,9 +126,9 @@ function initialize(A, h, starting_assignment_rule, record_trace)
return best, current, proposal, history
end

function select_bandwidth(A, type = "degs", alpha = 1, c = 1)
function select_bandwidth(A, type = "degs", alpha = 1, c = 1)::Int
h = oracle_bandwidth(A, type, alpha, c)
return sanitize_bandwidth(h, size(A, 1))
return max(2, min(size(A)[1], round(Int, h)))
end

"""
Expand All @@ -154,7 +153,7 @@ function oracle_bandwidth(A, type = "degs", alpha = 1, c = min(4, sqrt(size(A, 1

n = size(A, 1)
midPt = collect(max(1, round(Int, (n ÷ 2 - c * sqrt(n)))):round(Int,
(n ÷ 2 + c * sqrt(n))))
(n ÷ 2 + c * sqrt(n))))
rhoHat_inv = inv(sum(A) / (n * (n - 1)))

# Rank-1 graphon estimate via fhat(x,y) = mult*u(x)*u(y)*pinv(rhoHat);
Expand All @@ -180,18 +179,3 @@ function oracle_bandwidth(A, type = "degs", alpha = 1, c = min(4, sqrt(size(A, 1
#estMSqrd = 2*mult^2*(lmfit_coef[2]*length(uMid)/2+lmfit_coef[1])^2*lmfit_coef[2]^2*rhoHat_inv^2*(n+1)^2
return h[1]
end

function sanitize_bandwidth(h::Real, n::Int)::Int
h = max(2, min(n, round(Int, h)))
lastGroupSize = n % h

if lastGroupSize == 1
@warn "Correcting bandwidth to avoid singleton final group."
end
# step down h, to avoid singleton final group
while lastGroupSize == 1 && h > 2
h -= 1
lastGroupSize = n % h
end
return h
end
12 changes: 6 additions & 6 deletions src/proposal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ proposal is stored in the history.
The `proposal` assignment is modified in place to avoid unnecessary memory allocation.
"""
function create_proposal!(history::GraphOptimizationHistory, iteration::Int,
proposal::Assignment,
current::Assignment, A, swap_rule)
proposal::Assignment,
current::Assignment, A, swap_rule)
swap = select_swap(current, A, swap_rule)
make_proposal!(proposal, current, swap, A)
update_proposal!(history, iteration, proposal.likelihood)
Expand Down Expand Up @@ -78,20 +78,20 @@ function update_observed!(proposal::Assignment, swap::Tuple{Int, Int}, A)
if A[i, swap[1]] == 1
proposal.realized[group_node_1, group_i] -= 1
proposal.realized[group_i, group_node_1] = proposal.realized[group_node_1,
group_i]
group_i]

proposal.realized[group_node_2, group_i] += 1
proposal.realized[group_i, group_node_2] = proposal.realized[group_node_2,
group_i]
group_i]
end
if A[i, swap[2]] == 1
proposal.realized[group_node_2, group_i] -= 1
proposal.realized[group_i, group_node_2] = proposal.realized[group_node_2,
group_i]
group_i]

proposal.realized[group_node_1, group_i] += 1
proposal.realized[group_i, group_node_1] = proposal.realized[group_node_1,
group_i]
group_i]
end
end

Expand Down
10 changes: 5 additions & 5 deletions test/data_tests/utils.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
@testset "Data utils" begin
A = [0 0 0 1
0 0 0 0
1 0 0 1
0 0 1 0]
0 0 0 0
1 0 0 1
0 0 1 0]

@testset "drop isolated vertices" begin
B = NetworkHistogram.drop_isolated_vertices(A)
@test B == [0 1 1
1 0 1
0 1 0]
1 0 1
0 1 0]
end
end
20 changes: 11 additions & 9 deletions test/error_handling_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
@testset "Adjacency matrix" begin
As = [
[0 1
0 0], [1 1
1 0], [0 2
2 0], [0 1
1 0
0 1],
0 0], [1 1
1 0], [0 2
2 0], [0 1
1 0
0 1],
]
for A in As
@test_throws AssertionError graphhist(A, h = 2)
Expand All @@ -17,9 +17,11 @@
end
@testset "maxitr" begin
@test_throws AssertionError graphhist([0 1; 1 0], h = 2,
maxitr = -1)
maxitr = -1)
end
@testset "h" begin
for h in (3, -1, 1.1, -0.1)
@test_throws AssertionError graphhist([0 1; 1 0], h = h)
end
end
@testset "h" begin for h in (3, -1, 1.1, -0.1)
@test_throws AssertionError graphhist([0 1; 1 0], h = h)
end end
end
22 changes: 11 additions & 11 deletions test/oracle_bandwidth_test.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
@testset "oracle bandwidth test" begin
A = [0 0 1 0 1 0 1 1 0 1
0 0 1 1 1 1 1 1 0 0
1 1 0 1 0 0 0 0 1 0
0 1 1 0 1 0 1 0 0 0
1 1 0 1 0 0 1 0 0 1
0 1 0 0 0 0 0 1 0 0
1 1 0 1 1 0 0 1 0 1
1 1 0 0 0 1 1 0 0 1
0 0 1 0 0 0 0 0 0 1
1 0 0 0 1 0 1 1 1 0]
0 0 1 1 1 1 1 1 0 0
1 1 0 1 0 0 0 0 1 0
0 1 1 0 1 0 1 0 0 0
1 1 0 1 0 0 1 0 0 1
0 1 0 0 0 0 0 1 0 0
1 1 0 1 1 0 0 1 0 1
1 1 0 0 0 1 1 0 0 1
0 0 1 0 0 0 0 0 0 1
1 0 0 0 1 0 1 1 1 0]
h = NetworkHistogram.oracle_bandwidth(A)
rho = sum(A) / (size(A, 1) * (size(A, 1) - 1))
h_true_nethist = 2.643731 # version 0.2.3 from nethist package
h_clean = 3
@test hh_true_nethist atol=1e-4
h_clean = NetworkHistogram.sanitize_bandwidth(h, size(A, 1))
@test h_clean == 2
@test NetworkHistogram.select_bandwidth(A) == h_clean
end
41 changes: 26 additions & 15 deletions test/pipeline_test.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
@testset "Pipeline" begin
@testset "dummy run" begin
A = [0 0 1 0 1 0 1 1 0 1
0 0 1 1 1 1 1 1 0 0
1 1 0 1 0 0 0 0 1 0
0 1 1 0 1 0 1 0 0 0
1 1 0 1 0 0 1 0 0 1
0 1 0 0 0 0 0 1 0 0
1 1 0 1 1 0 0 1 0 1
1 1 0 0 0 1 1 0 0 1
0 0 1 0 0 0 0 0 0 1
1 0 0 0 1 0 1 1 1 0]
estimated = graphhist(A; h = 0.5)
0 0 1 1 1 1 1 1 0 0
1 1 0 1 0 0 0 0 1 0
0 1 1 0 1 0 1 0 0 0
1 1 0 1 0 0 1 0 0 1
0 1 0 0 0 0 0 1 0 0
1 1 0 1 1 0 0 1 0 1
1 1 0 0 0 1 1 0 0 1
0 0 1 0 0 0 0 0 0 1
1 0 0 0 1 0 1 1 1 0]
@testset "run bandwidth float" begin
estimated = graphhist(A; h = 0.5)
@test all(estimated.graphhist.θ .>= 0.0)
@test all(estimated.graphhist.θ .<= 1.0)
@test size(estimated.graphhist.θ) == (2, 2)
end
@testset "run bandwidth int" begin
estimated = graphhist(A; h = 5)
@test all(estimated.graphhist.θ .>= 0.0)
@test all(estimated.graphhist.θ .<= 1.0)
@test size(estimated.graphhist.θ) == (2, 2)
end
end

@testset "associative stochastic block model" begin
Expand All @@ -19,13 +30,13 @@
for (name, adjacency) in adjacencies
@testset "$name" begin
estimated, history = graphhist(adjacency; h = 0.3,
stop_rule = PreviousBestValue(100),
starting_assignment_rule = OrderedStart())
stop_rule = PreviousBestValue(100),
starting_assignment_rule = OrderedStart())
@test all(estimated.θ .>= 0.0)
estimated, history = graphhist(adjacency; h = 0.3,
stop_rule = PreviousBestValue(100),
starting_assignment_rule = OrderedStart(),
record_trace = false)
stop_rule = PreviousBestValue(100),
starting_assignment_rule = OrderedStart(),
record_trace = false)
@test all(estimated.θ .>= 0.0)
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/proposal_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
proposal = deepcopy(assignment)
NetworkHistogram.make_proposal!(proposal, assignment, swap, A)
reference_proposal = NetworkHistogram.Assignment(A, [1, 2, 1, 1, 1, 2, 2, 2],
group_size)
group_size)

@testset "update labels" begin
@test proposal.node_labels[swap[1]] == reference_proposal.node_labels[swap[1]] == 2
Expand Down
Loading

0 comments on commit ccbc6b3

Please sign in to comment.