diff --git a/test/callback.jl b/test/callback.jl index 91313d5..f4786d9 100644 --- a/test/callback.jl +++ b/test/callback.jl @@ -3,11 +3,15 @@ game = Kuhn() sol = CFRSolver(game) cb = CFR.ExploitabilityCallback(sol, 100) - train!(sol, 100_000, cb = cb) + cb_nashconv = NashConvCallback(sol, 100) + train!(sol, 100_000, cb = CallbackChain(cb, cb_nashconv)) @test RecipesBase.apply_recipe(Dict{Symbol,Any}(), cb) ≠ nothing + @test RecipesBase.apply_recipe(Dict{Symbol,Any}(), cb_nashconv) ≠ nothing @test length(cb.hist.y) == length(cb.hist.x) == 1_000 + @test length(cb_nashconv.hist.y) == length(cb_nashconv.hist.x) == 1_000 @test 0.0 < last(cb.hist.y) < 1e-2 + @test 0.0 < last(cb_nashconv.hist.y) < 1e-2 game = MatrixGame([(randn(), randn()) for i in 1:5, j in 1:5]) sol = CFRSolver(game) diff --git a/test/is-mcts.jl b/test/is-mcts.jl index a087969..7bd1365 100644 --- a/test/is-mcts.jl +++ b/test/is-mcts.jl @@ -4,18 +4,24 @@ ## MaxUCB sol = CFRSolver(game) true_exploit_cb = ExploitabilityCallback(sol, 10) - max_ucb_cb = ExploitabilityCallback(ISMCTS(sol; max_iter=100_000, criterion=CFR.MaxUCB()), 10) - poly_ucb_cb = ExploitabilityCallback(ISMCTS(sol; max_iter=100_000, criterion=CFR.PolyUCB()), 10) - max_q_cb = ExploitabilityCallback(ISMCTS(sol; max_iter=100_000, criterion=CFR.MaxQ()), 10) + true_nashconv_cb= NashConvCallback(sol, 10) + max_ucb_cb = MCTSExploitabilityCallback(sol, 10; max_iter=100_000, criterion=CFR.MaxUCB()) + poly_ucb_cb = MCTSExploitabilityCallback(sol, 10; max_iter=100_000, criterion=CFR.PolyUCB()) + max_q_cb = MCTSExploitabilityCallback(sol, 10; max_iter=100_000, criterion=CFR.MaxQ()) + nashconv_cb = MCTSNashConvCallback(sol, 10; max_iter=100_000) - cb_chain = CFR.CallbackChain(true_exploit_cb, max_ucb_cb, poly_ucb_cb, max_q_cb) + + cb_chain = CFR.CallbackChain(true_exploit_cb, true_nashconv_cb, max_ucb_cb, poly_ucb_cb, max_q_cb, nashconv_cb) train!(sol, 100; cb=cb_chain) err_ucb = abs.(max_ucb_cb.hist.y .- true_exploit_cb.hist.y) err_poly = abs.(poly_ucb_cb.hist.y .- true_exploit_cb.hist.y) err_q = abs.(max_q_cb.hist.y .- true_exploit_cb.hist.y) + err_nashconv = abs.(nashconv_cb.hist.y .- true_nashconv_cb.hist.y) + @test all(err_ucb[2:end] .< 0.1) @test all(err_poly[2:end] .< 0.1) @test all(err_q[2:end] .< 0.2) + @test all(err_nashconv[2:end] .< 0.2) end