Skip to content

Commit

Permalink
feat: ForwardDiff support in NonlinearSolveBase
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 17, 2024
1 parent d669bb8 commit 0b70d62
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 6 deletions.
10 changes: 9 additions & 1 deletion lib/BracketingNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,26 @@ NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
NonlinearSolveBaseForwardDiffExt = "ForwardDiff"

[compat]
CommonSolve = "0.2.4"
ConcreteStructs = "0.2.3"
ForwardDiff = "0.10.36"
NonlinearSolveBase = "1"
PrecompileTools = "1.2.1"
SciMLBase = "2.50"
julia = "1.10"

[extras]
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"

[targets]
test = ["InteractiveUtils", "Test", "TestItemRunner"]
test = ["InteractiveUtils", "ForwardDiff", "Test", "TestItemRunner"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module BracketingNonlinearSolveForwardDiffExt

using CommonSolve: CommonSolve
using ForwardDiff: ForwardDiff, Dual
using NonlinearSolveBase: nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution
using SciMLBase: SciMLBase, IntervalNonlinearProblem

using BracketingNonlinearSolve: Bisection, Brent, Alefeld, Falsi, ITP, Ridder

for algT in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@eval function CommonSolve.solve(
prob::IntervalNonlinearProblem{
uType, iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::$(algT),
args...;
kwargs...) where {uType, iip, T, V, P}
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats,
sol.original, left = Dual{T, V, P}(sol.left, partials),
right = Dual{T, V, P}(sol.right, partials))
end
end

end
6 changes: 4 additions & 2 deletions lib/BracketingNonlinearSolve/test/rootfind_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
end

@testitem "Interval Nonlinear Problems" setup=[RootfindingTestSnippet] tags=[:core] begin
using ForwardDiff

@testset for alg in (Bisection(), Falsi(), Ridder(), Brent(), ITP(), Alefeld())
tspan = (1.0, 20.0)

Expand All @@ -17,7 +19,7 @@ end

@testset for p in 1.1:0.1:100.0
@test g(p)sqrt(p) atol=1e-3 rtol=1e-3
# @test ForwardDiff.derivative(g, p)≈1 / (2 * sqrt(p)) atol=1e-3 rtol=1e-3
@test ForwardDiff.derivative(g, p)1 / (2 * sqrt(p)) atol=1e-3 rtol=1e-3
end

t = (p) -> [sqrt(p[2] / p[1])]
Expand All @@ -30,7 +32,7 @@ end
end

@test g2(p)[sqrt(p[2] / p[1])] atol=1e-3 rtol=1e-3
# @test ForwardDiff.jacobian(g2, p)≈ForwardDiff.jacobian(t, p) atol=1e-3 rtol=1e-3
@test ForwardDiff.jacobian(g2, p)ForwardDiff.jacobian(t, p) atol=1e-3 rtol=1e-3

probB = IntervalNonlinearProblem{false}(quadratic_f, (1.0, 2.0), 2.0)
sol = solve(probB, alg; abstol = 1e-9)
Expand Down
75 changes: 74 additions & 1 deletion lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,82 @@
module NonlinearSolveBaseForwardDiffExt

using CommonSolve: solve
using FastClosures: @closure
using ForwardDiff: ForwardDiff, Dual
using NonlinearSolveBase: Utils
using SciMLBase: SciMLBase, IntervalNonlinearProblem, NonlinearProblem,
NonlinearLeastSquaresProblem, remake

using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils

Utils.value(::Type{Dual{T, V, N}}) where {T, V, N} = V
Utils.value(x::Dual) = Utils.value(ForwardDiff.value(x))

function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem},
alg, args...; kwargs...)
p = Utils.value(prob.p)
if prob isa IntervalNonlinearProblem
tspan = Utils.value.(prob.tspan)
newprob = IntervalNonlinearProblem(prob.f, tspan, p; prob.kwargs...)
else
newprob = remake(prob; p, u0 = Utils.value(prob.u0))
end

sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u
Jₚ = nonlinearsolve_∂f_∂p(prob, prob.f, uu, p)
Jᵤ = nonlinearsolve_∂f_∂u(prob, prob.f, uu, p)
z = -Jᵤ \ Jₚ
pp = prob.p
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)

if uu isa Number
partials = sum(sumfun, zip(z, pp))
elseif p isa Number
partials = sumfun((z, pp))
else
partials = sum(sumfun, zip(eachcol(z), pp))
end

return sol, partials
end

function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F}
if isinplace(prob)
f = @closure p -> begin
du = Utils.safe_similar(u, promote_type(eltype(u), eltype(p)))
f(du, u, p)
return du
end
else
f = Base.Fix1(f, u)
end
if p isa Number
return Utils.safe_reshape(ForwardDiff.derivative(f, p), :, 1)
elseif u isa Number
return Utils.safe_reshape(ForwardDiff.gradient(f, p), 1, :)
else
return ForwardDiff.jacobian(f, p)
end
end

function nonlinearsolve_∂f_∂u(prob, f::F, u, p) where {F}
if isinplace(prob)
return ForwardDiff.jacobian(
@closure((du, u)->f(du, u, p)), Utils.safe_similar(u), u)
end
return ForwardDiff.jacobian(Base.Fix2(f, p), u)
end

function NonlinearSolveBase.nonlinearsolve_dual_solution(u::Number, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return Dual{T, V, P}(u, partials)
end

function NonlinearSolveBase.nonlinearsolve_dual_solution(u::AbstractArray, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials)))
end

end
2 changes: 1 addition & 1 deletion lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ include("utils.jl")

include("common_defaults.jl")
include("termination_conditions.jl")
include("autodiff.jl")
include("immutable_problem.jl")

# Unexported Public API
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
@compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution))

export RelTerminationMode, AbsTerminationMode, NormTerminationMode, RelNormTerminationMode,
AbsNormTerminationMode, RelNormSafeTerminationMode, AbsNormSafeTerminationMode,
Expand Down
1 change: 0 additions & 1 deletion lib/NonlinearSolveBase/src/autodiff.jl

This file was deleted.

4 changes: 4 additions & 0 deletions lib/NonlinearSolveBase/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ function L2_NORM end
function Linf_NORM end
function get_tolerance end

# Forward declarations of functions for forward mode AD
function nonlinearsolve_forwarddiff_solve end
function nonlinearsolve_dual_solution end

# Nonlinear Solve Termination Conditions
abstract type AbstractNonlinearTerminationMode end
abstract type AbstractSafeNonlinearTerminationMode <: AbstractNonlinearTerminationMode end
Expand Down
18 changes: 18 additions & 0 deletions lib/NonlinearSolveBase/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,22 @@ apply_norm(f::F, x, y) where {F} = norm_op(standardize_norm(f), +, x, y)
convert_real(::Type{T}, ::Nothing) where {T} = nothing
convert_real(::Type{T}, x) where {T} = real(T(x))

restructure(::Number, x::Number) = x
restructure(y, x) = ArrayInterface.restructure(y, x)

function safe_similar(x, args...; kwargs...)
y = similar(x, args...; kwargs...)
return init_bigfloat_array!!(y)
end

init_bigfloat_array!!(x) = x

function init_bigfloat_array!!(x::AbstractArray{<:BigFloat})
ArrayInterface.can_setindex(x) && fill!(x, BigFloat(0))
return x
end

safe_reshape(x::Number, args...) = x
safe_reshape(x, args...) = reshape(x, args...)

end

0 comments on commit 0b70d62

Please sign in to comment.