Skip to content

Commit

Permalink
[CUSPARSE] Implement a sparse GEMV for CuSparseMatrixCSC * CuSparseVe…
Browse files Browse the repository at this point in the history
…ctor (#2488)

This operation is not used by default by `*`.

Co-authored-by: Tim Besard <tim.besard@gmail.com>
  • Loading branch information
amontoison and maleadt authored Sep 18, 2024
1 parent f78a857 commit a56682e
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 1 deletion.
39 changes: 38 additions & 1 deletion lib/cusparse/generic.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# generic APIs

export gather!, scatter!, axpby!, rot!
export vv!, sv!, sm!, gemm, gemm!, sddmm!
export vv!, sv!, sm!, gemv, gemm, gemm!, sddmm!
export bmm!

## API functions
Expand Down Expand Up @@ -574,6 +574,43 @@ function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse
return C
end

"""
y = gemv(transa, alpha, A, x, index, [algo])
Perform a product between a `CuSparseMatrix` and a `CuSparseVector`, returning a `CuSparseVector`.
This function should only be used for highly sparse matrices and vectors, as the result is expected
to have many non-zeros in practice.
For this reason, high-level functions like `mul!` and `*` internally convert the sparse vector into a
dense vector to use a more efficient CUSPARSE routine.
Supported formats for the sparse matrix are `CuSparseMatrixCSC` and `CuSparseMatrixCSR`.
"""
function gemv end

function gemv(transa::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T},
x::CuSparseVector{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT) where {T}
m, n = size(A)
p = length(x)
p == n || throw(DimensionMismatch("dimensions must match: x has length $p, A has length $m × $n"))
# we model x as a CuSparseMatrixCSC with one column.
B = CuSparseMatrixCSC(x)
C = gemm(transa, 'N', alpha, A, B, index, algo)
y = CuSparseVector(C)
return y
end

function gemv(transa::SparseChar, alpha::Number, A::CuSparseMatrixCSR{T},
x::CuSparseVector{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT) where {T}
m, n = size(A)
p = length(x)
p == n || throw(DimensionMismatch("dimensions must match: x has length $p, A has length $m × $n"))
# we model x as a CuSparseMatrixCSR with one column.
B = CuSparseMatrixCSR(x)
C = gemm(transa, 'N', alpha, A, B, index, algo)
y = CuSparseVector(C)
return y
end

for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR)
@eval begin
function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::$SparseMatrixType{T}, B::$SparseMatrixType{T},
Expand Down
1 change: 1 addition & 0 deletions lib/cusparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::C
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
mv_wrapper(tA, alpha, A, B, beta, C)
end

function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::CuSparseVector{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}}
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
mv_wrapper(tA, alpha, A, CuVector{T}(B), beta, C)
Expand Down
13 changes: 13 additions & 0 deletions test/libraries/cusparse/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,19 @@ for SparseMatrixType in keys(SPGEMM_ALGOS)
end
end
end

@testset "gemv $T" for T in [Float32, Float64, ComplexF32, ComplexF64]
for (transa, opa) in [('N', identity)]
A = sprand(T,25,10,0.2)
b = sprand(T,10,0.3)
dA = SparseMatrixType(A)
db = CuSparseVector(b)
alpha = rand(T)
y = alpha * opa(A) * b
dy = gemv(transa, alpha, dA, db, 'O')
@test collect(dy) y
end
end
end

if CUSPARSE.version() >= v"11.4.1"
Expand Down

0 comments on commit a56682e

Please sign in to comment.