Skip to content

Commit

Permalink
Rework conv interface and rework its internals
Browse files Browse the repository at this point in the history
* add `:algorithm` kwarg
* add `conv!`
  • Loading branch information
martinholters committed Mar 1, 2024
1 parent 873abb8 commit 969a8a1
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 113 deletions.
2 changes: 1 addition & 1 deletion src/DSP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using LinearAlgebra: mul!, rmul!
using IterTools: subsets
using Compat: Compat

export conv, deconv, filt, filt!, xcorr
export conv, conv!, deconv, filt, filt!, xcorr

# This function has methods added in `periodograms` but is not exported,
# so we define it here so one can do `DSP.allocate_output` instead of
Expand Down
199 changes: 92 additions & 107 deletions src/dspbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,16 +483,16 @@ end
# Assumes u is larger than, or the same size as, v
# nfft should be greater than or equal to 2*sv-1
function unsafe_conv_kern_os!(out,
output_indices,
u::AbstractArray{<:Any, N},
v,
su,
sv,
sout,
nffts) where N
sout = size(out)
su = size(u)
sv = size(v)
u_start = first.(axes(u))
out_axes = axes(out)
out_start = first.(out_axes)
out_stop = last.(out_axes)
out_start = Tuple(first(output_indices))
out_stop = Tuple(last(output_indices))
ideal_save_blocksize = nffts .- sv .+ 1
# Number of samples that are "missing" if the output is smaller than the
# valid portion of the convolution
Expand All @@ -502,7 +502,7 @@ function unsafe_conv_kern_os!(out,
nblocks = cld.(sout, save_blocksize)

# Pre-allocation
tdbuff, fdbuff, p, ip = os_prepare_conv(u, nffts)
tdbuff, fdbuff, p, ip = os_prepare_conv(out, nffts)
tdbuff_axes = axes(tdbuff)

# Transform the smaller filter
Expand Down Expand Up @@ -603,148 +603,133 @@ function unsafe_conv_kern_os!(out,
out
end

function _conv_kern_fft!(out,
u::AbstractArray{T, N},
v::AbstractArray{T, N},
su,
sv,
outsize,
nffts) where {T<:Real, N}
padded = _zeropad(u, nffts)
function _conv_kern_fft!(out::AbstractArray{T, N},
output_indices,
u::AbstractArray{<:Real, N},
v::AbstractArray{<:Real, N}) where {T<:Real, N}
outsize = size(output_indices)
nffts = nextfastfft(outsize)
padded = _zeropad!(similar(u, T, nffts), u)
p = plan_rfft(padded)
uf = p * padded
_zeropad!(padded, v)
vf = p * padded
uf .*= vf
raw_out = irfft(uf, nffts[1])
copyto!(out,
CartesianIndices(out),
output_indices,
raw_out,
CartesianIndices(UnitRange.(1, outsize)))
end
function _conv_kern_fft!(out, u, v, su, sv, outsize, nffts)
upad = _zeropad(u, nffts)
vpad = _zeropad(v, nffts)
function _conv_kern_fft!(out::AbstractArray{T}, output_indices, u, v) where {T}
outsize = size(output_indices)
nffts = nextfastfft(outsize)
upad = _zeropad!(similar(u, T, nffts), u)
vpad = _zeropad!(similar(v, T, nffts), v)
p! = plan_fft!(upad)
ip! = inv(p!)
p! * upad # Operates in place on upad
p! * vpad
upad .*= vpad
ip! * upad
copyto!(out,
CartesianIndices(out),
output_indices,
upad,
CartesianIndices(UnitRange.(1, outsize)))
end

# v should be smaller than u for good performance
function _conv_fft!(out, u, v, su, sv, outsize)
os_nffts = map(optimalfftfiltlength, sv, su)
if any(os_nffts .< outsize)
unsafe_conv_kern_os!(out, u, v, su, sv, outsize, os_nffts)
else
nffts = nextfastfft(outsize)
_conv_kern_fft!(out, u, v, su, sv, outsize, nffts)
function _conv_td!(out, output_indices, u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}) where {N}
index_offset = first(CartesianIndices(u)) + first(CartesianIndices(v)) - first(output_indices)
checkbounds(out, output_indices)
fill!(out, zero(eltype(out)))
for m in CartesianIndices(u), n in CartesianIndices(v)
@inbounds out[n+m - index_offset] = muladd(u[m], v[n], out[n+m - index_offset])
end
return out
end

const FFTTypes = Union{Float32, Float64, ComplexF32, ComplexF64}

function conv!(
out::AbstractArray{T, N},
u::AbstractArray{<:Number, N},
v::AbstractArray{<:Number, N};
algorithm=:auto
) where {T<:Number, N}
calc_index_offset(ao::Base.OneTo, au::Base.OneTo, av::Base.OneTo) = 1
calc_index_offset(ao::Base.OneTo, au, av) = # first(au) + first(av) - 1

Check warning on line 660 in src/dspbase.jl

View check run for this annotation

Codecov / codecov/patch

src/dspbase.jl#L660

Added line #L660 was not covered by tests
throw(ArgumentError("output must have offset axes if the input has"))
calc_index_offset(ao, au::Base.OneTo, av::Base.OneTo) = # 2

Check warning on line 662 in src/dspbase.jl

View check run for this annotation

Codecov / codecov/patch

src/dspbase.jl#L662

Added line #L662 was not covered by tests
throw(ArgumentError("output must not have offset axes if none of the inputs has"))
calc_index_offset(ao, au, av) = 0
output_indices = CartesianIndices(map(axes(out), axes(u), axes(v)) do ao, au, av
return (first(au)+first(av) : last(au)+last(av)) .- calc_index_offset(ao, au, av)
end)

# For arrays with weird offsets
function _conv_similar(u, outsize, axesu, axesv)
out_offsets = first.(axesu) .+ first.(axesv)
out_axes = UnitRange.(out_offsets, out_offsets .+ outsize .- 1)
similar(u, out_axes)
end
function _conv_similar(
u, outsize, ::NTuple{<:Any, Base.OneTo{Int}}, ::NTuple{<:Any, Base.OneTo{Int}}
)
similar(u, outsize)
end
_conv_similar(u, v, outsize) = _conv_similar(u, outsize, axes(u), axes(v))

# Does convolution, will not switch argument order
function _conv!(out, u, v, su, sv, outsize)
# TODO: Add spatial / time domain algorithm
_conv_fft!(out, u, v, su, sv, outsize)
end

# Does convolution, will not switch argument order
function _conv(u, v, su, sv)
outsize = su .+ sv .- 1
out = _conv_similar(u, v, outsize)
_conv!(out, u, v, su, sv, outsize)
end

function _conv_td(u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}) where {N}
output_indices = CartesianIndices(map(axes(u), axes(v)) do au, av
r = (first(au)+first(av)):(last(au)+last(av))
if au isa Base.OneTo && av isa Base.OneTo
return r
if algorithm===:auto
algorithm = T <: FFTTypes ? :fast : :direct
end
if algorithm===:fast
if length(u) * length(v) < 2^16 # TODO: better heuristic
algorithm = :direct
else
return Base.IdentityUnitRange(r)
algorithm = :fft
end
end)
return [
sum(u[m] * v[n-m]
for m in CartesianIndices(ntuple(Val(N)) do d
max(firstindex(u,d),n[d]-lastindex(v,d)):min(lastindex(u,d), n[d]-firstindex(v,d))
end)
)
for n in output_indices
]
end
if algorithm===:direct
return _conv_td!(out, output_indices, u, v)
else
if output_indices != CartesianIndices(out)
fill!(out, zero(eltype(out)))
end
os_nffts = length(u) >= length(v) ? map(optimalfftfiltlength, size(v), size(u)) : map(optimalfftfiltlength, size(u), size(v))
if algorithm===:fft
if any(os_nffts .< size(output_indices))
algorithm = :fft_overlapsave

Check warning on line 688 in src/dspbase.jl

View check run for this annotation

Codecov / codecov/patch

src/dspbase.jl#L688

Added line #L688 was not covered by tests
else
algorithm = :fft_simple
end
end
if algorithm === :fft_overlapsave
# v should be smaller than u for good performance
if length(u) >= length(v)
return unsafe_conv_kern_os!(out, output_indices, u, v, os_nffts)
else
return unsafe_conv_kern_os!(out, output_indices, v, u, os_nffts)
end
elseif algorithm === :fft_simple
return _conv_kern_fft!(out, output_indices, u, v)
else
throw(ArgumentError("algorithm must be :auto, :fast, :direct, :fft, :fft_simple, or :fft_overlapsave"))

Check warning on line 703 in src/dspbase.jl

View check run for this annotation

Codecov / codecov/patch

src/dspbase.jl#L703

Added line #L703 was not covered by tests
end
end
end

# We use this type definition for clarity
const RealOrComplexFloat = Union{AbstractFloat, Complex{T} where T<:AbstractFloat}

# May switch argument order
"""
conv(u,v)
Convolution of two arrays. Uses either FFT convolution or overlap-save,
depending on the size of the input. `u` and `v` can be N-dimensional arrays,
with arbitrary indexing offsets, but their axes must be a `UnitRange`.
"""
function conv(u::AbstractArray{T, N},
v::AbstractArray{T, N}) where {T<:RealOrComplexFloat, N}
su = size(u)
sv = size(v)
if length(u) >= length(v)
_conv(u, v, su, sv)
else
_conv(v, u, sv, su)
end
end

function conv(u::AbstractArray{<:RealOrComplexFloat, N},
v::AbstractArray{<:RealOrComplexFloat, N}) where N
fu, fv = promote(u, v)
conv(fu, fv)
end

conv(u::AbstractArray{<:Integer, N}, v::AbstractArray{<:Integer, N}) where {N} =
_conv_td(u, v)

conv(u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}) where {N} =
conv(float(u), float(v))

function conv(u::AbstractArray{<:Number, N},
v::AbstractArray{<:RealOrComplexFloat, N}) where N
conv(float(u), v)
end

function conv(u::AbstractArray{<:RealOrComplexFloat, N},
v::AbstractArray{<:Number, N}) where N
conv(u, float(v))
function conv(
u::AbstractArray{Tu, N}, v::AbstractArray{Tv, N}; kwargs...
) where {Tu<:Number, Tv<:Number, N}
T = promote_type(Tu, Tv)
out_axis(au, av) = (first(au)+first(av)):(last(au)+last(av))
out_axis(au::Base.OneTo, av::Base.OneTo) = Base.OneTo(last(au) + last(av) - 1)
out_axes = map(out_axis, axes(u), axes(v))
out = similar(u, T, out_axes)
return conv!(out, u, v; kwargs...)
end

function conv(A::AbstractArray{<:Number, M},
B::AbstractArray{<:Number, N}) where {M, N}
B::AbstractArray{<:Number, N}; kwargs...) where {M, N}
if (M < N)
conv(cat(A, dims=N)::AbstractArray{eltype(A), N}, B)
conv(cat(A, dims=N)::AbstractArray{eltype(A), N}, B; kwargs...)

Check warning on line 729 in src/dspbase.jl

View check run for this annotation

Codecov / codecov/patch

src/dspbase.jl#L729

Added line #L729 was not covered by tests
else
@assert M > N
conv(A, cat(B, dims=M)::AbstractArray{eltype(B), M})
conv(A, cat(B, dims=M)::AbstractArray{eltype(B), M}; kwargs...)
end
end

Expand Down
43 changes: 38 additions & 5 deletions test/dsp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# TODO: parameterize conv tests
using Test, OffsetArrays
using DSP: filt, filt!, deconv, conv, xcorr,
optimalfftfiltlength, unsafe_conv_kern_os!, _conv_kern_fft!, _conv_similar,
optimalfftfiltlength, unsafe_conv_kern_os!, _conv_kern_fft!
nextfastfft


Expand Down Expand Up @@ -59,6 +59,10 @@ end
@test conv(f32a, b) fexp
@test conv(fb, a) fexp

u = rand(190)
v = rand(200)
@test conv(u, v; algorithm=:direct) conv(u, v; algorithm=:fft_simple) conv(u, v; algorithm=:fft_overlapsave)

# issue #410
n = 314159265
@test conv([n], [n]) == [n^2]
Expand All @@ -69,6 +73,32 @@ end
offset_arr_f = OffsetVector{Float64}(undef, -1:2)
offset_arr_f[:] = fa
@test conv(offset_arr_f, 1:3) OffsetVector(fexp, 0:5)

for M in [10, 200], N in [10, 200], T in [Float64, ComplexF64]
u = rand(T, M)
v = rand(T, N)
u_off = OffsetVector(u, 23)
v_off = OffsetVector(v, -42)
@test conv(u, v; algorithm=:direct) conv(u, v; algorithm=:fft_simple) conv(u, v; algorithm=:fft_overlapsave)
@test conv(u_off, v_off; algorithm=:direct) conv(u_off, v_off; algorithm=:fft_simple) conv(u_off, v_off; algorithm=:fft_overlapsave)
@test conv(u, v) == conv(u_off, v_off)[23-42+2:23-42+N+M]

for algorithm in [:direct, :fft_simple, :fft_overlapsave]
# pre-allocated non-offset output larger than necessary
out = ones(T, M+N+10)
conv!(out, u, v; algorithm)
@test out[1:M+N-1] == conv(u, v; algorithm)
@test all(iszero, out[M+N:end])

# pre-allocated output with offset larger than necessary
out = OffsetVector(ones(T, M+N+10), 23-42-5)
conv!(out, u_off, v_off; algorithm)
@test out[23-42+2:23-42+N+M] == conv(u, v; algorithm)
@test all(iszero, out[begin:23-42+1])
@test all(iszero, out[23-42+N+M+1:end])
end
end

# Issue #352
@test conv([1//2, 1//3, 1//4], [1, 2]) [1//2, 4//3, 11//12, 1//2]
# Non-numerical arrays should not be convolved
Expand Down Expand Up @@ -113,6 +143,10 @@ end
@test conv(f32a, b) fexp
@test conv(fb, a) fexp

u = rand(10, 20)
v = rand(10, 10)
@test conv(u, v; algorithm=:direct) conv(u, v; algorithm=:fft)

offset_arr = OffsetMatrix{Int}(undef, -1:1, -1:1)
offset_arr[:] = a
@test conv(offset_arr, b) == OffsetArray(expectation, 0:3, 0:3)
Expand Down Expand Up @@ -197,11 +231,10 @@ end
su, u = os_test_data(T, nu, N)
sv, v = os_test_data(T, nv, N)
sout = su .+ sv .- 1
out = _conv_similar(u, sout, axes(u), axes(v))
unsafe_conv_kern_os!(out, u, v, su, sv, sout, nffts)
out = similar(u, T, sout)
unsafe_conv_kern_os!(out, CartesianIndices(out), u, v, nffts)
os_out = copy(out)
fft_nfft = nextfastfft(sout)
_conv_kern_fft!(out, u, v, su, sv, sout, fft_nfft)
_conv_kern_fft!(out, CartesianIndices(out), u, v)
@test out os_out
end
Ns = [1, 2, 3]
Expand Down

0 comments on commit 969a8a1

Please sign in to comment.