Skip to content

Commit

Permalink
Merge pull request #7 from heltonmc/cfmadd
Browse files Browse the repository at this point in the history
Add more general support for arithmetic with complex vectors and complex or real scalars
  • Loading branch information
heltonmc authored Apr 10, 2023
2 parents b29e44a + e329036 commit c953bdd
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 167 deletions.
48 changes: 34 additions & 14 deletions src/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,41 +21,61 @@ end
i = fmul(x.im, y.data)
return ComplexVec(r, i)
end
@inline fmul(x::Vec{N, FloatTypes}, y::ComplexVec{N, FloatTypes}) where {N, FloatTypes} = fmul(y, x)

# Complex add / subtract
for f in (:fadd, :fsub)
@eval begin
@inline function $f(x::ComplexVec{N, FloatTypes}, y::ComplexVec{N, FloatTypes}) where {N, FloatTypes}
re = $f(x.re, y.re)
im = $f(x.im, y.im)
return ComplexVec(re, im)
r = $f(x.re, y.re)
i = $f(x.im, y.im)
return ComplexVec(r, i)
end
@inline function $f(x::ComplexVec{N, FloatTypes}, y::Vec{N, FloatTypes}) where {N, FloatTypes}
re = $f(x.re, y.data)
return ComplexVec(re, x.im)
r = $f(x.re, y.data)
return ComplexVec(r, x.im)
end
end
end

# Argument symmetry
@inline fadd(x::Vec{N, FloatTypes}, y::ComplexVec{N, FloatTypes}) where {N, FloatTypes} = fadd(y, x)

@inline function fsub(x::Vec{N, FloatTypes}, y::ComplexVec{N, FloatTypes}) where {N, FloatTypes}
r = fsub(x.data, y.re)
return ComplexVec(r, fneg(y.im))
end

for f in (:fmul, :fadd, :fsub)
@eval @inline $f(x::Vec{N, T}, y::ComplexVec{N, T}) where {N, T <: FloatTypes} = $f(y, x)
# promote complex numbers to constant complex vectors
@eval @inline $f(x::Complex{T}, y::ComplexVec{N, T}) where {N, T <: FloatTypes} = $f(promote(x, y)...)
@eval @inline $f(x::ComplexVec{N, T}, y::Complex{T}) where {N, T <: FloatTypes} = $f(promote(x, y)...)
@eval @inline $f(x::Complex{T}, y::Vec{N, T}) where {N, T <: FloatTypes} = $f(convert(ComplexVec{N, T}, x), y)
@eval @inline $f(x::Vec{N, T}, y::Complex{T}) where {N, T <: FloatTypes} = $f(x, convert(ComplexVec{N, T}, y))

# promote real numbers to constant real vectors
@eval @inline $f(x::T, y::ComplexVec{N, T}) where {N, T <: FloatTypes} = $f(convert(Vec{N, T}, x), y)
@eval @inline $f(x::ComplexVec{N, T}, y::T) where {N, T <: FloatTypes} = $f(x, convert(Vec{N, T}, y))
end

@inline fneg(x::ComplexVec{N, T}) where {N, T <: FloatTypes} = ComplexVec{N, T}(fneg(x.re), fneg(x.im))

# complex multiply-add
# a*b + c
@inline fmadd(x::ComplexorRealVec{N, T}, y::ComplexorRealVec{N, T}, z::ComplexorRealVec{N, T}) where {N, T <: FloatTypes} = fadd(fmul(x, y), z)
@inline fmadd(x, y, z) = fadd(fmul(x, y), z)

# complex multiply-subtract
# a*b - c
@inline fmsub(x::ComplexorRealVec{N, T}, y::ComplexorRealVec{N, T}, z::ComplexorRealVec{N, T}) where {N, T <: FloatTypes} = fsub(fmul(x, y), z)
@inline fmsub(x, y, z) = fsub(fmul(x, y), z)

# complex negated multiply-add
# -a*b + c
@inline fnmadd(x::ComplexorRealVec{N, T}, y::ComplexorRealVec{N, T}, z::ComplexorRealVec{N, T}) where {N, T <: FloatTypes} = fsub(z, fmul(x, y))
@inline fnmadd(x, y, z) = fsub(z, fmul(x, y))

# -a*b - c
@inline function fnmsub(x::ComplexorRealVec{N, T}, y::ComplexorRealVec{N, T}, z::ComplexorRealVec{N, T}) where {N, T <: FloatTypes}
a = fmadd(x, y, z)
return ComplexVec{N, T}(fneg(a.re), fneg(a.im))
end
@inline fnmsub(x, y, z) = fneg(fmadd(x, y, z))

# scalar fallbacks
@inline fmul(x::Union{T, Complex{T}}, y::Union{T, Complex{T}}) where T = x * y
@inline fadd(x::Union{T, Complex{T}}, y::Union{T, Complex{T}}) where T = x + y
@inline fsub(x::Union{T, Complex{T}}, y::Union{T, Complex{T}}) where T = x - y
@inline fneg(x::Union{T, Complex{T}}) where T = -x
18 changes: 9 additions & 9 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@
for f in (:fmadd, :fmsub, :fnmadd, :fnmsub)
@eval begin
@inline $f(x::Vec{N, T}, y::Vec{N, T}, z::Vec{N, T}) where {N, T <: FloatTypes} = Vec($f(x.data, y.data, z.data))
@inline $f(x::ScalarTypes, y::Vec{N, T}, z::Vec{N, T}) where {N, T <: FloatTypes} = $f(constantvector(x, Vec{N, T}), y, z)
@inline $f(x::Vec{N, T}, y::ScalarTypes, z::Vec{N, T}) where {N, T <: FloatTypes} = $f(x, constantvector(y, Vec{N, T}), z)
@inline $f(x::ScalarTypes, y::ScalarTypes, z::Vec{N, T}) where {N, T <: FloatTypes} = $f(constantvector(x, Vec{N, T}), constantvector(y, Vec{N, T}), z)
@inline $f(x::Vec{N, T}, y::Vec{N, T}, z::ScalarTypes) where {N, T <: FloatTypes} = $f(x, y, constantvector(z, Vec{N, T}))
@inline $f(x::ScalarTypes, y::Vec{N, T}, z::ScalarTypes) where {N, T <: FloatTypes} = $f(constantvector(x, Vec{N, T}), y, constantvector(z, Vec{N, T}))
@inline $f(x::Vec{N, T}, y::ScalarTypes, z::ScalarTypes) where {N, T <: FloatTypes} = $f(x, constantvector(y, Vec{N, T}), constantvector(z, Vec{N, T}))
@inline $f(x::Vec{N, T}, y::Vec{N, T}, z::T) where {N, T <: FloatTypes} = $f(promote(x, y, z)...)
@inline $f(x::Vec{N, T}, y::T, z::T) where {N, T <: FloatTypes} = $f(promote(x, y, z)...)
@inline $f(x::T, y::Vec{N, T}, z::Vec{N, T}) where {N, T <: FloatTypes} = $f(promote(x, y, z)...)
@inline $f(x::T, y::Vec{N, T}, z::T) where {N, T <: FloatTypes} = $f(promote(x, y, z)...)
@inline $f(x::Vec{N, T}, y::T, z::Vec{N, T}) where {N, T <: FloatTypes} = $f(promote(x, y, z)...)
@inline $f(x::T, y::T, z::Vec{N, T}) where {N, T <: FloatTypes} = $f(promote(x, y, z)...)
end
end

for f in (:fadd, :fsub, :fmul, :fdiv)
@eval begin
@inline $f(x::Vec{N, T}, y::Vec{N, T}) where {N, T <: FloatTypes} = Vec($f(x.data, y.data))
@inline $f(x::Vec{N, T}, y::ScalarTypes) where {N, T <: FloatTypes} = $f(x, constantvector(y, Vec{N, T}))
@inline $f(x::ScalarTypes, y::Vec{N, T}) where {N, T <: FloatTypes} = $f(constantvector(x, Vec{N, T}), y)
@inline $f(x::Vec{N, T}, y::T) where {N, T <: FloatTypes} = $f(promote(x, y)...)
@inline $f(x::T, y::Vec{N, T}) where {N, T <: FloatTypes} = $f(promote(x, y)...)
end
end

@inline Base.checkbounds(v::ComplexorRealVec{N, T}, i::IntegerTypes) where {N, T} = (i < 1 || i > N) && Base.throw_boundserror(v, i)
@inline Base.checkbounds(v::ComplexOrRealVec{N, T}, i::IntegerTypes) where {N, T} = (i < 1 || i > N) && Base.throw_boundserror(v, i)

Base.@propagate_inbounds function Base.getindex(v::Vec{N, T}, i::IntegerTypes) where {N, T}
@boundscheck checkbounds(v, i)
Expand Down
16 changes: 14 additions & 2 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,27 @@ const LLVMType = Dict{DataType, String}(
Float64 => "double",
)

const ScalarOrVec{N, T} = Union{ScalarTypes, Vec{N, T}}

Base.convert(::Type{Vec{N, T}}, x::Vec{N, T}) where {N, T <: FloatTypes} = x
Base.convert(::Type{Vec{N, T}}, x::T) where {N, T <: ScalarTypes} = constantvector(x, Vec{N, T})

Base.promote_rule(::Type{Vec{N, T}}, ::Type{T}) where {N, T <: FloatTypes} = Vec{N, T}

# Complex Types

struct ComplexVec{N, T<:FloatTypes}
re::LVec{N, T}
im::LVec{N, T}
end

const ComplexorRealVec{N, T} = Union{Vec{N, T}, ComplexVec{N, T}}
const ComplexOrRealVec{N, T} = Union{Vec{N, T}, ComplexVec{N, T}}

ComplexVec(x::NTuple{N, T}, y::NTuple{N, T}) where {N, T <: FloatTypes} = ComplexVec(LVec{N, T}(x), LVec{N, T}(y))

ComplexVec(z::NTuple{N, Complex{T}}) where {N, T <: FloatTypes} = ComplexVec(real.(z), imag.(z))

Base.convert(::Type{ComplexVec{N, T}}, z::ComplexVec{N, T}) where {N, T <: FloatTypes} = z
Base.convert(::Type{ComplexVec{N, T}}, z::Complex{T}) where {N, T <: FloatTypes} = constantvector(z, ComplexVec{N, T})
Base.convert(::Type{ComplexVec{N, T}}, x::T) where {N, T <: ScalarTypes} = ComplexVec{N, T}(constantvector(x, LVec{N, T}), constantvector(zero(T), LVec{N, T}))

Base.promote_rule(::Type{ComplexVec{N, T}}, ::Type{Complex{T}}) where {N, T <: FloatTypes} = ComplexVec{N, T}
217 changes: 75 additions & 142 deletions test/complex_test.jl
Original file line number Diff line number Diff line change
@@ -1,144 +1,77 @@
# test complex

let
p = complex.(ntuple(i->rand(), 2), ntuple(i->rand(), 2))
p2 = complex.(ntuple(i->rand(), 2), ntuple(i->rand(), 2))
pr = ntuple(i->rand(), 2)

pc = SIMDMath.ComplexVec(p)
pc2 = SIMDMath.ComplexVec(p2)
pr1 = SIMDMath.Vec(pr)

# multiply

pcmul = SIMDMath.fmul(pc, pc2)
pmul = p .* p2
@test pcmul.re[1].value pmul[1].re
@test pcmul.im[1].value pmul[1].im
@test pcmul.re[2].value pmul[2].re
@test pcmul.im[2].value pmul[2].im

pcmul = SIMDMath.fmul(pc, pr1)
@test pcmul == SIMDMath.fmul(pr1, pc)
pmul = p .* pr
@test pcmul.re[1].value pmul[1].re
@test pcmul.im[1].value pmul[1].im
@test pcmul.re[2].value pmul[2].re
@test pcmul.im[2].value pmul[2].im

# add

pcmul = SIMDMath.fadd(pc, pc2)
pmul = p .+ p2
@test pcmul.re[1].value pmul[1].re
@test pcmul.im[1].value pmul[1].im
@test pcmul.re[2].value pmul[2].re
@test pcmul.im[2].value pmul[2].im

pcmul = SIMDMath.fadd(pc, pr1)
@test pcmul == SIMDMath.fadd(pr1, pc)
pmul = p .+ pr
@test pcmul.re[1].value pmul[1].re
@test pcmul.im[1].value pmul[1].im
@test pcmul.re[2].value pmul[2].re
@test pcmul.im[2].value pmul[2].im

# subtract

pcmul = SIMDMath.fsub(pc, pc2)
pmul = p .- p2
@test pcmul.re[1].value pmul[1].re
@test pcmul.im[1].value pmul[1].im
@test pcmul.re[2].value pmul[2].re
@test pcmul.im[2].value pmul[2].im

pcmul = SIMDMath.fsub(pc, pr1)
@test pcmul == SIMDMath.fsub(pr1, pc)
pmul = p .- pr
@test pcmul.re[1].value pmul[1].re
@test pcmul.im[1].value pmul[1].im
@test pcmul.re[2].value pmul[2].re
@test pcmul.im[2].value pmul[2].im

# multiply add

pcmul = SIMDMath.fmadd(pc, pc2, pc)
pmul = muladd.(p, p2, p)
@test pcmul.re[1].value pmul[1].re
@test pcmul.im[1].value pmul[1].im
@test pcmul.re[2].value pmul[2].re
@test pcmul.im[2].value pmul[2].im

pcmul = SIMDMath.fmadd(pc, pr1, pc)
@test pcmul == SIMDMath.fmadd(pr1, pc, pc)
pmul = muladd.(p, pr, p)
@test pcmul.re[1].value pmul[1].re
@test pcmul.im[1].value pmul[1].im
@test pcmul.re[2].value pmul[2].re
@test pcmul.im[2].value pmul[2].im

pcmul = SIMDMath.fmadd(pc, pr1, pr1)
pmul = muladd.(p, pr, pr)
@test pcmul.re[1].value pmul[1].re
@test pcmul.im[1].value pmul[1].im
@test pcmul.re[2].value pmul[2].re
@test pcmul.im[2].value pmul[2].im

# multiply subtract

pcmul = SIMDMath.fmsub(pc, pc2, pc)
pmul = @. p*p2 - p
@test pcmul.re[1].value pmul[1].re
@test pcmul.im[1].value pmul[1].im
@test pcmul.re[2].value pmul[2].re
@test pcmul.im[2].value pmul[2].im

pcmul = SIMDMath.fmsub(pc, pr1, pc)
@test pcmul == SIMDMath.fmsub(pr1, pc, pc)
pmul = @. p*pr - p
@test pcmul.re[1].value pmul[1].re
@test pcmul.im[1].value pmul[1].im
@test pcmul.re[2].value pmul[2].re
@test pcmul.im[2].value pmul[2].im

pcmul = SIMDMath.fmsub(pc, pr1, pr1)
pmul = @. p*pr - pr
@test pcmul.re[1].value pmul[1].re
@test pcmul.im[1].value pmul[1].im
@test pcmul.re[2].value pmul[2].re
@test pcmul.im[2].value pmul[2].im

# complex negated multiply-add
# -a*b + c
pcmul = SIMDMath.fnmadd(pc, pc2, pc)
pmul = @. -p*p2 + p
@test pcmul.re[1].value pmul[1].re
@test pcmul.im[1].value pmul[1].im
@test pcmul.re[2].value pmul[2].re
@test pcmul.im[2].value pmul[2].im

# -a*b - c
pcmul = SIMDMath.fnmsub(pc, pc2, pc)
pmul = @. -p*p2 - p
@test pcmul.re[1].value pmul[1].re
@test pcmul.im[1].value pmul[1].im
@test pcmul.re[2].value pmul[2].re
@test pcmul.im[2].value pmul[2].im


P1 = (1.1, 1.2, 1.4, 1.5, 1.3, 1.4, 1.5, 1.6, 1.7, 1.2, 1.2, 2.1, 3.1, 1.4, 1.5)
P2 = (1.1, 1.2, 1.4, 1.53, 1.32, 1.41, 1.52, 1.64, 1.4, 1.0, 1.6, 2.5, 3.1, 1.9, 1.2)
pp3 = pack_poly((P1, P2))
z = 1.2 + 1.1im
s = horner_simd(z, pp3)
e = evalpoly(z, P1)

@test s.re[1].value == e.re
@test s.im[1].value == e.im

e = evalpoly(z, P2)
@test s.re[2].value == e.re
@test s.im[2].value == e.im


end
using SIMDMath: fmul, fadd, fsub
using SIMDMath: fmadd, fmsub, fnmadd, fnmsub
using SIMDMath: ComplexVec, Vec

# define scalar functions
mulsub(a, b, c) = a*b - c
nmuladd(a, b, c) = -a*b + c
nmulsub(a, b, c) = -a*b - c

cvec1 = complex.(ntuple(i->rand(), 2), ntuple(i->rand(), 2))
cvec2 = complex.(ntuple(i->rand(), 2), ntuple(i->rand()*(-1)^i, 2))
cvec3 = complex.(ntuple(i->rand()*(-1)^i, 2), ntuple(i->rand()*(-1)^(2i), 2))

rvec1 = ntuple(i->rand(), 2)
rvec2 = ntuple(i->rand()*(-1)^(i), 2)
rvec3 = ntuple(i->rand()*(-1)^(2i), 2)

cscal1 = 1.2 + 1.3im
cscal2 = 2.1 - 1.9im
cscal3 = -3.1 - 3.4im

rscal1 = 4.5
rscal2 = -1.2
rscal3 = 6.5

for (f, f2) in ((:fmul, :*), (:fadd, :+), (:fsub, :-))
@eval begin
for a in ((cvec1, ComplexVec(cvec1)), (cvec2, ComplexVec(cvec2)), (cvec3, ComplexVec(cvec3)), (rvec1, Vec(rvec1)), (rvec2, Vec(rvec2)), (rvec3, Vec(rvec3)), (cscal1, cscal1), (cscal3, cscal3), (cscal3, cscal3), (rscal1, rscal1), (rscal2, rscal2), (rscal3, rscal3))
for b in ((cvec1, ComplexVec(cvec1)), (cvec2, ComplexVec(cvec2)), (cvec3, ComplexVec(cvec3)), (rvec1, Vec(rvec1)), (rvec2, Vec(rvec2)), (rvec3, Vec(rvec3)), (cscal1, cscal1), (cscal3, cscal3), (cscal3, cscal3), (rscal1, rscal1), (rscal2, rscal2), (rscal3, rscal3))

vec = $f(a[2], b[2])
scal = @. $f2(a[1], b[1])
@test vec[1] scal[1]
if length(scal) == 2
@test vec[2] scal[2]
end
end
end
end
end

for (f, f2) in ((:fmadd, :muladd), (:fmsub, :mulsub), (:fnmadd, :nmuladd), (:fnmsub, :nmulsub))
@eval begin
for a in ((cvec1, ComplexVec(cvec1)), (cvec2, ComplexVec(cvec2)), (cvec3, ComplexVec(cvec3)), (rvec1, Vec(rvec1)), (rvec2, Vec(rvec2)), (rvec3, Vec(rvec3)), (cscal1, cscal1), (cscal3, cscal3), (cscal3, cscal3), (rscal1, rscal1), (rscal2, rscal2), (rscal3, rscal3))
for b in ((cvec1, ComplexVec(cvec1)), (cvec2, ComplexVec(cvec2)), (cvec3, ComplexVec(cvec3)), (rvec1, Vec(rvec1)), (rvec2, Vec(rvec2)), (rvec3, Vec(rvec3)), (cscal1, cscal1), (cscal3, cscal3), (cscal3, cscal3), (rscal1, rscal1), (rscal2, rscal2), (rscal3, rscal3))
for c in ((cvec1, ComplexVec(cvec1)), (cvec2, ComplexVec(cvec2)), (cvec3, ComplexVec(cvec3)), (rvec1, Vec(rvec1)), (rvec2, Vec(rvec2)), (rvec3, Vec(rvec3)), (cscal1, cscal1), (cscal3, cscal3), (cscal3, cscal3), (rscal1, rscal1), (rscal2, rscal2), (rscal3, rscal3))

vec = $f(a[2], b[2], c[2])
scal = @. $f2(a[1], b[1], c[1])
@test vec[1] scal[1]
if length(scal) == 2
@test vec[2] scal[2]
end

end
end
end
end
end

@test convert(ComplexVec{4, Float64}, 1.2) == ComplexVec{4, Float64}((1.2, 1.2, 1.2, 1.2), (0.0, 0.0, 0.0, 0.0))

P1 = (1.1, 1.2, 1.4, 1.5, 1.3, 1.4, 1.5, 1.6, 1.7, 1.2, 1.2, 2.1, 3.1, 1.4, 1.5)
P2 = (1.1, 1.2, 1.4, 1.53, 1.32, 1.41, 1.52, 1.64, 1.4, 1.0, 1.6, 2.5, 3.1, 1.9, 1.2)
pp3 = pack_poly((P1, P2))
z = 1.2 + 1.1im
s = horner_simd(z, pp3)
e = evalpoly(z, P1)

@test s[1].re == e.re
@test s[1].im == e.im

e = evalpoly(z, P2)
@test s[2].re == e.re
@test s[2].im == e.im

0 comments on commit c953bdd

Please sign in to comment.