diff --git a/Project.toml b/Project.toml index 55c217d5..84538c95 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Ripserer" uuid = "aa79e827-bd0b-42a8-9f10-2b302677a641" authors = ["mtsch "] -version = "0.14.1" +version = "0.14.2" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/abstractfiltration.jl b/src/abstractfiltration.jl index 7c726efa..d14929c8 100644 --- a/src/abstractfiltration.jl +++ b/src/abstractfiltration.jl @@ -157,17 +157,20 @@ computed. Defaults to sorting the diagram. """ postprocess_diagram(::AbstractFiltration, diagram) = sort!(diagram) -function interval_meta_type(flt::AbstractFiltration, dim, reps, field) +# Vals everywhere so compiler computes this at compile time. +@inline function interval_type( + flt::AbstractFiltration, ::Val{dim}, ::Val{reps}, ::Type{F} +) where {dim, reps, F} if reps - return @NamedTuple begin + return PersistenceInterval{@NamedTuple begin birth_simplex::simplex_type(flt, dim) death_simplex::Union{simplex_type(flt, dim + 1), Nothing} - representative::Vector{chain_element_type(simplex_type(flt, dim), field)} - end + representative::Vector{chain_element_type(simplex_type(flt, dim), F)} + end} else - return @NamedTuple begin + return PersistenceInterval{@NamedTuple begin birth_simplex::simplex_type(flt, dim) death_simplex::Union{simplex_type(flt, dim + 1), Nothing} - end + end} end end diff --git a/src/reductionmatrix.jl b/src/reductionmatrix.jl index 3bb9d89b..59f7e93f 100644 --- a/src/reductionmatrix.jl +++ b/src/reductionmatrix.jl @@ -309,8 +309,8 @@ function birth_death(::ReductionMatrix{false}, column, pivot) end function add_interval!( - intervals, matrix::ReductionMatrix, column, pivot, cutoff, reps -) + intervals, matrix::ReductionMatrix, column, pivot, cutoff, ::Val{reps} +) where reps birth_time, birth_sx, death_time, death_sx = birth_death(matrix, column, pivot) if death_time - birth_time > cutoff if reps && is_cohomology(matrix) @@ -325,13 +325,14 @@ function add_interval!( else rep = NamedTuple() end - int = PersistenceInterval( - birth_time, death_time; + meta = (; birth_simplex=birth_sx, death_simplex=death_sx, rep..., ) - !isnothing(int) && push!(intervals, int) + push!(intervals, PersistenceInterval( + birth_time, death_time, meta + )) end end @@ -344,13 +345,13 @@ function compute_intervals!( desc="Computing $(dim(matrix))d intervals... ", ) end - intervals = PersistenceInterval{ - interval_meta_type(matrix.filtration, dim(matrix), reps, field_type(matrix)) - }[] + intervals = interval_type( + matrix.filtration, Val(dim(matrix)), Val(reps), field_type(matrix) + )[] for column in matrix.columns_to_reduce pivot = reduce_column!(matrix, column) - add_interval!(intervals, matrix, column, pivot, cutoff, reps) + add_interval!(intervals, matrix, column, pivot, cutoff, Val(reps)) progress && next!(progbar; showvalues=((:n_intervals, length(intervals)),)) end thresh=Float64(threshold(matrix.filtration)) diff --git a/src/zerodimensional.jl b/src/zerodimensional.jl index a6833a66..536126f1 100644 --- a/src/zerodimensional.jl +++ b/src/zerodimensional.jl @@ -78,8 +78,8 @@ end birth(dset::DisjointSetsWithBirth, i) = dset.births[i] function add_interval!( - intervals, dset::DisjointSetsWithBirth, filtration, vertex, edge, cutoff, reps -) + intervals, dset::DisjointSetsWithBirth, filtration, vertex, edge, cutoff, ::Val{reps} +) where reps birth_time, birth_vertex = birth(dset, vertex) death_time = isnothing(edge) ? Inf : birth(edge) if death_time - birth_time > cutoff @@ -116,7 +116,7 @@ function zeroth_intervals( CE = chain_element_type(V, F) dset = DisjointSetsWithBirth(vertices(filtration), birth(filtration)) - intervals = PersistenceInterval{interval_meta_type(filtration, 0, reps, F)}[] + intervals = interval_type(filtration, Val(0), Val(reps), F)[] to_skip = edge_type(filtration)[] to_reduce = edge_type(filtration)[] @@ -134,7 +134,7 @@ function zeroth_intervals( if i ≠ j # According to the elder rule, the vertex with the higer birth will die first. last_vertex = birth(dset, i) > birth(dset, j) ? i : j - add_interval!(intervals, dset, filtration, last_vertex, edge, cutoff, reps) + add_interval!(intervals, dset, filtration, last_vertex, edge, cutoff, Val(reps)) union!(dset, i, j) push!(to_skip, edge) @@ -145,7 +145,7 @@ function zeroth_intervals( end for v in vertices(filtration) if find_root!(dset, v) == v && !isnothing(simplex(filtration, Val(0), (v,), 1)) - add_interval!(intervals, dset, filtration, v, nothing, cutoff, reps) + add_interval!(intervals, dset, filtration, v, nothing, cutoff, Val(reps)) end progress && next!(progbar; showvalues=((:n_intervals, length(intervals)),)) end