diff --git a/src/find.jl b/src/find.jl index 2099fa5..11e99f6 100644 --- a/src/find.jl +++ b/src/find.jl @@ -27,26 +27,40 @@ function _find( combs = combinations(is, n) optimal_is = zeros(Int, n) - max_reliability = -Inf + max_reliability = Ref(-Inf) - prog = ProgressBar(transient = true) + if progress + prog = ProgressBar(transient = true) - Progress.with(prog) do - prog_job = - addjob!(prog, N = length(combs), description = "Finding optimal item subset...") + Progress.with(prog) do + prog_job = addjob!( + prog, + N = length(combs), + description = "Finding optimal item subset...", + ) - for c in combs - subtest = view(m, :, c) - reliability = method(subtest) - - if reliability > max_reliability - max_reliability = reliability - optimal_is = c + for c in combs + _update_reliability!(max_reliability, optimal_is, m, c, method) + update!(prog_job) end - - update!(prog_job) + end + else + for c in combs + _update_reliability!(max_reliability, optimal_is, m, c, method) end end return optimal_is end + +function _update_reliability!(max_reliability, optimal_is, m, c, method) + subtest = view(m, :, c) + reliability = method(subtest) + + if reliability > max_reliability[] + max_reliability[] = reliability + optimal_is .= c + end + + return nothing +end diff --git a/src/precompile.jl b/src/precompile.jl index 9503ef7..aca6c68 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -28,7 +28,7 @@ using PrecompileTools # find for method in methods - find(m, 2, method) + find(m, 2, method, progress = false) end end end diff --git a/test/find.jl b/test/find.jl index 99ac3b9..5b3ccb3 100644 --- a/test/find.jl +++ b/test/find.jl @@ -18,9 +18,10 @@ # test @test_throws ArgumentError find(m, n_items + 2) - @test size(find(m, 2)) == (n_persons, 2) - @test size(find(m, 1)) == (n_persons, 1) + @test size(find(m, 2, progress = false)) == (n_persons, 2) + @test size(find(m, 1, progress = false)) == (n_persons, 1) - @test find(m_extended, n_items) == m + @test find(m_extended, n_items, progress = false) == m + @test find(m, 2, progress = true) == find(m, 2, progress = false) end