From 168d962680897a7855f2f8a063fc05e2fdd4ceb5 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Thu, 25 Apr 2024 17:33:54 -0400 Subject: [PATCH] LPtr --- src/TriangularSolve.jl | 66 ++++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/src/TriangularSolve.jl b/src/TriangularSolve.jl index 98f6344..a16cbb3 100644 --- a/src/TriangularSolve.jl +++ b/src/TriangularSolve.jl @@ -20,6 +20,15 @@ using IfElse: ifelse using LoopVectorization using Polyester +const LPtr{T} = Core.LLVMPtr{T,0} +_lptr(x::Ptr{T}) where {T} = reinterpret(LPtr{T}, x) +_lptr(x) = x +_ptr(x::LPtr{T}) where {T} = reinterpret(Ptr{T}, x) +_ptr(x) = x +@inline reassemble_tup(::Type{T}, t) where {T} = + LoopVectorization.reassemble_tuple(T, map(_ptr, t)) +@inline flatten_to_tup(t) = map(_lptr, LoopVectorization.flatten_to_tuple(t)) + @generated function solve_AU( A::VecUnroll{Nm1}, spu::AbstractStridedPointer, @@ -65,7 +74,7 @@ end quote $(Expr(:meta, :inline)) mask = $(VectorizationBase.Mask{W})(_mask) - spa, spu = LoopVectorization.reassemble_tuple($Args, args) + spa, spu = reassemble_tup($Args, args) vstore!(spa, $Amn, $i, mask) end else @@ -74,7 +83,7 @@ end scale = UNIT ? nothing : :(Amn_n /= vload(spu, (n - 1, n - 1))) quote $(Expr(:meta, :inline)) - spa, spu = LoopVectorization.reassemble_tuple($Args, args) + spa, spu = reassemble_tup($Args, args) mask = $(VectorizationBase.Mask{W})(_mask) Amn = getfield(vload(spa, $unroll, mask), :data) Base.Cartesian.@nexprs $N n -> begin @@ -120,7 +129,7 @@ end end quote $(Expr(:meta, :inline)) - spa, spu = LoopVectorization.reassemble_tuple($Args, args) + spa, spu = reassemble_tup($Args, args) vstore!(spa, $Amn, $unroll) end else @@ -130,7 +139,7 @@ end scale = UNIT ? nothing : :(Amn_n /= vload(spu, (n - 1, n - 1))) quote $(Expr(:meta, :inline)) - spa, spu = LoopVectorization.reassemble_tuple($Args, args) + spa, spu = reassemble_tup($Args, args) Amn = getfield(vload(spa, $double_unroll), :data) Base.Cartesian.@nexprs $N n -> begin Amn_n = getfield(Amn, n) @@ -468,7 +477,7 @@ end while m < M - WU + 1 n = Nr if n > 0 - let t = (spa, spu), ft = LoopVectorization.flatten_to_tuple(t) + let t = (spa, spu), ft = flatten_to_tup(t) BdivU_small_kern_u!(n, UF, Val(UNIT), WS, typeof(t), ft...) end end @@ -488,7 +497,7 @@ end n = Nr if n > 0 let t = (spa, spu), - ft = LoopVectorization.flatten_to_tuple(t), + ft = flatten_to_tup(t), mask = getfield(mask, :u) % UInt32 BdivU_small_kern!(n, mask, WS, Val(UNIT), typeof(t), ft...) @@ -733,22 +742,17 @@ end end @inline function Mat(A::AbstractMatrix{T}) where {T} r, c = LoopVectorization.ArrayInterface.stride_rank(A) - M, N = size(A) + M, N = size(A) if r === static(1) - Mat{T,true}(pointer(A), stride(A,2), M, N) + Mat{T,true}(pointer(A), stride(A, 2), M, N) else - @assert c === static(1) - Mat{T,false}(pointer(A), stride(A,1), M, N) + @assert c === static(1) + Mat{T,false}(pointer(A), stride(A, 1), M, N) end end # C -= A * B -@inline function _schur_complement!( - C::Mat, - A::Mat, - B::Mat, - ::Val{false} -) +@inline function _schur_complement!(C::Mat, A::Mat, B::Mat, ::Val{false}) # _turbo_! will not be inlined @turbo warn_check_args = false for n in indices((C, B), 2), m in indices((C, A), 1) @@ -760,12 +764,7 @@ end C[m, n] = Cmn end end -@inline function _schur_complement!( - C::Mat, - A::Mat, - B::Mat, - ::Val{true} -) +@inline function _schur_complement!(C::Mat, A::Mat, B::Mat, ::Val{true}) # _turbo_! will not be inlined @tturbo warn_check_args = false for n in indices((C, B), 2), m in indices((C, A), 1) @@ -839,7 +838,12 @@ function rdiv_block_N!( n += B_normalized repeat = n + B_normalized < N N_temp = repeat ? N_temp : N - n - schur_complement!(Mat(spa, M, N_temp), Mat(spa_base, M, n), Mat(spu, n, N_temp), Val(false)) + schur_complement!( + Mat(spa, M, N_temp), + Mat(spa_base, M, n), + Mat(spu, n, N_temp), + Val(false) + ) end end function rdiv_block_MandN!( @@ -974,7 +978,7 @@ end n = Nr # non factor of W remainder if n > 0 let t = (spa, spu), - ft = LoopVectorization.flatten_to_tuple(t), + ft = flatten_to_tup(t), mask = $(getfield(_mask(WS, r), :u) % UInt32) BdivU_small_kern!(n, mask, $WS, $(Val(UNIT)), typeof(t), ft...) @@ -1007,14 +1011,14 @@ end if W == 2 quote $(Expr(:meta, :inline)) - spa, spu = LoopVectorization.reassemble_tuple(Args, args) + spa, spu = reassemble_tup(Args, args) _ldiv_remainder!(spa, spu, N, Nr, $WS, $(Val(UNIT)), $(static(1))) nothing end elseif W == 8 quote # $(Expr(:meta, :inline)) - spa, spu = LoopVectorization.reassemble_tuple(Args, args) + spa, spu = reassemble_tup(Args, args) if m == M - 1 _ldiv_remainder!(spa, spu, N, Nr, static(8), $(Val(UNIT)), StaticInt(1)) else @@ -1093,7 +1097,7 @@ end else quote # $(Expr(:meta, :inline)) - spa, spu = LoopVectorization.reassemble_tuple(Args, args) + spa, spu = reassemble_tup(Args, args) Base.Cartesian.@nif $(W - 1) w -> m == M - w w -> _ldiv_remainder!(spa, spu, N, Nr, $WS, $(Val(UNIT)), static(w)) nothing @@ -1108,7 +1112,7 @@ end ::Val{UNIT} ) where {T,UNIT} tup = (spa, spu) - ftup = LoopVectorization.flatten_to_tuple(tup) + ftup = flatten_to_tup(tup) _ldiv_L!(M, N, Val(UNIT), typeof(tup), ftup...) end @@ -1122,7 +1126,7 @@ function _ldiv_L!( ::Type{Args}, args::Vararg{Any,K} ) where {UNIT,Args,K} - spa, spu = LoopVectorization.reassemble_tuple(Args, args) + spa, spu = reassemble_tup(Args, args) T = eltype(spa) WS = pick_vector_width(T) W = Int(WS) @@ -1134,7 +1138,7 @@ function _ldiv_L!( while m < M - WS + 1 n = Nr # non factor of W remainder if n > 0 - let t = (spa, spu), ft = LoopVectorization.flatten_to_tuple(t) + let t = (spa, spu), ft = flatten_to_tup(t) BdivU_small_kern_u!(n, StaticInt(1), Val(UNIT), WS, typeof(t), ft...) end end @@ -1151,7 +1155,7 @@ function _ldiv_L!( end # remainder on `m` if m < M - let tup = (spa, spu), ftup = LoopVectorization.flatten_to_tuple(tup) + let tup = (spa, spu), ftup = flatten_to_tup(tup) ldiv_remainder!(M, N, m, Nr, WS, Val(UNIT), typeof(tup), ftup...) end end