Skip to content

Commit

Permalink
Switch from AbstractDifferentiation to DifferentiationInterface (#93)
Browse files Browse the repository at this point in the history
Following our discussion per email, this PR proposes a switch from
AbstractDifferentiation.jl to DifferentiationInterface.jl, which is
becoming the new standard in the ecosystem.

- [x] Modify `Project.toml` files and imports
- [x] Replace `SomethingBackend()` with `AutoSomething()`
- [x] Replace `value_and_gradient_closure` with `value_and_gradient`
(unclear how performance is affected)
- [x] Update documentation and README
- [ ] Add [preparation
mechanism](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/stable/operators/#Preparation):
available on another branch but not sure we want it because if the
function contains value-dependent control flow, preparation is not
appropriate

---------

Co-authored-by: Lorenzo Stella <lorenzostella@gmail.com>
  • Loading branch information
gdalle and lostella authored Oct 1, 2024
1 parent b3e667e commit 27f8a96
Show file tree
Hide file tree
Showing 31 changed files with 126 additions and 138 deletions.
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
name = "ProximalAlgorithms"
uuid = "140ffc9f-1907-541a-a177-7475e0a401e9"
version = "0.6.0"
version = "0.7.0"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"

[compat]
AbstractDifferentiation = "0.6"
ADTypes = "1.5.3"
DifferentiationInterface = "0.5.8"
LinearAlgebra = "1.2"
Printf = "1.2"
ProximalCore = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Implemented algorithms include:
Check out [this section](https://juliafirstorder.github.io/ProximalAlgorithms.jl/stable/guide/implemented_algorithms/) for an overview of the available algorithms.

Algorithms rely on:
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) for automatic differentiation (but you can easily bring your own gradients)
- [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) for automatic differentiation (but you can easily bring your own gradients)
- the [ProximalCore API](https://github.com/JuliaFirstOrder/ProximalCore.jl) for proximal mappings, projections, etc, to handle non-differentiable terms (see for example [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) for an extensive collection of functions).

## Documentation
Expand Down
8 changes: 4 additions & 4 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ using FileIO

const SUITE = BenchmarkGroup()

function ProximalAlgorithms.value_and_gradient_closure(
function ProximalAlgorithms.value_and_gradient(
f::ProximalOperators.LeastSquaresDirect,
x,
)
res = f.A * x - f.b
norm(res)^2 / 2, () -> f.A' * res
norm(res)^2 / 2, f.A' * res
end

struct SquaredDistance{Tb}
Expand All @@ -22,9 +22,9 @@ end

(f::SquaredDistance)(x) = norm(x - f.b)^2 / 2

function ProximalAlgorithms.value_and_gradient_closure(f::SquaredDistance, x)
function ProximalAlgorithms.value_and_gradient(f::SquaredDistance, x)
diff = x - f.b
norm(diff)^2 / 2, () -> diff
norm(diff)^2 / 2, diff
end

for (benchmark_name, file_name) in [
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/examples/sparse_linear_regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ end
mean_squared_error(label, output) = mean((output .- label) .^ 2) / 2

using Zygote
using AbstractDifferentiation: ZygoteBackend
using DifferentiationInterface: AutoZygote
using ProximalAlgorithms

training_loss = ProximalAlgorithms.AutoDifferentiable(
wb -> mean_squared_error(training_label, standardized_linear_model(wb, training_input)),
ZygoteBackend(),
AutoZygote(),
)

# As regularization we will use the L1 norm, implemented in [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl):
Expand Down
29 changes: 15 additions & 14 deletions docs/src/guide/custom_objectives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@
#
# Defining the proximal mapping for a custom function type requires adding a method for [`ProximalCore.prox!`](@ref).
#
# To compute gradients, algorithms use [`value_and_gradient_closure`](@ref):
# this relies on [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl), for automatic differentiation
# To compute gradients, algorithms use [`value_and_gradient`](@ref):
# this relies on [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl), for automatic differentiation
# with any of its supported backends, when functions are wrapped in [`AutoDifferentiable`](@ref),
# as the examples below show.
#
# If however you would like to provide your own gradient implementation (e.g. for efficiency reasons),
# you can simply implement a method for [`value_and_gradient_closure`](@ref) on your own function type.
# you can simply implement a method for [`value_and_gradient`](@ref) on your own function type.
#
# ```@docs
# ProximalCore.prox
# ProximalCore.prox!
# ProximalAlgorithms.value_and_gradient_closure
# ProximalAlgorithms.value_and_gradient
# ProximalAlgorithms.AutoDifferentiable
# ```
#
Expand All @@ -32,12 +32,12 @@
# Let's try to minimize the celebrated Rosenbrock function, but constrained to the unit norm ball. The cost function is

using Zygote
using AbstractDifferentiation: ZygoteBackend
using DifferentiationInterface: AutoZygote
using ProximalAlgorithms

rosenbrock2D = ProximalAlgorithms.AutoDifferentiable(
x -> 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2,
ZygoteBackend(),
AutoZygote(),
)

# To enforce the constraint, we define the indicator of the unit ball, together with its proximal mapping:
Expand Down Expand Up @@ -105,16 +105,17 @@ end

Counting(f::T) where {T} = Counting{T}(f, 0, 0, 0)

# Now we only need to intercept any call to [`value_and_gradient_closure`](@ref) and [`prox!`](@ref) and increase counters there:
function (f::Counting)(x)
f.eval_count += 1
return f.f(x)
end

function ProximalAlgorithms.value_and_gradient_closure(f::Counting, x)
# Now we only need to intercept any call to [`value_and_gradient`](@ref) and [`prox!`](@ref) and increase counters there:

function ProximalAlgorithms.value_and_gradient(f::Counting, x)
f.eval_count += 1
fx, pb = ProximalAlgorithms.value_and_gradient_closure(f.f, x)
function counting_pullback()
f.gradient_count += 1
return pb()
end
return fx, counting_pullback
f.gradient_count += 1
return ProximalAlgorithms.value_and_gradient(f.f, x)
end

function ProximalCore.prox!(y, f::Counting, x, gamma)
Expand Down
11 changes: 5 additions & 6 deletions docs/src/guide/getting_started.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# The literature on proximal operators and algorithms is vast: for an overview, one can refer to [Parikh2014](@cite), [Beck2017](@cite).
#
# To evaluate these first-order primitives, in ProximalAlgorithms:
# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl) and all of its backends).
# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) and all of its backends).
# * ``\operatorname{prox}_{f_i}`` relies on the intereface of [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) (>= 0.15).
# Both of the above can be implemented for custom function types, as [documented here](@ref custom_terms).
#
Expand Down Expand Up @@ -52,13 +52,13 @@

using LinearAlgebra
using Zygote
using AbstractDifferentiation: ZygoteBackend
using DifferentiationInterface: AutoZygote
using ProximalOperators
using ProximalAlgorithms

quadratic_cost = ProximalAlgorithms.AutoDifferentiable(
x -> dot([3.4 1.2; 1.2 4.5] * x, x) / 2 + dot([-2.3, 9.9], x),
ZygoteBackend(),
AutoZygote(),
)
box_indicator = ProximalOperators.IndBox(0, 1)

Expand All @@ -72,10 +72,9 @@ ffb = ProximalAlgorithms.FastForwardBackward(maxit = 1000, tol = 1e-5, verbose =
solution, iterations = ffb(x0 = ones(2), f = quadratic_cost, g = box_indicator)

# We can verify the correctness of the solution by checking that the negative gradient is orthogonal to the constraints, pointing outwards:
# for this, we just evaluate the closure `cl` returned as second output of [`value_and_gradient_closure`](@ref).
# for this, we just evaluate the second output of [`value_and_gradient`](@ref).

v, cl = ProximalAlgorithms.value_and_gradient_closure(quadratic_cost, solution)
-cl()
last(ProximalAlgorithms.value_and_gradient(quadratic_cost, solution))

# Or by plotting the solution against the cost function and constraint:

Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Implemented algorithms include:
Check out [this section](@ref problems_algorithms) for an overview of the available algorithms.

Algorithms rely on:
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) for automatic differentiation (but you can easily bring your own gradients),
- [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) for automatic differentiation (but you can easily bring your own gradients),
- the [ProximalCore API](https://github.com/JuliaFirstOrder/ProximalCore.jl) for proximal mappings, projections, etc, to handle non-differentiable terms (see for example [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) for an extensive collection of functions).

!!! note
Expand Down
26 changes: 12 additions & 14 deletions src/ProximalAlgorithms.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ProximalAlgorithms

using AbstractDifferentiation
using ADTypes: ADTypes
using DifferentiationInterface: DifferentiationInterface
using ProximalCore
using ProximalCore: prox, prox!

Expand All @@ -12,33 +13,30 @@ const Maybe{T} = Union{T,Nothing}
Callable struct wrapping function `f` to be auto-differentiated using `backend`.
When called, it evaluates the same as `f`, while [`value_and_gradient_closure`](@ref)
When called, it evaluates the same as `f`, while its gradient
is implemented using `backend` for automatic differentiation.
The backend can be any from [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl).
The backend can be any of those supported by [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl).
"""
struct AutoDifferentiable{F,B}
struct AutoDifferentiable{F,B<:ADTypes.AbstractADType}
f::F
backend::B
end

(f::AutoDifferentiable)(x) = f.f(x)

"""
value_and_gradient_closure(f, x)
value_and_gradient(f, x)
Return a tuple containing the value of `f` at `x`, and a closure `cl`.
Function `cl`, once called, yields the gradient of `f` at `x`.
Return a tuple containing the value of `f` at `x` and the gradient of `f` at `x`.
"""
value_and_gradient_closure
value_and_gradient

function value_and_gradient_closure(f::AutoDifferentiable, x)
fx, pb = AbstractDifferentiation.value_and_pullback_function(f.backend, f.f, x)
return fx, () -> pb(one(fx))[1]
function value_and_gradient(f::AutoDifferentiable, x)
return DifferentiationInterface.value_and_gradient(f.f, f.backend, x)
end

function value_and_gradient_closure(f::ProximalCore.Zero, x)
f(x), () -> zero(x)
function value_and_gradient(f::ProximalCore.Zero, x)
return f(x), zero(x)
end

# various utilities
Expand Down
7 changes: 3 additions & 4 deletions src/algorithms/davis_yin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ end
function Base.iterate(iter::DavisYinIteration)
z = copy(iter.x0)
xg, = prox(iter.g, z, iter.gamma)
f_xg, cl = value_and_gradient_closure(iter.f, xg)
grad_f_xg = cl()
f_xg, grad_f_xg = value_and_gradient(iter.f, xg)
z_half = 2 .* xg .- z .- iter.gamma .* grad_f_xg
xh, = prox(iter.h, z_half, iter.gamma)
res = xh - xg
Expand All @@ -68,8 +67,8 @@ end

function Base.iterate(iter::DavisYinIteration, state::DavisYinState)
prox!(state.xg, iter.g, state.z, iter.gamma)
f_xg, cl = value_and_gradient_closure(iter.f, state.xg)
state.grad_f_xg .= cl()
f_xg, grad_f_xg = value_and_gradient(iter.f, state.xg)
state.grad_f_xg .= grad_f_xg
state.z_half .= 2 .* state.xg .- state.z .- iter.gamma .* state.grad_f_xg
prox!(state.xh, iter.h, state.z_half, iter.gamma)
state.res .= state.xh .- state.xg
Expand Down
7 changes: 3 additions & 4 deletions src/algorithms/fast_forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ end

function Base.iterate(iter::FastForwardBackwardIteration)
x = copy(iter.x0)
f_x, cl = value_and_gradient_closure(iter.f, x)
grad_f_x = cl()
f_x, grad_f_x = value_and_gradient(iter.f, x)
gamma =
iter.gamma === nothing ?
1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma
Expand Down Expand Up @@ -136,8 +135,8 @@ function Base.iterate(
state.x .= state.z .+ beta .* (state.z .- state.z_prev)
state.z_prev, state.z = state.z, state.z_prev

state.f_x, cl = value_and_gradient_closure(iter.f, state.x)
state.grad_f_x .= cl()
state.f_x, grad_f_x = value_and_gradient(iter.f, state.x)
state.grad_f_x .= grad_f_x
state.y .= state.x .- state.gamma .* state.grad_f_x
state.g_z = prox!(state.z, iter.g, state.y, state.gamma)
state.res .= state.x .- state.z
Expand Down
7 changes: 3 additions & 4 deletions src/algorithms/forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ end

function Base.iterate(iter::ForwardBackwardIteration)
x = copy(iter.x0)
f_x, cl = value_and_gradient_closure(iter.f, x)
grad_f_x = cl()
f_x, grad_f_x = value_and_gradient(iter.f, x)
gamma =
iter.gamma === nothing ?
1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma
Expand Down Expand Up @@ -111,8 +110,8 @@ function Base.iterate(
state.grad_f_x, state.grad_f_z = state.grad_f_z, state.grad_f_x
else
state.x, state.z = state.z, state.x
state.f_x, cl = value_and_gradient_closure(iter.f, state.x)
state.grad_f_x .= cl()
state.f_x, grad_f_x = value_and_gradient(iter.f, state.x)
state.grad_f_x .= grad_f_x
end

state.y .= state.x .- state.gamma .* state.grad_f_x
Expand Down
10 changes: 4 additions & 6 deletions src/algorithms/li_lin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ end

function Base.iterate(iter::LiLinIteration{R}) where {R}
y = copy(iter.x0)
f_y, cl = value_and_gradient_closure(iter.f, y)
grad_f_y = cl()
f_y, grad_f_y = value_and_gradient(iter.f, y)

# TODO: initialize gamma if not provided
# TODO: authors suggest Barzilai-Borwein rule?
Expand Down Expand Up @@ -110,8 +109,7 @@ function Base.iterate(iter::LiLinIteration{R}, state::LiLinState{R,Tx}) where {R
else
# TODO: re-use available space in state?
# TODO: backtrack gamma at x
f_x, cl = value_and_gradient_closure(iter.f, x)
grad_f_x = cl()
f_x, grad_f_x = value_and_gradient(iter.f, x)
x_forward = state.x - state.gamma .* grad_f_x
v, g_v = prox(iter.g, x_forward, state.gamma)
Fv = iter.f(v) + g_v
Expand All @@ -130,8 +128,8 @@ function Base.iterate(iter::LiLinIteration{R}, state::LiLinState{R,Tx}) where {R
Fx = Fv
end

state.f_y, cl = value_and_gradient_closure(iter.f, state.y)
state.grad_f_y .= cl()
state.f_y, grad_f_y = value_and_gradient(iter.f, state.y)
state.grad_f_y .= grad_f_y
state.y_forward .= state.y .- state.gamma .* state.grad_f_y
state.g_z = prox!(state.z, iter.g, state.y_forward, state.gamma)

Expand Down
15 changes: 7 additions & 8 deletions src/algorithms/panoc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ f_model(iter::PANOCIteration, state::PANOCState) =
function Base.iterate(iter::PANOCIteration{R}) where {R}
x = copy(iter.x0)
Ax = iter.A * x
f_Ax, cl = value_and_gradient_closure(iter.f, Ax)
grad_f_Ax = cl()
f_Ax, grad_f_Ax = value_and_gradient(iter.f, Ax)
gamma =
iter.gamma === nothing ?
iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) :
Expand Down Expand Up @@ -182,8 +181,8 @@ function Base.iterate(iter::PANOCIteration{R,Tx,Tf}, state::PANOCState) where {R

state.x_d .= state.x .+ state.d
state.Ax_d .= state.Ax .+ state.Ad
state.f_Ax_d, cl = value_and_gradient_closure(iter.f, state.Ax_d)
state.grad_f_Ax_d .= cl()
state.f_Ax_d, grad_f_Ax_d = value_and_gradient(iter.f, state.Ax_d)
state.grad_f_Ax_d .= grad_f_Ax_d
mul!(state.At_grad_f_Ax_d, adjoint(iter.A), state.grad_f_Ax_d)

copyto!(state.x, state.x_d)
Expand Down Expand Up @@ -220,8 +219,8 @@ function Base.iterate(iter::PANOCIteration{R,Tx,Tf}, state::PANOCState) where {R
# along a line using interpolation and linear combinations
# this allows saving operations
if isinf(f_Az)
f_Az, cl = value_and_gradient_closure(iter.f, state.Az)
state.grad_f_Az .= cl()
f_Az, grad_f_Az = value_and_gradient(iter.f, state.Az)
state.grad_f_Az .= grad_f_Az
end
if isinf(c)
mul!(state.At_grad_f_Az, iter.A', state.grad_f_Az)
Expand All @@ -239,8 +238,8 @@ function Base.iterate(iter::PANOCIteration{R,Tx,Tf}, state::PANOCState) where {R
else
# otherwise, in the general case where f is only smooth, we compute
# one gradient and matvec per backtracking step
state.f_Ax, cl = value_and_gradient_closure(iter.f, state.Ax)
state.grad_f_Ax .= cl()
state.f_Ax, grad_f_Ax = value_and_gradient(iter.f, state.Ax)
state.grad_f_Ax .= grad_f_Ax
mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax)
end

Expand Down
Loading

0 comments on commit 27f8a96

Please sign in to comment.