Skip to content

Commit

Permalink
Add shared Hessian tracer à la Walther (#135)
Browse files Browse the repository at this point in the history
* Add shared Hessian tracer à la Walther

* Add identical objects test

* Fixes for patterns introduced in #139

* Clarify connection between mutation and sharing

---------

Co-authored-by: adrhill <adrian.hill@mailbox.org>
  • Loading branch information
gdalle and adrhill authored Jul 30, 2024
1 parent b447596 commit c0bf9d0
Show file tree
Hide file tree
Showing 12 changed files with 302 additions and 97 deletions.
16 changes: 14 additions & 2 deletions benchmark/bench_jogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ suite["OptimizationProblems"] = optbench([:britgas])
for S1 in SET_TYPES
S2 = Set{Tuple{Int,Int}}

# Non-shared tracers
shared = false
PG = IndexSetGradientPattern{Int,S1}
PH = IndexSetHessianPattern{Int,S1,S2}

PH = IndexSetHessianPattern{Int,S1,S2,shared}
G = GradientTracer{PG}
H = HessianTracer{PH}

Expand All @@ -34,4 +35,15 @@ for S1 in SET_TYPES
suite["Hessian"]["Local"][(nameof(S1), nameof(S2))] = hessbench(
TracerLocalSparsityDetector(G, H)
)

# Shared tracers
shared = true
PG = IndexSetGradientPattern{Int,S1}
PH = IndexSetHessianPattern{Int,S1,S2,shared}
G = GradientTracer{PG}
H = HessianTracer{PH}

suite["Hessian"]["Global (shared)"][(nameof(S1), nameof(S2))] = hessbench(
TracerSparsityDetector(G, H)
)
end
17 changes: 8 additions & 9 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
const DEFAULT_GRADIENT_TRACER = GradientTracer{IndexSetGradientPattern{Int,BitSet}}
const DEFAULT_HESSIAN_TRACER = HessianTracer{
IndexSetHessianPattern{Int,BitSet,Set{Tuple{Int,Int}}}
IndexSetHessianPattern{Int,BitSet,Set{Tuple{Int,Int}},false}
}

#==================#
Expand All @@ -9,20 +9,19 @@ const DEFAULT_HESSIAN_TRACER = HessianTracer{

"""
trace_input(T, x)
trace_input(T, x)
trace_input(T, xs)
Enumerates input indices and constructs the specified type `T` of tracer.
Supports [`GradientTracer`](@ref), [`HessianTracer`](@ref) and [`Dual`](@ref).
"""
trace_input(::Type{T}, x) where {T<:Union{AbstractTracer,Dual}} = trace_input(T, x, 1)
trace_input(::Type{T}, xs) where {T<:Union{AbstractTracer,Dual}} = trace_input(T, xs, 1)

function trace_input(::Type{T}, x::Real, i::Integer) where {T<:Union{AbstractTracer,Dual}}
return create_tracer(T, x, i)
end
function trace_input(::Type{T}, xs::AbstractArray, i) where {T<:Union{AbstractTracer,Dual}}
indices = reshape(1:length(xs), size(xs)) .+ (i - 1)
return create_tracer.(T, xs, indices)
is = reshape(1:length(xs), size(xs)) .+ (i - 1)
return create_tracers(T, xs, is)
end
function trace_input(::Type{T}, x::Real, i::Integer) where {T<:Union{AbstractTracer,Dual}}
return only(create_tracers(T, [x], [i]))
end

#=========================#
Expand Down
49 changes: 46 additions & 3 deletions src/overloads/hessian_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ end

function hessian_tracer_1_to_1_inner(
p::P, is_der1_zero::Bool, is_der2_zero::Bool
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH}}
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,false}}
sg = gradient(p)
sh = hessian(p)
sg_out = gradient_tracer_1_to_1_inner(sg, is_der1_zero)
Expand All @@ -22,13 +22,32 @@ function hessian_tracer_1_to_1_inner(
elseif !is_der1_zero && is_der2_zero
sh
elseif is_der1_zero && !is_der2_zero
# TODO: this branch of the code currently isn't tested.
# Covering it would require a scalar 1-to-1 function with local overloads,
# such that ∂f/∂x == 0 and ∂²f/∂x² != 0.
union_product!(myempty(SH), sg, sg)
else
else # !is_der1_zero && !is_der2_zero
union_product!(copy(sh), sg, sg)
end
return P(sg_out, sh_out) # return pattern
end

# NOTE: mutates argument p and should arguably be called `hessian_tracer_1_to_1_inner!`
function hessian_tracer_1_to_1_inner(
p::P, is_der1_zero::Bool, is_der2_zero::Bool
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,true}}
sg = gradient(p)
sh = hessian(p)
sg_out = gradient_tracer_1_to_1_inner(sg, is_der1_zero)
# shared Hessian patterns can't remove second-order information, only add to it.
sh_out = if is_der2_zero
sh
else
union_product!(sh, sg, sg)
end
return P(sg_out, sh_out) # return pattern
end

function overload_hessian_1_to_1(M, op)
SCT = SparseConnectivityTracer
return quote
Expand Down Expand Up @@ -96,7 +115,7 @@ function hessian_tracer_2_to_1_inner(
is_der1_arg2_zero::Bool,
is_der2_arg2_zero::Bool,
is_der_cross_zero::Bool,
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH}}
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,false}}
sgx, shx = gradient(px), hessian(px)
sgy, shy = gradient(py), hessian(py)
sg_out = gradient_tracer_2_to_1_inner(sgx, sgy, is_der1_arg1_zero, is_der1_arg2_zero)
Expand All @@ -110,6 +129,30 @@ function hessian_tracer_2_to_1_inner(
return P(sg_out, sh_out) # return pattern
end

# NOTE: mutates arguments px and py and should arguably be called `hessian_tracer_1_to_1_inner!`
function hessian_tracer_2_to_1_inner(
px::P,
py::P,
is_der1_arg1_zero::Bool,
is_der2_arg1_zero::Bool,
is_der1_arg2_zero::Bool,
is_der2_arg2_zero::Bool,
is_der_cross_zero::Bool,
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,true}}
sgx, shx = gradient(px), hessian(px)
sgy, shy = gradient(py), hessian(py)

shx !== shy && error("Expected shared Hessians, got $shx, $shy.")
sh_out = shx # union of shx and shy can be skipped since they are the same object
sg_out = gradient_tracer_2_to_1_inner(sgx, sgy, is_der1_arg1_zero, is_der1_arg2_zero)

!is_der2_arg1_zero && union_product!(sh_out, sgx, sgx) # product alpha
!is_der2_arg2_zero && union_product!(sh_out, sgy, sgy) # product beta
!is_der_cross_zero && union_product!(sh_out, sgx, sgy) # cross product 1
!is_der_cross_zero && union_product!(sh_out, sgy, sgx) # cross product 2
return P(sg_out, sh_out) # return pattern
end

function overload_hessian_2_to_1(M, op)
SCT = SparseConnectivityTracer
return quote
Expand Down
11 changes: 10 additions & 1 deletion src/overloads/ifelse_global.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,20 @@
function output_union(px::P, py::P) where {P<:IndexSetGradientPattern}
return P(union(set(px), set(py))) # return pattern
end
function output_union(px::P, py::P) where {P<:IndexSetHessianPattern}
function output_union(
px::P, py::P
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,false}} # non-mutating
g_out = union(gradient(px), gradient(py))
h_out = union(hessian(px), hessian(py))
return P(g_out, h_out) # return pattern
end
function output_union(
px::P, py::P
) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH,true}} # mutating
g_out = union(gradient(px), gradient(py))
h_out = union!(hessian(px), hessian(py))
return P(g_out, h_out) # return pattern
end

output_union(tx::AbstractTracer, y) = tx
output_union(x, ty::AbstractTracer) = ty
Expand Down
83 changes: 56 additions & 27 deletions src/patterns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@ AbstractPattern
"""
abstract type AbstractPattern end

"""
isshared(pattern)
Indicates whether patterns **always** share memory and whether operators are **allowed** to mutate their `AbstractTracer` arguments.
If `false`, patterns **can** share memory and operators are **prohibited** from mutating `AbstractTracer` arguments.
## Note
In practice, memory sharing is limited to second-order information in `AbstractHessianPattern`.
"""
isshared(::P) where {P<:AbstractPattern} = isshared(P)
isshared(::Type{P}) where {P<:AbstractPattern} = false

"""
myempty(T)
myempty(tracer)
Expand All @@ -25,13 +38,11 @@ Constructor for an empty tracer or pattern of type `T` representing a new number
myempty

"""
seed(T, i)
seed(tracer, i)
seed(pattern, i)
create_patterns(P, xs, is)
Constructor for a tracer or pattern of type `T` that only contains the given index `i`.
Convenience constructor for patterns of type `P` for multiple inputs `xs` and their indices `is`.
"""
seed
create_patterns

#==========================#
# Utilities on AbstractSet #
Expand All @@ -49,8 +60,8 @@ product(a::AbstractSet{I}, b::AbstractSet{I}) where {I<:Integer} =
Set((i, j) for i in a, j in b)

function union_product!(
hessian::SH, gradient_x::SG, gradient_y::SG
) where {I<:Integer,SG<:AbstractSet{I},SH<:AbstractSet{Tuple{I,I}}}
hessian::H, gradient_x::G, gradient_y::G
) where {I<:Integer,G<:AbstractSet{I},H<:AbstractSet{Tuple{I,I}}}
hxy = product(gradient_x, gradient_y)
return union!(hessian, hxy)
end
Expand All @@ -69,18 +80,17 @@ For use with [`GradientTracer`](@ref).
## Expected interface
* `myempty(::Type{MyPattern})`: return a pattern representing a new number (usually an empty pattern)
* `seed(::Type{MyPattern}, i::Integer)`: return an pattern that only contains the given index `i`
* `gradient(p::MyPattern)`: return non-zero indices `i` for use with `GradientTracer`
Note that besides their names, the last two functions are usually identical.
* [`myempty`](@ref)
* [`create_patterns`](@ref)
* `gradient(p::MyPattern)`: return non-zero indices `i` in the gradient representation
* [`isshared`](@ref) in case the pattern is shared (mutates). Defaults to false.
"""
abstract type AbstractGradientPattern <: AbstractPattern end

"""
$(TYPEDEF)
Vector sparsity pattern represented by an `AbstractSet` of indices ``{i}`` of non-zero values.
Gradient sparsity pattern represented by an `AbstractSet` of indices ``{i}`` of non-zero values.
## Fields
$(TYPEDFIELDS)
Expand All @@ -97,8 +107,9 @@ Base.show(io::IO, p::IndexSetGradientPattern) = Base.show(io, set(p))
function myempty(::Type{IndexSetGradientPattern{I,S}}) where {I,S}
return IndexSetGradientPattern{I,S}(myempty(S))
end
function seed(::Type{IndexSetGradientPattern{I,S}}, i) where {I,S}
return IndexSetGradientPattern{I,S}(seed(S, i))
function create_patterns(::Type{P}, xs, is) where {I,S,P<:IndexSetGradientPattern{I,S}}
sets = seed.(S, is)
return P.(sets)
end

# Tracer compatibility
Expand All @@ -118,29 +129,47 @@ For use with [`HessianTracer`](@ref).
## Expected interface
* `myempty(::Type{MyPattern})`: return a pattern representing a new number (usually an empty pattern)
* `seed(::Type{MyPattern}, i::Integer)`: return an pattern that only contains the given index `i` in the first-order representation
* [`myempty`](@ref)
* [`create_patterns`](@ref)
* `gradient(p::MyPattern)`: return non-zero indices `i` in the first-order representation
* `hessian(p::MyPattern)`: return non-zero indices `(i, j)` in the second-order representation
* [`isshared`](@ref) in case the pattern is shared (mutates). Defaults to false.
"""
abstract type AbstractHessianPattern <: AbstractPattern end

"""
IndexSetHessianPattern(vector::AbstractGradientPattern, mat::AbstractMatrixPattern)
$(TYPEDEF)
Hessian sparsity pattern represented by:
* an `AbstractSet` of indices ``i`` of non-zero values representing first-order sparsity
* an `AbstractSet` of index tuples ``(i,j)`` of non-zero values representing second-order sparsity
## Fields
$(TYPEDFIELDS)
## Internals
Gradient and Hessian sparsity patterns constructed by combining two AbstractSets.
The last type parameter `shared` is a `Bool` indicating whether the `hessian` field of this object should be shared among all intermediate scalar quantities involved in a function.
"""
struct IndexSetHessianPattern{I<:Integer,SG<:AbstractSet{I},SH<:AbstractSet{Tuple{I,I}}} <:
AbstractHessianPattern
gradient::SG
hessian::SH
struct IndexSetHessianPattern{
I<:Integer,G<:AbstractSet{I},H<:AbstractSet{Tuple{I,I}},shared
} <: AbstractHessianPattern
gradient::G
hessian::H
end
isshared(::Type{IndexSetHessianPattern{I,G,H,true}}) where {I,G,H} = true

function myempty(::Type{IndexSetHessianPattern{I,SG,SH}}) where {I,SG,SH}
return IndexSetHessianPattern{I,SG,SH}(myempty(SG), myempty(SH))
function myempty(::Type{P}) where {I,G,H,S,P<:IndexSetHessianPattern{I,G,H,S}}
return P(myempty(G), myempty(H))
end
function seed(::Type{IndexSetHessianPattern{I,SG,SH}}, index) where {I,SG,SH}
return IndexSetHessianPattern{I,SG,SH}(seed(SG, index), myempty(SH))
function create_patterns(
::Type{P}, xs, is
) where {I,G,H,S,P<:IndexSetHessianPattern{I,G,H,S}}
gradients = seed.(G, is)
hessian = myempty(H)
# Even if `shared=false`, sharing a single reference to `hessian` is allowed upon initialization,
# since mutation is prohibited when `isshared` is false.
return P.(gradients, Ref(hessian))
end

# Tracer compatibility
Expand Down
31 changes: 17 additions & 14 deletions src/tracers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,29 +131,32 @@ end
# Utilities #
#===========#

myempty(::T) where {T<:AbstractTracer} = myempty(T)
# isshared(::Type{T}) where {P,T<:GradientTracer{P}} = isshared(P) # no shared AbstractGradientPattern yet
isshared(::Type{T}) where {P,T<:HessianTracer{P}} = isshared(P)

# myempty(::Type{T}) where {P,T<:AbstractTracer{P}} = T(myempty(P), true) # JET complains about this
myempty(::T) where {T<:AbstractTracer} = myempty(T)
# myempty(::Type{T}) where {P,T<:AbstractTracer{P}} = T(myempty(P), true) # JET complains about this
myempty(::Type{T}) where {P,T<:GradientTracer{P}} = T(myempty(P), true)
myempty(::Type{T}) where {P,T<:HessianTracer{P}} = T(myempty(P), true)

seed(::T, i) where {T<:AbstractTracer} = seed(T, i)

# seed(::Type{T}, i) where {P,T<:AbstractTracer{P}} = T(seed(P, i)) # JET complains about this
seed(::Type{T}, i) where {P,T<:GradientTracer{P}} = T(seed(P, i))
seed(::Type{T}, i) where {P,T<:HessianTracer{P}} = T(seed(P, i))

"""
create_tracer(T, index) where {T<:AbstractTracer}
create_tracers(T, xs, indices)
Convenience constructor for [`GradientTracer`](@ref) and [`HessianTracer`](@ref) from input indices.
Convenience constructor for [`GradientTracer`](@ref), [`HessianTracer`](@ref) and [`Dual`](@ref)
from multiple inputs `xs` and their indices `is`.
"""
function create_tracer(::Type{T}, ::Real, index::Integer) where {P,T<:AbstractTracer{P}}
return T(seed(P, index))
function create_tracers(
::Type{T}, xs::AbstractArray{<:Real,N}, indices::AbstractArray{<:Integer,N}
) where {P<:AbstractPattern,T<:AbstractTracer{P},N}
patterns = create_patterns(P, xs, indices)
return T.(patterns)
end

function create_tracer(::Type{Dual{P,T}}, primal::Real, index::Integer) where {P,T}
return Dual(primal, create_tracer(T, primal, index))
function create_tracers(
::Type{D}, xs::AbstractArray{<:Real,N}, indices::AbstractArray{<:Integer,N}
) where {P,T,D<:Dual{P,T},N}
tracers = create_tracers(T, xs, indices)
return D.(xs, tracers)
end

# Pretty-printing of Dual tracers
Expand Down
2 changes: 1 addition & 1 deletion test/brusselator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector
using SparseConnectivityTracerBenchmarks.ODE: Brusselator!
using Test

# Load definitions of GRADIENT_TRACERS and HESSIAN_TRACERS
# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("tracers_definitions.jl")

function test_brusselator(method::AbstractSparsityDetector)
Expand Down
2 changes: 1 addition & 1 deletion test/flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using SparseConnectivityTracer
using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector
using Test

# Load definitions of GRADIENT_TRACERS and HESSIAN_TRACERS
# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("tracers_definitions.jl")

const INPUT_FLUX = reshape(
Expand Down
2 changes: 1 addition & 1 deletion test/test_constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using SparseConnectivityTracer: primal, tracer, isemptytracer
using SparseConnectivityTracer: myempty, name
using Test

# Load definitions of GRADIENT_TRACERS and HESSIAN_TRACERS
# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("tracers_definitions.jl")

function test_nested_duals(::Type{T}) where {T<:AbstractTracer}
Expand Down
Loading

0 comments on commit c0bf9d0

Please sign in to comment.