Skip to content

Commit

Permalink
Merge pull request #446 from tansongchen/fwd-ls
Browse files Browse the repository at this point in the history
Add forward mode to line search
  • Loading branch information
avik-pal authored Jun 14, 2024
2 parents f3b2e1f + b7c54f3 commit 6740074
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 26 deletions.
82 changes: 60 additions & 22 deletions src/globalization/line_search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ end
ϕdϕ
method
alpha
grad_op
deriv_op
u_cache
fu_cache
stats::NLStats
Expand All @@ -110,25 +110,59 @@ function __internal_init(
@warn "Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \
Detected $(autodiff). Falling back to AutoFiniteDiff."
end
grad_op = @closure (u, fu, p) -> last(__value_derivative(
autodiff, Base.Fix2(f, p), u)) * fu
deriv_op = @closure (du, u, fu, p) -> last(__value_derivative(
autodiff, Base.Fix2(f, p), u)) *
fu *
du
else
if SciMLBase.has_jvp(f)
# Both forward and reverse AD can be used for line-search.
# We prefer forward AD for better performance, however, reverse AD is also supported if user explicitly requests it.
# 1. If jvp is available, we use forward AD;
# 2. If vjp is available, we use reverse AD;
# 3. If reverse type is requested, we use reverse AD;
# 4. Finally, we use forward AD.
if alg.autodiff isa AutoFiniteDiff
deriv_op = nothing
elseif SciMLBase.has_jvp(f)
if isinplace(prob)
g_cache = __similar(u)
grad_op = @closure (u, fu, p) -> f.vjp(g_cache, fu, u, p)
jvp_cache = __similar(fu)
deriv_op = @closure (du, u, fu, p) -> begin
f.jvp(jvp_cache, du, u, p)
dot(fu, jvp_cache)
end
else
grad_op = @closure (u, fu, p) -> f.vjp(fu, u, p)
deriv_op = @closure (du, u, fu, p) -> dot(fu, f.jvp(du, u, p))
end
else
elseif SciMLBase.has_vjp(f)
if isinplace(prob)
vjp_cache = __similar(u)
deriv_op = @closure (du, u, fu, p) -> begin
f.vjp(vjp_cache, fu, u, p)
dot(du, vjp_cache)
end
else
deriv_op = @closure (du, u, fu, p) -> dot(du, f.vjp(fu, u, p))
end
elseif alg.autodiff !== nothing &&
ADTypes.mode(alg.autodiff) isa ADTypes.ReverseMode
autodiff = get_concrete_reverse_ad(
alg.autodiff, prob; check_reverse_mode = true)
vjp_op = VecJacOperator(prob, fu, u; autodiff)
if isinplace(prob)
g_cache = __similar(u)
grad_op = @closure (u, fu, p) -> vjp_op(g_cache, fu, u, p)
vjp_cache = __similar(u)
deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(vjp_cache, fu, u, p))
else
deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(fu, u, p))
end
else
autodiff = get_concrete_forward_ad(
alg.autodiff, prob; check_forward_mode = true)
jvp_op = JacVecOperator(prob, fu, u; autodiff)
if isinplace(prob)
jvp_cache = __similar(fu)
deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(jvp_cache, du, u, p))
else
grad_op = @closure (u, fu, p) -> vjp_op(fu, u, p)
deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(du, u, p))
end
end
end
Expand All @@ -143,33 +177,37 @@ function __internal_init(
return @fastmath internalnorm(fu_cache)^2 / 2
end

= @closure (f, p, u, du, α, u_cache, fu_cache, grad_op) -> begin
= @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op) -> begin
@bb @. u_cache = u + α * du
fu_cache = evaluate_f!!(f, fu_cache, u_cache, p)
stats.nf += 1
g₀ = grad_op(u_cache, fu_cache, p)
return dot(g₀, du)
return deriv_op(du, u_cache, fu_cache, p)
end

ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, grad_op) -> begin
ϕdϕ = @closure (f, p, u, du, α, u_cache, fu_cache, deriv_op) -> begin
@bb @. u_cache = u + α * du
fu_cache = evaluate_f!!(f, fu_cache, u_cache, p)
stats.nf += 1
g₀ = grad_op(u_cache, fu_cache, p)
deriv = deriv_op(du, u_cache, fu_cache, p)
obj = @fastmath internalnorm(fu_cache)^2 / 2
return obj, dot(g₀, du)
return obj, deriv
end

return LineSearchesJLCache(f, p, ϕ, dϕ, ϕdϕ, alg.method, T(alg.initial_alpha),
grad_op, u_cache, fu_cache, stats)
deriv_op, u_cache, fu_cache, stats)
end

function __internal_solve!(cache::LineSearchesJLCache, u, du; kwargs...)
ϕ = @closure α -> cache.ϕ(cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache)
= @closure α -> cache.(
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.grad_op)
ϕdϕ = @closure α -> cache.ϕdϕ(
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.grad_op)
if cache.deriv_op !== nothing
= @closure α -> cache.(
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.deriv_op)
ϕdϕ = @closure α -> cache.ϕdϕ(
cache.f, cache.p, u, du, α, cache.u_cache, cache.fu_cache, cache.deriv_op)
else
= @closure α -> FiniteDiff.finite_difference_derivative(ϕ, α)
ϕdϕ = @closure α -> (ϕ(α), FiniteDiff.finite_difference_derivative(ϕ, α))
end

ϕ₀, dϕ₀ = ϕdϕ(zero(eltype(u)))

Expand Down
8 changes: 4 additions & 4 deletions test/core/rootfind_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
@testitem "NewtonRaphson" setup=[CoreRootfindTesting] tags=[:core] timeout=3600 begin
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (
Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()),
ad in (AutoFiniteDiff(), AutoZygote())
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff())

linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
Expand Down Expand Up @@ -466,7 +466,7 @@ end
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad)) Init Jacobian: $(init_jacobian) Update Rule: $(update_rule)" for lsmethod in (
Static(), StrongWolfe(), BackTracking(),
HagerZhang(), MoreThuente(), LiFukushimaLineSearch()),
ad in (AutoFiniteDiff(), AutoZygote()),
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff()),
init_jacobian in (Val(:identity), Val(:true_jacobian)),
update_rule in (Val(:good_broyden), Val(:bad_broyden), Val(:diagonal))

Expand Down Expand Up @@ -515,7 +515,7 @@ end
@testitem "Klement" setup=[CoreRootfindTesting] tags=[:core] skip=:(Sys.isapple()) timeout=3600 begin
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad)) Init Jacobian: $(init_jacobian)" for lsmethod in (
Static(), StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()),
ad in (AutoFiniteDiff(), AutoZygote()),
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff()),
init_jacobian in (Val(:identity), Val(:true_jacobian), Val(:true_jacobian_diagonal))

linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)
Expand Down Expand Up @@ -565,7 +565,7 @@ end
@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (
Static(), StrongWolfe(), BackTracking(),
HagerZhang(), MoreThuente(), LiFukushimaLineSearch()),
ad in (AutoFiniteDiff(), AutoZygote())
ad in (AutoForwardDiff(), AutoZygote(), AutoFiniteDiff())

linesearch = LineSearchesJL(; method = lsmethod, autodiff = ad)
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
Expand Down

2 comments on commit 6740074

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/109035

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.13.0 -m "<description of version>" 6740074b450d05ed5296cbc83d483f6288781fa2
git push origin v3.13.0

Please sign in to comment.