Skip to content

Commit

Permalink
fix: extension for forward AD support
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 17, 2024
1 parent 0b70d62 commit 8fadcc4
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion lib/BracketingNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
BracketingNonlinearSolveForwardDiffExt = "ForwardDiff"

[compat]
CommonSolve = "0.2.4"
Expand Down
2 changes: 0 additions & 2 deletions lib/BracketingNonlinearSolve/test/rootfind_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
@testsnippet RootfindingTestSnippet begin
using NonlinearSolveBase, BracketingNonlinearSolve

quadratic_f(u, p) = u .* u .- p
quadratic_f!(du, u, p) = (du .= u .* u .- p)
quadratic_f2(u, p) = @. p[1] * u * u - p[2]
Expand Down
2 changes: 2 additions & 0 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "1.0.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Expand All @@ -24,6 +25,7 @@ NonlinearSolveBaseSparseArraysExt = "SparseArrays"

[compat]
ArrayInterface = "7.9"
CommonSolve = "0.2.4"
Compat = "4.15"
ConcreteStructs = "0.2.3"
FastClosures = "0.3"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
end

function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F}
if isinplace(prob)
if SciMLBase.isinplace(prob)
f = @closure p -> begin
du = Utils.safe_similar(u, promote_type(eltype(u), eltype(p)))
f(du, u, p)
Expand All @@ -62,10 +62,11 @@ function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F}
end

function nonlinearsolve_∂f_∂u(prob, f::F, u, p) where {F}
if isinplace(prob)
if SciMLBase.isinplace(prob)
return ForwardDiff.jacobian(
@closure((du, u)->f(du, u, p)), Utils.safe_similar(u), u)
end
u isa Number && return ForwardDiff.derivative(Base.Fix2(f, p), u)
return ForwardDiff.jacobian(Base.Fix2(f, p), u)
end

Expand Down

0 comments on commit 8fadcc4

Please sign in to comment.