Skip to content

Commit

Permalink
Try to implement schur_complement! in a way to reduce monomorphizatio…
Browse files Browse the repository at this point in the history
…n, in particular when coupled with rflu
  • Loading branch information
chriselrod committed Apr 25, 2024
1 parent 089bbc9 commit 3f5b155
Showing 1 changed file with 112 additions and 8 deletions.
120 changes: 112 additions & 8 deletions src/TriangularSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ if isdefined(Base, :Experimental) &&
@eval Base.Experimental.@max_methods 1
end

using LayoutPointers: stridedpointer_preserve, StrideIndex
using LayoutPointers: stridedpointer_preserve
using VectorizationBase, LinearAlgebra #LoopVectorization
using VectorizationBase:
vfnmadd_fast,
Expand Down Expand Up @@ -701,6 +701,116 @@ function block_size(::Val{T}) where {T}
Static.floortostaticint(sqrt(elements_l2))
end

struct Mat{T,ColMajor} <: AbstractMatrix{T}
p::Ptr{T}
x::Int
M::Int
N::Int
end
Base.size(A::Mat) = (A.M, A.N)
Base.axes(A::Mat) = (CloseOpen(A.M), CloseOpen(A.N))
Base.transpose(A::Mat{T,true}) where {T} = Mat{T,false}(A.p, A.x, A.N, A.M)
Base.transpose(A::Mat{T,false}) where {T} = Mat{T,true}(A.p, A.x, A.N, A.M)
@inline function Base.getindex(
A::Mat{T,ColMajor},
i::Int,
j::Int
) where {T,ColMajor}
(; p, x) = A
offset = ColMajor ? i + j * x : i * x + j
unsafe_load(p, offset + 1)
end
@inline function Base.setindex!(
A::Mat{T,ColMajor},
v::T,
i::Int,
j::Int
) where {T,ColMajor}
(; p, x) = A
offset = ColMajor ? i + j * x : i * x + j
unsafe_store!(p, v, offset + 1)
v
end
@inline function Mat(A::AbstractMatrix{T}) where {T}
r, c = LoopVectorization.ArrayInterface.stride_rank(A)
M, N = size(A)
if r === static(1)
Mat{T,true}(pointer(A), stride(A,2), M, N)
else
@assert c === static(1)
Mat{T,false}(pointer(A), stride(A,1), M, N)
end
end

# C -= A * B
@inline function _schur_complement!(
C::Mat,
A::Mat,
B::Mat,
::Val{false}
)
# _turbo_! will not be inlined
@turbo warn_check_args = false for n in indices((C, B), 2),
m in indices((C, A), 1)

Cmn = C[m, n]
for k in indices((A, B), (2, 1))
Cmn -= A[m, k] * B[k, n]
end
C[m, n] = Cmn
end
end
@inline function _schur_complement!(
C::Mat,
A::Mat,
B::Mat,
::Val{true}
)
# _turbo_! will not be inlined
@tturbo warn_check_args = false for n in indices((C, B), 2),
m in indices((C, A), 1)

Cmn = C[m, n]
for k in indices((A, B), (2, 1))
Cmn -= A[m, k] * B[k, n]
end
C[m, n] = Cmn
end
end
@inline function schur_complement!(
C::Mat,
A::Mat{<:Any,false},
B::Mat{<:Any,false},
::Val{THREAD}
) where {THREAD}
# C - A * B == (C' - B' * A')'
_schur_complement!(transpose(C), transpose(B), transpose(A), Val(THREAD))
end
@inline function schur_complement!(
C::Mat,
A::Mat,
B::Mat,
::Val{THREAD}
) where {THREAD}
_schur_complement!(C, A, B, Val(THREAD))
end
@inline function schur_complement!(C, A, B, ::Val{THREAD}) where {THREAD}
schur_complement!(Mat(C), Mat(A), Mat(B), Val(THREAD))
end

@inline function Mat(sp::StridedPointer{T,2,1}, M, N) where {T}
x, y = strides(stridedpointer(sp))
st = sizeof(T)
@assert x == st
Mat{T,true}(pointer(sp), y >>> trailing_zeros(st), M, N)
end
@inline function Mat(sp::StridedPointer{T,2,2}, M, N) where {T}
x, y = strides(stridedpointer(sp))
st = sizeof(T)
@assert y == st
Mat{T,false}(pointer(sp), x >>> trailing_zeros(st), M, N)
end

function rdiv_block_N!(
spa::AbstractStridedPointer{T},
spu,
Expand Down Expand Up @@ -729,13 +839,7 @@ function rdiv_block_N!(
n += B_normalized
repeat = n + B_normalized < N
N_temp = repeat ? N_temp : N - n
@turbo for c CloseOpen(N_temp), m CloseOpen(M)
Cmn = spa_base[m, n+c]
for k CloseOpen(n)
Cmn -= spa_base[m, k] * spu[k, c]
end
spa_base[m, n+c] = Cmn
end
schur_complement!(Mat(spa, M, N_temp), Mat(spa_base, M, n), Mat(spu, n, N_temp), Val(false))
end
end
function rdiv_block_MandN!(
Expand Down

0 comments on commit 3f5b155

Please sign in to comment.