Skip to content

Commit

Permalink
LPtr
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Apr 25, 2024
1 parent 3f5b155 commit 168d962
Showing 1 changed file with 35 additions and 31 deletions.
66 changes: 35 additions & 31 deletions src/TriangularSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ using IfElse: ifelse
using LoopVectorization
using Polyester

const LPtr{T} = Core.LLVMPtr{T,0}
_lptr(x::Ptr{T}) where {T} = reinterpret(LPtr{T}, x)
_lptr(x) = x
_ptr(x::LPtr{T}) where {T} = reinterpret(Ptr{T}, x)
_ptr(x) = x
@inline reassemble_tup(::Type{T}, t) where {T} =
LoopVectorization.reassemble_tuple(T, map(_ptr, t))
@inline flatten_to_tup(t) = map(_lptr, LoopVectorization.flatten_to_tuple(t))

@generated function solve_AU(
A::VecUnroll{Nm1},
spu::AbstractStridedPointer,
Expand Down Expand Up @@ -65,7 +74,7 @@ end
quote
$(Expr(:meta, :inline))
mask = $(VectorizationBase.Mask{W})(_mask)
spa, spu = LoopVectorization.reassemble_tuple($Args, args)
spa, spu = reassemble_tup($Args, args)
vstore!(spa, $Amn, $i, mask)
end
else
Expand All @@ -74,7 +83,7 @@ end
scale = UNIT ? nothing : :(Amn_n /= vload(spu, (n - 1, n - 1)))
quote
$(Expr(:meta, :inline))
spa, spu = LoopVectorization.reassemble_tuple($Args, args)
spa, spu = reassemble_tup($Args, args)
mask = $(VectorizationBase.Mask{W})(_mask)
Amn = getfield(vload(spa, $unroll, mask), :data)
Base.Cartesian.@nexprs $N n -> begin
Expand Down Expand Up @@ -120,7 +129,7 @@ end
end
quote
$(Expr(:meta, :inline))
spa, spu = LoopVectorization.reassemble_tuple($Args, args)
spa, spu = reassemble_tup($Args, args)
vstore!(spa, $Amn, $unroll)
end
else
Expand All @@ -130,7 +139,7 @@ end
scale = UNIT ? nothing : :(Amn_n /= vload(spu, (n - 1, n - 1)))
quote
$(Expr(:meta, :inline))
spa, spu = LoopVectorization.reassemble_tuple($Args, args)
spa, spu = reassemble_tup($Args, args)
Amn = getfield(vload(spa, $double_unroll), :data)
Base.Cartesian.@nexprs $N n -> begin
Amn_n = getfield(Amn, n)
Expand Down Expand Up @@ -468,7 +477,7 @@ end
while m < M - WU + 1
n = Nr
if n > 0
let t = (spa, spu), ft = LoopVectorization.flatten_to_tuple(t)
let t = (spa, spu), ft = flatten_to_tup(t)
BdivU_small_kern_u!(n, UF, Val(UNIT), WS, typeof(t), ft...)
end
end
Expand All @@ -488,7 +497,7 @@ end
n = Nr
if n > 0
let t = (spa, spu),
ft = LoopVectorization.flatten_to_tuple(t),
ft = flatten_to_tup(t),
mask = getfield(mask, :u) % UInt32

BdivU_small_kern!(n, mask, WS, Val(UNIT), typeof(t), ft...)
Expand Down Expand Up @@ -733,22 +742,17 @@ end
end
@inline function Mat(A::AbstractMatrix{T}) where {T}
r, c = LoopVectorization.ArrayInterface.stride_rank(A)
M, N = size(A)
M, N = size(A)
if r === static(1)
Mat{T,true}(pointer(A), stride(A,2), M, N)
Mat{T,true}(pointer(A), stride(A, 2), M, N)
else
@assert c === static(1)
Mat{T,false}(pointer(A), stride(A,1), M, N)
@assert c === static(1)
Mat{T,false}(pointer(A), stride(A, 1), M, N)
end
end

# C -= A * B
@inline function _schur_complement!(
C::Mat,
A::Mat,
B::Mat,
::Val{false}
)
@inline function _schur_complement!(C::Mat, A::Mat, B::Mat, ::Val{false})
# _turbo_! will not be inlined
@turbo warn_check_args = false for n in indices((C, B), 2),
m in indices((C, A), 1)
Expand All @@ -760,12 +764,7 @@ end
C[m, n] = Cmn
end
end
@inline function _schur_complement!(
C::Mat,
A::Mat,
B::Mat,
::Val{true}
)
@inline function _schur_complement!(C::Mat, A::Mat, B::Mat, ::Val{true})
# _turbo_! will not be inlined
@tturbo warn_check_args = false for n in indices((C, B), 2),
m in indices((C, A), 1)
Expand Down Expand Up @@ -839,7 +838,12 @@ function rdiv_block_N!(
n += B_normalized
repeat = n + B_normalized < N
N_temp = repeat ? N_temp : N - n
schur_complement!(Mat(spa, M, N_temp), Mat(spa_base, M, n), Mat(spu, n, N_temp), Val(false))
schur_complement!(
Mat(spa, M, N_temp),
Mat(spa_base, M, n),
Mat(spu, n, N_temp),
Val(false)
)
end
end
function rdiv_block_MandN!(
Expand Down Expand Up @@ -974,7 +978,7 @@ end
n = Nr # non factor of W remainder
if n > 0
let t = (spa, spu),
ft = LoopVectorization.flatten_to_tuple(t),
ft = flatten_to_tup(t),
mask = $(getfield(_mask(WS, r), :u) % UInt32)

BdivU_small_kern!(n, mask, $WS, $(Val(UNIT)), typeof(t), ft...)
Expand Down Expand Up @@ -1007,14 +1011,14 @@ end
if W == 2
quote
$(Expr(:meta, :inline))
spa, spu = LoopVectorization.reassemble_tuple(Args, args)
spa, spu = reassemble_tup(Args, args)
_ldiv_remainder!(spa, spu, N, Nr, $WS, $(Val(UNIT)), $(static(1)))
nothing
end
elseif W == 8
quote
# $(Expr(:meta, :inline))
spa, spu = LoopVectorization.reassemble_tuple(Args, args)
spa, spu = reassemble_tup(Args, args)
if m == M - 1
_ldiv_remainder!(spa, spu, N, Nr, static(8), $(Val(UNIT)), StaticInt(1))
else
Expand Down Expand Up @@ -1093,7 +1097,7 @@ end
else
quote
# $(Expr(:meta, :inline))
spa, spu = LoopVectorization.reassemble_tuple(Args, args)
spa, spu = reassemble_tup(Args, args)
Base.Cartesian.@nif $(W - 1) w -> m == M - w w ->
_ldiv_remainder!(spa, spu, N, Nr, $WS, $(Val(UNIT)), static(w))
nothing
Expand All @@ -1108,7 +1112,7 @@ end
::Val{UNIT}
) where {T,UNIT}
tup = (spa, spu)
ftup = LoopVectorization.flatten_to_tuple(tup)
ftup = flatten_to_tup(tup)
_ldiv_L!(M, N, Val(UNIT), typeof(tup), ftup...)
end

Expand All @@ -1122,7 +1126,7 @@ function _ldiv_L!(
::Type{Args},
args::Vararg{Any,K}
) where {UNIT,Args,K}
spa, spu = LoopVectorization.reassemble_tuple(Args, args)
spa, spu = reassemble_tup(Args, args)
T = eltype(spa)
WS = pick_vector_width(T)
W = Int(WS)
Expand All @@ -1134,7 +1138,7 @@ function _ldiv_L!(
while m < M - WS + 1
n = Nr # non factor of W remainder
if n > 0
let t = (spa, spu), ft = LoopVectorization.flatten_to_tuple(t)
let t = (spa, spu), ft = flatten_to_tup(t)
BdivU_small_kern_u!(n, StaticInt(1), Val(UNIT), WS, typeof(t), ft...)
end
end
Expand All @@ -1151,7 +1155,7 @@ function _ldiv_L!(
end
# remainder on `m`
if m < M
let tup = (spa, spu), ftup = LoopVectorization.flatten_to_tuple(tup)
let tup = (spa, spu), ftup = flatten_to_tup(tup)
ldiv_remainder!(M, N, m, Nr, WS, Val(UNIT), typeof(tup), ftup...)
end
end
Expand Down

0 comments on commit 168d962

Please sign in to comment.