Skip to content

Commit

Permalink
Add matrix alloc to ext
Browse files Browse the repository at this point in the history
  • Loading branch information
termi-official committed Sep 30, 2024
1 parent 896fee1 commit 3976701
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions ext/FerriteSparseMatrixCSR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,40 @@ function Ferrite.zero_out_columns!(K::SparseMatrixCSR, ch::ConstraintHandler)
end
end


function Ferrite.allocate_matrix(::Type{SparseMatrixCSR}, sp::AbstractSparsityPattern)
return allocate_matrix(SparseMatrixCSR{1,Float64,Int}, sp)
_allocate_matrix(SparseMatrixCSR{1, Float64, Int64}, sp, false)
end

function Ferrite.allocate_matrix(::Type{SparseMatrixCSR{1, Tv, Ti}}, sp::AbstractSparsityPattern) where {Tv, Ti}
_allocate_matrix(SparseMatrixCSR{1, Tv, Ti}, sp, false)

Check warning on line 78 in ext/FerriteSparseMatrixCSR.jl

View check run for this annotation

Codecov / codecov/patch

ext/FerriteSparseMatrixCSR.jl#L77-L78

Added lines #L77 - L78 were not covered by tests
end

function Ferrite.allocate_matrix(::Type{MatrixType}, sp::AbstractSparsityPattern) where {Bi, Tv, Ti, MatrixType <: SparseMatrixCSR{Bi, Tv, Ti}}
# Allocate CSC first ...
K = allocate_matrix(SparseMatrixCSC{Tv, Ti}, sp)
# ... and transform to SparseMatrixCSR
return SparseMatrixCSR{Bi}(transpose(sparse(transpose(K))))
function _allocate_matrix(::Type{SparseMatrixCSR{1, Tv, Ti}}, sp::AbstractSparsityPattern, sym::Bool) where {Tv, Ti}
# 1. Setup rowptr
rowptr = zeros(Ti, Ferrite.getnrows(sp) + 1)
rowptr[1] = 1
for (row, colidxs) in enumerate(Ferrite.eachrow(sp))
for col in colidxs
sym && row > col && continue
rowptr[row+1] += 1
end
end
cumsum!(rowptr, rowptr)
nnz = rowptr[end] - 1
# 2. Allocate colval and nzval now that nnz is known
colval = Vector{Ti}(undef, nnz)
nzval = zeros(Tv, nnz)
# 3. Populate colval.
k = 1
for (row, colidxs) in zip(1:Ferrite.getnrows(sp), Ferrite.eachrow(sp)) # pairs(eachrow(sp))
for col in colidxs
sym && row > col && continue
colval[k] = col
k += 1
end
end
S = SparseMatrixCSR{1}(Ferrite.getnrows(sp), Ferrite.getncols(sp), rowptr, colval, nzval)
return S
end

end

0 comments on commit 3976701

Please sign in to comment.