diff --git a/lib/cusparse/generic.jl b/lib/cusparse/generic.jl index 0aed52115d..e944645760 100644 --- a/lib/cusparse/generic.jl +++ b/lib/cusparse/generic.jl @@ -1,7 +1,7 @@ # generic APIs export gather!, scatter!, axpby!, rot! -export vv!, sv!, sm!, gemm, gemm!, sddmm! +export vv!, sv!, sm!, gemv, gemm, gemm!, sddmm! export bmm! ## API functions @@ -574,6 +574,43 @@ function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse return C end +""" + y = gemv(transa, alpha, A, x, index, [algo]) + +Perform a product between a `CuSparseMatrix` and a `CuSparseVector`, returning a `CuSparseVector`. +This function should only be used for highly sparse matrices and vectors, as the result is expected +to have many non-zeros in practice. +For this reason, high-level functions like `mul!` and `*` internally convert the sparse vector into a +dense vector to use a more efficient CUSPARSE routine. + +Supported formats for the sparse matrix are `CuSparseMatrixCSC` and `CuSparseMatrixCSR`. +""" +function gemv end + +function gemv(transa::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T}, + x::CuSparseVector{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT) where {T} + m, n = size(A) + p = length(x) + p == n || throw(DimensionMismatch("dimensions must match: x has length $p, A has length $m × $n")) + # we model x as a CuSparseMatrixCSC with one column. + B = CuSparseMatrixCSC(x) + C = gemm(transa, 'N', alpha, A, B, index, algo) + y = CuSparseVector(C) + return y +end + +function gemv(transa::SparseChar, alpha::Number, A::CuSparseMatrixCSR{T}, + x::CuSparseVector{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT) where {T} + m, n = size(A) + p = length(x) + p == n || throw(DimensionMismatch("dimensions must match: x has length $p, A has length $m × $n")) + # we model x as a CuSparseMatrixCSR with one column. + B = CuSparseMatrixCSR(x) + C = gemm(transa, 'N', alpha, A, B, index, algo) + y = CuSparseVector(C) + return y +end + for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR) @eval begin function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::$SparseMatrixType{T}, B::$SparseMatrixType{T}, diff --git a/lib/cusparse/interfaces.jl b/lib/cusparse/interfaces.jl index 27b10b8ce4..f6a7437162 100644 --- a/lib/cusparse/interfaces.jl +++ b/lib/cusparse/interfaces.jl @@ -73,6 +73,7 @@ function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::C tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA mv_wrapper(tA, alpha, A, B, beta, C) end + function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::CuSparseVector{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}} tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA mv_wrapper(tA, alpha, A, CuVector{T}(B), beta, C) diff --git a/test/libraries/cusparse/generic.jl b/test/libraries/cusparse/generic.jl index 4be9183839..1c6c74c05e 100644 --- a/test/libraries/cusparse/generic.jl +++ b/test/libraries/cusparse/generic.jl @@ -358,6 +358,19 @@ for SparseMatrixType in keys(SPGEMM_ALGOS) end end end + + @testset "gemv $T" for T in [Float32, Float64, ComplexF32, ComplexF64] + for (transa, opa) in [('N', identity)] + A = sprand(T,25,10,0.2) + b = sprand(T,10,0.3) + dA = SparseMatrixType(A) + db = CuSparseVector(b) + alpha = rand(T) + y = alpha * opa(A) * b + dy = gemv(transa, alpha, dA, db, 'O') + @test collect(dy) ≈ y + end + end end if CUSPARSE.version() >= v"11.4.1"