Skip to content

Commit

Permalink
nashconv tests
Browse files Browse the repository at this point in the history
  • Loading branch information
WhiffleFish committed Jul 10, 2023
1 parent 8d18247 commit 872c172
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
6 changes: 5 additions & 1 deletion test/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
game = Kuhn()
sol = CFRSolver(game)
cb = CFR.ExploitabilityCallback(sol, 100)
train!(sol, 100_000, cb = cb)
cb_nashconv = NashConvCallback(sol, 100)
train!(sol, 100_000, cb = CallbackChain(cb, cb_nashconv))

@test RecipesBase.apply_recipe(Dict{Symbol,Any}(), cb) nothing
@test RecipesBase.apply_recipe(Dict{Symbol,Any}(), cb_nashconv) nothing
@test length(cb.hist.y) == length(cb.hist.x) == 1_000
@test length(cb_nashconv.hist.y) == length(cb_nashconv.hist.x) == 1_000
@test 0.0 < last(cb.hist.y) < 1e-2
@test 0.0 < last(cb_nashconv.hist.y) < 1e-2

game = MatrixGame([(randn(), randn()) for i in 1:5, j in 1:5])
sol = CFRSolver(game)
Expand Down
14 changes: 10 additions & 4 deletions test/is-mcts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,24 @@
## MaxUCB
sol = CFRSolver(game)
true_exploit_cb = ExploitabilityCallback(sol, 10)
max_ucb_cb = ExploitabilityCallback(ISMCTS(sol; max_iter=100_000, criterion=CFR.MaxUCB()), 10)
poly_ucb_cb = ExploitabilityCallback(ISMCTS(sol; max_iter=100_000, criterion=CFR.PolyUCB()), 10)
max_q_cb = ExploitabilityCallback(ISMCTS(sol; max_iter=100_000, criterion=CFR.MaxQ()), 10)
true_nashconv_cb= NashConvCallback(sol, 10)
max_ucb_cb = MCTSExploitabilityCallback(sol, 10; max_iter=100_000, criterion=CFR.MaxUCB())
poly_ucb_cb = MCTSExploitabilityCallback(sol, 10; max_iter=100_000, criterion=CFR.PolyUCB())
max_q_cb = MCTSExploitabilityCallback(sol, 10; max_iter=100_000, criterion=CFR.MaxQ())
nashconv_cb = MCTSNashConvCallback(sol, 10; max_iter=100_000)

cb_chain = CFR.CallbackChain(true_exploit_cb, max_ucb_cb, poly_ucb_cb, max_q_cb)

cb_chain = CFR.CallbackChain(true_exploit_cb, true_nashconv_cb, max_ucb_cb, poly_ucb_cb, max_q_cb, nashconv_cb)
train!(sol, 100; cb=cb_chain)

err_ucb = abs.(max_ucb_cb.hist.y .- true_exploit_cb.hist.y)
err_poly = abs.(poly_ucb_cb.hist.y .- true_exploit_cb.hist.y)
err_q = abs.(max_q_cb.hist.y .- true_exploit_cb.hist.y)

err_nashconv = abs.(nashconv_cb.hist.y .- true_nashconv_cb.hist.y)

@test all(err_ucb[2:end] .< 0.1)
@test all(err_poly[2:end] .< 0.1)
@test all(err_q[2:end] .< 0.2)
@test all(err_nashconv[2:end] .< 0.2)
end

0 comments on commit 872c172

Please sign in to comment.