Skip to content

Commit

Permalink
Fix failing benchmarks (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored May 31, 2024
1 parent 683b5a6 commit 7181de4
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 27 deletions.
16 changes: 10 additions & 6 deletions src/overload_hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ function hessian_tracer_1_to_1(
elseif !is_firstder_zero && is_secondder_zero
sh
elseif is_firstder_zero && !is_secondder_zero
product(sg, sg)
union_product!(myempty(SH), sg, sg)
else
union_product(sh, sg, sg)
union_product!(copy(sh), sg, sg)
end
return (sg_out, sh_out)
end
Expand Down Expand Up @@ -254,22 +254,26 @@ end
## Exponent (requires extra types)
for S in (Integer, Rational, Irrational{:ℯ})
function Base.:^(tx::T, y::S) where {T<:HessianTracer}
return T(gradient(tx), union_product(hessian(tx), gradient(tx), gradient(tx)))
return T(
gradient(tx), union_product!(copy(hessian(tx)), gradient(tx), gradient(tx))
)
end
function Base.:^(x::S, ty::T) where {T<:HessianTracer}
return T(gradient(ty), union_product(hessian(ty), gradient(ty), gradient(ty)))
return T(
gradient(ty), union_product!(copy(hessian(ty)), gradient(ty), gradient(ty))
)
end

function Base.:^(dx::D, y::S) where {P,T<:HessianTracer,D<:Dual{P,T}}
return Dual(
primal(dx)^y,
T(gradient(dx), union_product(hessian(dx), gradient(dx), gradient(dx))),
T(gradient(dx), union_product!(copy(hessian(dx)), gradient(dx), gradient(dx))),
)
end
function Base.:^(x::S, dy::D) where {P,T<:HessianTracer,D<:Dual{P,T}}
return Dual(
x^primal(dy),
T(gradient(dy), union_product(hessian(dy), gradient(dy), gradient(dy))),
T(gradient(dy), union_product!(copy(hessian(dy)), gradient(dy), gradient(dy))),
)
end
end
Expand Down
10 changes: 7 additions & 3 deletions src/settypes/duplicatevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ struct DuplicateVector{T} <: AbstractSet{T}
DuplicateVector{T}() where {T} = new{T}(T[])
end

function Base.show(io::IO, dv::DuplicateVector)
Base.show(io::IO, dv::DuplicateVector) = print(io, "DuplicateVector($(dv.data))")

function Base.show(io::IO, ::MIME"text/plain", dv::DuplicateVector)
return print(io, "DuplicateVector($(dv.data))")
end

Base.eltype(::Type{DuplicateVector{T}}) where {T} = T

Base.collect(dv::DuplicateVector) = unique!(dv.data)
Base.length(dv::DuplicateVector) = length(collect(dv)) # TODO: slow
Base.copy(dv::DuplicateVector{T}) where {T} = DuplicateVector{T}(dv.data)

function Base.union!(a::S, b::S) where {S<:DuplicateVector}
append!(a.data, b.data)
Expand All @@ -28,6 +30,8 @@ function Base.union(a::S, b::S) where {S<:DuplicateVector}
return S(vcat(a.data, b.data))
end

Base.collect(dv::DuplicateVector) = unique!(dv.data)

Base.iterate(dv::DuplicateVector) = iterate(collect(dv))
Base.iterate(dv::DuplicateVector, i::Integer) = iterate(collect(dv), i)

Expand Down
25 changes: 19 additions & 6 deletions src/settypes/recursiveset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,20 @@ function Base.show(io::IO, rs::RecursiveSet)
return print_recursiveset(io, rs; offset=0)
end

function Base.show(io::IO, ::MIME"text/plain", rs::RecursiveSet)
return print_recursiveset(io, rs; offset=0)
end

Base.eltype(::Type{RecursiveSet{T}}) where {T} = T
Base.length(rs::RecursiveSet) = length(collect(rs)) # TODO: slow

function Base.copy(rs::RecursiveSet{T}) where {T}
if !isnothing(rs.s)
return RecursiveSet{T}(copy(rs.s))
else
return RecursiveSet{T}(rs.child1, rs.child2)
end
end

function Base.union(rs1::RecursiveSet{T}, rs2::RecursiveSet{T}) where {T}
return RecursiveSet{T}(rs1, rs2)
Expand All @@ -58,12 +71,6 @@ function Base.union!(rs1::RecursiveSet{T}, rs2::RecursiveSet{T}) where {T}
return rs1
end

function Base.collect(rs::RecursiveSet{T}) where {T}
accumulator = Set{T}()
collect_aux!(accumulator, rs)
return collect(accumulator)
end

function collect_aux!(accumulator::Set{T}, rs::RecursiveSet{T})::Nothing where {T}
if !isnothing(rs.s)
union!(accumulator, rs.s::Set{T})
Expand All @@ -74,6 +81,12 @@ function collect_aux!(accumulator::Set{T}, rs::RecursiveSet{T})::Nothing where {
return nothing
end

function Base.collect(rs::RecursiveSet{T}) where {T}
accumulator = Set{T}()
collect_aux!(accumulator, rs)
return collect(accumulator)
end

Base.iterate(rs::RecursiveSet) = iterate(collect(rs))
Base.iterate(rs::RecursiveSet, i::Integer) = iterate(collect(rs), i)

Expand Down
14 changes: 10 additions & 4 deletions src/settypes/sortedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@ function Base.convert(::Type{SortedVector{T}}, v::Vector{T}) where {T}
return SortedVector{T}(v; sorted=false)
end

Base.length(v::SortedVector) = length(v.data)
Base.size(v::SortedVector) = size(v.data)
Base.getindex(v::SortedVector, i) = v.data[i]
Base.IndexStyle(::Type{SortedVector{T}}) where {T} = IndexStyle(Vector{T})
Base.show(io::IO, v::SortedVector) = print(io, "SortedVector($(v.data))")

function Base.show(io::IO, ::MIME"text/plain", dv::SortedVector)
return print(io, "SortedVector($(dv.data))")
end

Base.eltype(::Type{SortedVector{T}}) where {T} = T
Base.length(v::SortedVector) = length(v.data)
Base.copy(v::SortedVector{T}) where {T} = SortedVector{T}(copy(v.data); sorted=true)

function merge_sorted!(result::Vector{T}, left::Vector{T}, right::Vector{T}) where {T}
resize!(result, length(left) + length(right))
left_index, right_index, result_index = 1, 1, 1
Expand Down Expand Up @@ -83,6 +87,8 @@ function Base.union!(v1::SortedVector{T}, v2::SortedVector{T}) where {T}
return v1
end

Base.collect(v::SortedVector) = v.data

Base.iterate(v::SortedVector) = iterate(v.data)
Base.iterate(v::SortedVector, i::Integer) = iterate(v.data, i)

Expand Down
6 changes: 0 additions & 6 deletions src/tracers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@ end

product(a::AbstractSet{I}, b::AbstractSet{I}) where {I} = Set((i, j) for i in a, j in b)

function union_product(
sh::SH, sgx::SG, sgy::SG
) where {I,SG<:AbstractSet{I},SH<:AbstractSet{Tuple{I,I}}}
return clever_union(sh, product(sgx, sgy))
end

function union_product!(
sh::SH, sgx::SG, sgy::SG
) where {I,SG<:AbstractSet{I},SH<:AbstractSet{Tuple{I,I}}}
Expand Down
8 changes: 6 additions & 2 deletions test/settypes/correctness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@ using SparseConnectivityTracer
using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector, product
using Test

@testset "$(typeof(S))" for S in (
BitSet, Set{Int64}, DuplicateVector{Int64}, RecursiveSet{Int64}, SortedVector{Int64}
@testset "$S" for S in (
BitSet, Set{Int}, DuplicateVector{Int}, RecursiveSet{Int}, SortedVector{Int}
)
x = S.(1:10)
y = (x[1] x[3]) (x[3] ((x[5] x[7]) x[1]))

@test length(string(x)) > 0
@test eltype(y) == Int
@test length(y) == 4
@test sort(collect(y)) == [1, 3, 5, 7]
@test sort(collect(copy(y))) == [1, 3, 5, 7]
@test length(collect(product(y, y))) == 16
end

0 comments on commit 7181de4

Please sign in to comment.