Skip to content

Commit

Permalink
test: add tests for the bracketing methods
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 16, 2024
1 parent 56c6d4d commit d669bb8
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 7 deletions.
8 changes: 8 additions & 0 deletions lib/BracketingNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,11 @@ NonlinearSolveBase = "1"
PrecompileTools = "1.2.1"
SciMLBase = "2.50"
julia = "1.10"

[extras]
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"

[targets]
test = ["InteractiveUtils", "Test", "TestItemRunner"]
5 changes: 4 additions & 1 deletion lib/BracketingNonlinearSolve/src/BracketingNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module BracketingNonlinearSolve

using ConcreteStructs: @concrete

using CommonSolve: CommonSolve
using CommonSolve: CommonSolve, solve
using NonlinearSolveBase: NonlinearSolveBase
using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, IntervalNonlinearProblem, ReturnCode

Expand Down Expand Up @@ -33,6 +33,9 @@ include("ridder.jl")
end
end

export IntervalNonlinearProblem
export solve

export Alefeld, Bisection, Brent, Falsi, ITP, Ridder

end
6 changes: 4 additions & 2 deletions lib/BracketingNonlinearSolve/src/brent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;

if abs(fl) < abs(fr)
d = c
c, right, left = right, left, c
fc, fr, fl = fr, fl, fc
c, right = right, left
left = c
fc, fr = fr, fl
fl = fc
end
i += 1
end
Expand Down
94 changes: 94 additions & 0 deletions lib/BracketingNonlinearSolve/test/rootfind_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
@testsnippet RootfindingTestSnippet begin
using NonlinearSolveBase, BracketingNonlinearSolve

quadratic_f(u, p) = u .* u .- p
quadratic_f!(du, u, p) = (du .= u .* u .- p)
quadratic_f2(u, p) = @. p[1] * u * u - p[2]
end

@testitem "Interval Nonlinear Problems" setup=[RootfindingTestSnippet] tags=[:core] begin
@testset for alg in (Bisection(), Falsi(), Ridder(), Brent(), ITP(), Alefeld())
tspan = (1.0, 20.0)

function g(p)
probN = IntervalNonlinearProblem{false}(quadratic_f, typeof(p).(tspan), p)
return solve(probN, alg; abstol = 1e-9).left
end

@testset for p in 1.1:0.1:100.0
@test g(p)sqrt(p) atol=1e-3 rtol=1e-3
# @test ForwardDiff.derivative(g, p)≈1 / (2 * sqrt(p)) atol=1e-3 rtol=1e-3
end

t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]

function g2(p)
probN = IntervalNonlinearProblem{false}(quadratic_f2, tspan, p)
sol = solve(probN, alg; abstol = 1e-9)
return [sol.u]
end

@test g2(p)[sqrt(p[2] / p[1])] atol=1e-3 rtol=1e-3
# @test ForwardDiff.jacobian(g2, p)≈ForwardDiff.jacobian(t, p) atol=1e-3 rtol=1e-3

probB = IntervalNonlinearProblem{false}(quadratic_f, (1.0, 2.0), 2.0)
sol = solve(probB, alg; abstol = 1e-9)
@test sol.leftsqrt(2.0) atol=1e-3 rtol=1e-3

if !(alg isa Bisection || alg isa Falsi)
probB = IntervalNonlinearProblem{false}(quadratic_f, (sqrt(2.0), 10.0), 2.0)
sol = solve(probB, alg; abstol = 1e-9)
@test sol.leftsqrt(2.0) atol=1e-3 rtol=1e-3

probB = IntervalNonlinearProblem{false}(quadratic_f, (0.0, sqrt(2.0)), 2.0)
sol = solve(probB, alg; abstol = 1e-9)
@test sol.leftsqrt(2.0) atol=1e-3 rtol=1e-3
end
end
end

@testitem "Tolerance Tests Interval Methods" setup=[RootfindingTestSnippet] tags=[:core] begin
prob = IntervalNonlinearProblem(quadratic_f, (1.0, 20.0), 2.0)
ϵ = eps(Float64) # least possible tol for all methods

@testset for alg in (Bisection(), Falsi(), ITP())
@testset for abstol in [0.1, 0.01, 0.001, 0.0001, 1e-5, 1e-6, 1e-7]
sol = solve(prob, alg; abstol)
result_tol = abs(sol.u - sqrt(2))
@test result_tol < abstol
# test that the solution is not calculated upto max precision
@test result_tol > ϵ
end
end

@testset for alg in (Ridder(), Brent())
# Ridder and Brent converge rapidly so as we lower tolerance below 0.01, it
# converges with max precision to the solution
@testset for abstol in [0.1]
sol = solve(prob, alg; abstol)
result_tol = abs(sol.u - sqrt(2))
@test result_tol < abstol
# test that the solution is not calculated upto max precision
@test result_tol > ϵ
end
end
end

@testitem "Flipped Signs and Reversed Tspan" setup=[RootfindingTestSnippet] tags=[:core] begin
@testset for alg in (Alefeld(), Bisection(), Falsi(), Brent(), ITP(), Ridder())
f1(u, p) = u * u - p
f2(u, p) = p - u * u

for p in 1:4
inp1 = IntervalNonlinearProblem(f1, (1.0, 2.0), p)
inp2 = IntervalNonlinearProblem(f2, (1.0, 2.0), p)
inp3 = IntervalNonlinearProblem(f1, (2.0, 1.0), p)
inp4 = IntervalNonlinearProblem(f2, (2.0, 1.0), p)
@test abs.(solve(inp1, alg).u) sqrt.(p)
@test abs.(solve(inp2, alg).u) sqrt.(p)
@test abs.(solve(inp3, alg).u) sqrt.(p)
@test abs.(solve(inp4, alg).u) sqrt.(p)
end
end
end
4 changes: 4 additions & 0 deletions lib/BracketingNonlinearSolve/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
using TestItemRunner, InteractiveUtils

@info sprint(InteractiveUtils.versioninfo)

@run_package_tests
2 changes: 0 additions & 2 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -35,5 +34,4 @@ RecursiveArrayTools = "3"
SciMLBase = "2.50"
SparseArrays = "1.10"
StaticArraysCore = "1.4"
UnrolledUtilities = "0.1"
julia = "1.10"
3 changes: 1 addition & 2 deletions lib/NonlinearSolveBase/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ using ArrayInterface: ArrayInterface
using FastClosures: @closure
using LinearAlgebra: norm
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
using UnrolledUtilities: unrolled_all

using ..NonlinearSolveBase: L2_NORM, Linf_NORM

fast_scalar_indexing(xs...) = unrolled_all(ArrayInterface.fast_scalar_indexing, xs)
fast_scalar_indexing(xs...) = all(ArrayInterface.fast_scalar_indexing, xs)

function nonallocating_isapprox(x::Number, y::Number; atol = false,
rtol = atol > 0 ? false : sqrt(eps(promote_type(typeof(x), typeof(y)))))
Expand Down

0 comments on commit d669bb8

Please sign in to comment.