diff --git a/lib/cusolver/CUSOLVER.jl b/lib/cusolver/CUSOLVER.jl index f923b91374..f4b222dba7 100644 --- a/lib/cusolver/CUSOLVER.jl +++ b/lib/cusolver/CUSOLVER.jl @@ -60,11 +60,21 @@ end const idle_dense_handles = HandleCache{CuContext,cusolverDnHandle_t}(dense_handle_ctor, dense_handle_dtor) +# fat handle, includes a cache +struct cusolverDnHandle + handle::cusolverDnHandle_t + workspace_gpu::CuVector{UInt8} + workspace_cpu::Vector{UInt8} + info::CuVector{Cint} +end +Base.unsafe_convert(::Type{Ptr{cusolverDnContext}}, handle::cusolverDnHandle) = + handle.handle + function dense_handle() cuda = CUDA.active_state() # every task maintains library state per device - LibraryState = @NamedTuple{handle::cusolverDnHandle_t, stream::CuStream} + LibraryState = @NamedTuple{handle::cusolverDnHandle, stream::CuStream} states = get!(task_local_storage(), :CUSOLVER_dense) do Dict{CuContext,LibraryState}() end::Dict{CuContext,LibraryState} @@ -72,13 +82,21 @@ function dense_handle() # get library state @noinline function new_state(cuda) new_handle = pop!(idle_dense_handles, cuda.context) + + workspace_gpu = CuVector{UInt8}(undef, 0) + workspace_cpu = Vector{UInt8}(undef, 0) + info = CuVector{Cint}(undef, 1) + fat_handle = cusolverDnHandle(new_handle, workspace_gpu, workspace_cpu, info) + finalizer(current_task()) do task + CUDA.unsafe_free!(workspace_gpu) + CUDA.unsafe_free!(info) push!(idle_dense_handles, cuda.context, new_handle) end cusolverDnSetStream(new_handle, cuda.stream) - (; handle=new_handle, cuda.stream) + (; handle=fat_handle, cuda.stream) end state = get!(states, cuda.context) do new_state(cuda) diff --git a/lib/cusolver/dense.jl b/lib/cusolver/dense.jl index cdff8c43f4..d2cbb0bc8b 100644 --- a/lib/cusolver/dense.jl +++ b/lib/cusolver/dense.jl @@ -23,21 +23,20 @@ for (bname, fname,elty) in ((:cusolverDnSpotrf_bufferSize, :cusolverDnSpotrf, :F chkuplo(uplo) n = checksquare(A) lda = max(1, stride(A, 2)) + dh = dense_handle() function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), uplo, n, A, lda, out) + $bname(dh, uplo, n, A, lda, out) out[] * sizeof($elty) end - devinfo = CuArray{Cint}(undef, 1) - with_workspace(bufferSize) do buffer - $fname(dense_handle(), uplo, n, A, lda, - buffer, sizeof(buffer) ÷ sizeof($elty), devinfo) + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, uplo, n, A, lda, + buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) end - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chkargsok(BlasInt(info)) A, info @@ -62,12 +61,11 @@ for (fname,elty) in ((:cusolverDnSpotrs, :Float32), nrhs = size(B,2) lda = max(1, stride(A, 2)) ldb = max(1, stride(B, 2)) + dh = dense_handle() - devinfo = CuArray{Cint}(undef, 1) - $fname(dense_handle(), uplo, n, nrhs, A, lda, B, ldb, devinfo) + $fname(dh, uplo, n, nrhs, A, lda, B, ldb, dh.info) - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chkargsok(BlasInt(info)) B @@ -85,21 +83,20 @@ for (bname, fname,elty) in ((:cusolverDnSpotri_bufferSize, :cusolverDnSpotri, :F chkuplo(uplo) n = checksquare(A) lda = max(1, stride(A, 2)) + dh = dense_handle() function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), uplo, n, A, lda, out) + $bname(dh, uplo, n, A, lda, out) out[] * sizeof($elty) end - devinfo = CuArray{Cint}(undef, 1) - with_workspace(bufferSize) do buffer - $fname(dense_handle(), uplo, n, A, lda, - buffer, sizeof(buffer) ÷ sizeof($elty), devinfo) + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, uplo, n, A, lda, + buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) end - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chkargsok(BlasInt(info)) A @@ -116,20 +113,19 @@ for (bname, fname,elty) in ((:cusolverDnSgetrf_bufferSize, :cusolverDnSgetrf, :F function getrf!(A::StridedCuMatrix{$elty}, ipiv::CuVector{Cint}) m,n = size(A) lda = max(1, stride(A, 2)) + dh = dense_handle() function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), m, n, A, lda, out) + $bname(dh, m, n, A, lda, out) return out[] * sizeof($elty) end - devinfo = CuArray{Cint}(undef, 1) - with_workspace(bufferSize) do buffer - $fname(dense_handle(), m, n, A, lda, buffer, ipiv, devinfo) + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, m, n, A, lda, buffer, ipiv, dh.info) end - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chkargsok(BlasInt(info)) A, ipiv, info @@ -152,21 +148,20 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F function geqrf!(A::StridedCuMatrix{$elty}, tau::CuVector{$elty}) m, n = size(A) lda = max(1, stride(A, 2)) + dh = dense_handle() function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), m, n, A, lda, out) + $bname(dh, m, n, A, lda, out) return out[] * sizeof($elty) end - devinfo = CuArray{Cint}(undef, 1) - with_workspace(bufferSize) do buffer - $fname(dense_handle(), m, n, A, lda, tau, - buffer, sizeof(buffer) ÷ sizeof($elty), devinfo) + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, m, n, A, lda, tau, + buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) end - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chkargsok(BlasInt(info)) A, tau @@ -192,21 +187,20 @@ for (bname, fname,elty) in ((:cusolverDnSsytrf_bufferSize, :cusolverDnSsytrf, :F chkuplo(uplo) n = checksquare(A) lda = max(1, stride(A, 2)) + dh = dense_handle() function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), n, A, lda, out) + $bname(dh, n, A, lda, out) return out[] * sizeof($elty) end - devinfo = CuArray{Cint}(undef, 1) - with_workspace(bufferSize) do buffer - $fname(dense_handle(), uplo, n, A, lda, ipiv, - buffer, sizeof(buffer) ÷ sizeof($elty), devinfo) + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, uplo, n, A, lda, ipiv, + buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) end - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chkargsok(BlasInt(info)) A, ipiv, info @@ -245,12 +239,11 @@ for (fname,elty) in ((:cusolverDnSgetrs, :Float32), nrhs = size(B, 2) lda = max(1, stride(A, 2)) ldb = max(1, stride(B, 2)) + dh = dense_handle() - devinfo = CuArray{Cint}(undef, 1) - $fname(dense_handle(), trans, n, nrhs, A, lda, ipiv, B, ldb, devinfo) + $fname(dh, trans, n, nrhs, A, lda, ipiv, B, ldb, dh.info) - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chkargsok(BlasInt(info)) B @@ -293,21 +286,20 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, : end lda = max(1, stride(A, 2)) ldc = max(1, stride(C, 2)) + dh = dense_handle() function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), side, trans, m, n, k, A, lda, tau, C, ldc, out) + $bname(dh, side, trans, m, n, k, A, lda, tau, C, ldc, out) return out[] * sizeof($elty) end - devinfo = CuArray{Cint}(undef, 1) - with_workspace(bufferSize) do buffer - $fname(dense_handle(), side, trans, m, n, k, A, lda, tau, C, ldc, - buffer, sizeof(buffer) ÷ sizeof($elty), devinfo) + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, side, trans, m, n, k, A, lda, tau, C, ldc, + buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) end - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chkargsok(BlasInt(info)) C @@ -326,21 +318,20 @@ for (bname, fname, elty) in ((:cusolverDnSorgqr_bufferSize, :cusolverDnSorgqr, : n = min(m, size(A, 2)) lda = max(1, stride(A, 2)) k = length(tau) + dh = dense_handle() function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), m, n, k, A, lda, tau, out) + $bname(dh, m, n, k, A, lda, tau, out) return out[] * sizeof($elty) end - devinfo = CuArray{Cint}(undef, 1) - with_workspace(bufferSize) do buffer - $fname(dense_handle(), m, n, k, A, lda, tau, - buffer, sizeof(buffer) ÷ sizeof($elty), devinfo) + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, m, n, k, A, lda, tau, + buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) end - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chkargsok(BlasInt(info)) if n < size(A, 2) @@ -361,27 +352,26 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgebrd_bufferSize, :cusolverDnSg function gebrd!(A::StridedCuMatrix{$elty}) m, n = size(A) lda = max(1, stride(A, 2)) + dh = dense_handle() function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), m, n, out) + $bname(dh, m, n, out) return out[] * sizeof($elty) end - devinfo = CuArray{Cint}(undef, 1) k = min(m, n) D = CuArray{$relty}(undef, k) E = CUDA.zeros($relty, k) TAUQ = CuArray{$elty}(undef, k) TAUP = CuArray{$elty}(undef, k) - with_workspace(bufferSize) do buffer - $fname(dense_handle(), m, n, A, lda, D, E, TAUQ, TAUP, - buffer, sizeof(buffer) ÷ sizeof($elty), devinfo) + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, m, n, A, lda, D, E, TAUQ, TAUP, + buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) end - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chkargsok(BlasInt(info)) A, D, E, TAUQ, TAUP @@ -424,23 +414,22 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg error("jobvt must be one of 'A', 'S', 'O', or 'N'") end ldvt = Vt == CU_NULL ? 1 : max(1, stride(Vt, 2)) + dh = dense_handle() function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), m, n, out) + $bname(dh, m, n, out) return out[] * sizeof($elty) end rwork = CuArray{$relty}(undef, min(m, n) - 1) - devinfo = CuArray{Cint}(undef, 1) - with_workspace(bufferSize) do buffer - $fname(dense_handle(), jobu, jobvt, m, n, A, lda, S, U, ldu, Vt, ldvt, - buffer, sizeof(buffer) ÷ sizeof($elty), rwork, devinfo) + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, jobu, jobvt, m, n, A, lda, S, U, ldu, Vt, ldvt, + buffer, sizeof(buffer) ÷ sizeof($elty), rwork, dh.info) end unsafe_free!(rwork) - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chkargsok(BlasInt(info)) return U, S, Vt @@ -483,22 +472,21 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdj_bufferSize, :cusolverDnS cusolverDnCreateGesvdjInfo(params) cusolverDnXgesvdjSetTolerance(params[], tol) cusolverDnXgesvdjSetMaxSweeps(params[], max_sweeps) + dh = dense_handle() function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, + $bname(dh, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, out, params[]) return out[] * sizeof($elty) end - devinfo = CuArray{Cint}(undef, 1) - with_workspace(bufferSize) do buffer - $fname(dense_handle(), jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, - buffer, sizeof(buffer) ÷ sizeof($elty), devinfo, params[]) + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, + buffer, sizeof(buffer) ÷ sizeof($elty), dh.info, params[]) end - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chkargsok(BlasInt(info)) cusolverDnDestroyGesvdjInfo(params[]) @@ -536,21 +524,22 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdjBatched_bufferSize, :cuso cusolverDnXgesvdjSetTolerance(params[], tol) cusolverDnXgesvdjSetMaxSweeps(params[], max_sweeps) + dh = dense_handle() + resize!(dh.info, batchSize) + function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), jobz, m, n, A, lda, S, U, ldu, V, ldv, + $bname(dh, jobz, m, n, A, lda, S, U, ldu, V, ldv, out, params[], batchSize) return out[] * sizeof($elty) end - devinfo = CuArray{Cint}(undef, batchSize) - with_workspace(bufferSize) do buffer - $fname(dense_handle(), jobz, m, n, A, lda, S, U, ldu, V, ldv, - buffer, sizeof(buffer) ÷ sizeof($elty), devinfo, params[], batchSize) + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, jobz, m, n, A, lda, S, U, ldu, V, ldv, + buffer, sizeof(buffer) ÷ sizeof($elty), dh.info, params[], batchSize) end - info = @allowscalar collect(devinfo) - unsafe_free!(devinfo) + info = @allowscalar collect(dh.info) # Double check the solver's exit status for i = 1:batchSize @@ -591,26 +580,27 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdaStridedBatched_bufferSize ldv = max(1, stride(V, 2)) strideV = stride(V, 3) + dh = dense_handle() + resize!(dh.info, batchSize) + function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), jobz, rank, m, n, A, lda, strideA, + $bname(dh, jobz, rank, m, n, A, lda, strideA, S, strideS, U, ldu, strideU, V, ldv, strideV, out, batchSize) return out[] * sizeof($elty) end - devinfo = CuArray{Cint}(undef, batchSize) # residual storage h_RnrmF = Array{Cdouble}(undef, batchSize) - with_workspace(bufferSize) do buffer - $fname(dense_handle(), jobz, rank, m, n, A, lda, strideA, + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, jobz, rank, m, n, A, lda, strideA, S, strideS, U, ldu, strideU, V, ldv, strideV, - buffer, sizeof(buffer) ÷ sizeof($elty), devinfo, h_RnrmF, batchSize) + buffer, sizeof(buffer) ÷ sizeof($elty), dh.info, h_RnrmF, batchSize) end - info = @allowscalar collect(devinfo) - unsafe_free!(devinfo) + info = @allowscalar collect(dh.info) # Double check the solver's exit status for i = 1:batchSize @@ -631,24 +621,23 @@ for (jname, bname, fname, elty, relty) in ((:syevd!, :cusolverDnSsyevd_bufferSiz uplo::Char, A::StridedCuMatrix{$elty}) chkuplo(uplo) - n = checksquare(A) - lda = max(1, stride(A, 2)) - W = CuArray{$relty}(undef, n) + n = checksquare(A) + lda = max(1, stride(A, 2)) + W = CuArray{$relty}(undef, n) + dh = dense_handle() function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), jobz, uplo, n, A, lda, W, out) + $bname(dh, jobz, uplo, n, A, lda, W, out) return out[] * sizeof($elty) end - devinfo = CuArray{Cint}(undef, 1) - with_workspace(bufferSize) do buffer - $fname(dense_handle(), jobz, uplo, n, A, lda, W, - buffer, sizeof(buffer) ÷ sizeof($elty), devinfo) + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, jobz, uplo, n, A, lda, W, + buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) end - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chkargsok(BlasInt(info)) if jobz == 'N' @@ -675,25 +664,24 @@ for (jname, bname, fname, elty, relty) in ((:sygvd!, :cusolverDnSsygvd_bufferSiz if nB != nA throw(DimensionMismatch("Dimensions of A ($nA, $nA) and B ($nB, $nB) must match!")) end - n = nA - lda = max(1, stride(A, 2)) - ldb = max(1, stride(B, 2)) - W = CuArray{$relty}(undef, n) + n = nA + lda = max(1, stride(A, 2)) + ldb = max(1, stride(B, 2)) + W = CuArray{$relty}(undef, n) + dh = dense_handle() function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), itype, jobz, uplo, n, A, lda, B, ldb, W, out) + $bname(dh, itype, jobz, uplo, n, A, lda, B, ldb, W, out) return out[] * sizeof($elty) end - devinfo = CuArray{Cint}(undef, 1) - with_workspace(bufferSize) do buffer - $fname(dense_handle(), itype, jobz, uplo, n, A, lda, B, ldb, W, - buffer, sizeof(buffer) ÷ sizeof($elty), devinfo) + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, itype, jobz, uplo, n, A, lda, B, ldb, W, + buffer, sizeof(buffer) ÷ sizeof($elty), dh.info) end - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chkargsok(BlasInt(info)) if jobz == 'N' @@ -722,30 +710,29 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz if nB != nA throw(DimensionMismatch("Dimensions of A ($nA, $nA) and B ($nB, $nB) must match!")) end - n = nA - lda = max(1, stride(A, 2)) - ldb = max(1, stride(B, 2)) - W = CuArray{$relty}(undef, n) - params = Ref{syevjInfo_t}(C_NULL) + n = nA + lda = max(1, stride(A, 2)) + ldb = max(1, stride(B, 2)) + W = CuArray{$relty}(undef, n) + params = Ref{syevjInfo_t}(C_NULL) cusolverDnCreateSyevjInfo(params) cusolverDnXsyevjSetTolerance(params[], tol) cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps) + dh = dense_handle() function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), itype, jobz, uplo, n, A, lda, B, ldb, W, + $bname(dh, itype, jobz, uplo, n, A, lda, B, ldb, W, out, params[]) return out[] * sizeof($elty) end - devinfo = CuArray{Cint}(undef, 1) - with_workspace(bufferSize) do buffer - $fname(dense_handle(), itype, jobz, uplo, n, A, lda, B, ldb, W, - buffer, sizeof(buffer) ÷ sizeof($elty), devinfo, params[]) + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, itype, jobz, uplo, n, A, lda, B, ldb, W, + buffer, sizeof(buffer) ÷ sizeof($elty), dh.info, params[]) end - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chkargsok(BlasInt(info)) cusolverDnDestroySyevjInfo(params[]) @@ -772,12 +759,14 @@ for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBat # Set up information for the solver arguments chkuplo(uplo) - n = checksquare(A) - lda = max(1, stride(A, 2)) + n = checksquare(A) + lda = max(1, stride(A, 2)) batchSize = size(A,3) - W = CuArray{$relty}(undef, n,batchSize) - params = Ref{syevjInfo_t}(C_NULL) - devinfo = CuArray{Cint}(undef, batchSize) + W = CuArray{$relty}(undef, n,batchSize) + params = Ref{syevjInfo_t}(C_NULL) + + dh = dense_handle() + resize!(dh.info, batchSize) # Initialize the solver parameters cusolverDnCreateSyevjInfo(params) @@ -787,19 +776,18 @@ for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBat # Calculate the workspace size function bufferSize() out = Ref{Cint}(0) - $bname(dense_handle(), jobz, uplo, n, A, lda, W, out, params[], batchSize) + $bname(dh, jobz, uplo, n, A, lda, W, out, params[], batchSize) return out[] * sizeof($elty) end # Run the solver - with_workspace(bufferSize) do buffer - $fname(dense_handle(), jobz, uplo, n, A, lda, W, buffer, - sizeof(buffer) ÷ sizeof($elty), devinfo, params[], batchSize) + with_workspace(dh.workspace_gpu, bufferSize) do buffer + $fname(dh, jobz, uplo, n, A, lda, W, buffer, + sizeof(buffer) ÷ sizeof($elty), dh.info, params[], batchSize) end # Copy the solver info and delete the device memory - info = @allowscalar collect(devinfo) - unsafe_free!(devinfo) + info = @allowscalar collect(dh.info) # Double check the solver's exit status for i = 1:batchSize @@ -843,17 +831,17 @@ for (fname, elty) in ((:cusolverDnSpotrsBatched, :Float32), lda = max(1, stride(A[1], 2)) ldb = max(1, stride(B[1], 2)) batchSize = length(A) - devinfo = CuArray{Cint}(undef, 1) Aptrs = unsafe_batch(A) Bptrs = unsafe_batch(B) + dh = dense_handle() + # Run the solver - $fname(dense_handle(), uplo, n, nrhs, Aptrs, lda, Bptrs, ldb, devinfo, batchSize) + $fname(dh, uplo, n, nrhs, Aptrs, lda, Bptrs, ldb, dh.info, batchSize) # Copy the solver info and delete the device memory - info = @allowscalar devinfo[1] - unsafe_free!(devinfo) + info = @allowscalar dh.info[1] chklapackerror(BlasInt(info)) return B @@ -873,16 +861,17 @@ for (fname, elty) in ((:cusolverDnSpotrfBatched, :Float32), n = checksquare(A[1]) lda = max(1, stride(A[1], 2)) batchSize = length(A) - devinfo = CuArray{Cint}(undef, batchSize) Aptrs = unsafe_batch(A) + dh = dense_handle() + resize!(dh.info, batchSize) + # Run the solver - $fname(dense_handle(), uplo, n, Aptrs, lda, devinfo, batchSize) + $fname(dh, uplo, n, Aptrs, lda, dh.info, batchSize) # Copy the solver info and delete the device memory - info = @allowscalar collect(devinfo) - unsafe_free!(devinfo) + info = @allowscalar collect(dh.info) # Double check the solver's exit status for i = 1:batchSize diff --git a/lib/cusolver/dense_generic.jl b/lib/cusolver/dense_generic.jl index f203d4b2b2..b1e67453d0 100644 --- a/lib/cusolver/dense_generic.jl +++ b/lib/cusolver/dense_generic.jl @@ -17,24 +17,23 @@ function Xpotrf!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat} chkuplo(uplo) n = checksquare(A) lda = max(1, stride(A, 2)) - info = CuVector{Cint}(undef, 1) params = CuSolverParameters() + dh = dense_handle() function bufferSize() out_cpu = Ref{Csize_t}(0) out_gpu = Ref{Csize_t}(0) - cusolverDnXpotrf_bufferSize(dense_handle(), params, uplo, n, + cusolverDnXpotrf_bufferSize(dh, params, uplo, n, T, A, lda, T, out_gpu, out_cpu) out_gpu[], out_cpu[] end - with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu - cusolverDnXpotrf(dense_handle(), params, uplo, n, T, A, lda, T, + with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu + cusolverDnXpotrf(dh, params, uplo, n, T, A, lda, T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu, - sizeof(buffer_cpu), info) + sizeof(buffer_cpu), dh.info) end - flag = @allowscalar info[1] - unsafe_free!(info) + flag = @allowscalar dh.info[1] chkargsok(BlasInt(flag)) A, flag end @@ -47,13 +46,12 @@ function Xpotrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuVecOrMat{T}) whe (p ≠ n) && throw(DimensionMismatch("first dimension of B, $p, must match second dimension of A, $n")) lda = max(1, stride(A, 2)) ldb = max(1, stride(B, 2)) - info = CuVector{Cint}(undef, 1) params = CuSolverParameters() + dh = dense_handle() - cusolverDnXpotrs(dense_handle(), params, uplo, n, nrhs, T, A, lda, T, B, ldb, info) + cusolverDnXpotrs(dh, params, uplo, n, nrhs, T, A, lda, T, B, ldb, dh.info) - flag = @allowscalar info[1] - unsafe_free!(info) + flag = @allowscalar dh.info[1] chkargsok(BlasInt(flag)) B end @@ -62,24 +60,23 @@ end function Xgetrf!(A::StridedCuMatrix{T}, ipiv::CuVector{Int64}) where {T <: BlasFloat} m, n = size(A) lda = max(1, stride(A, 2)) - info = CuVector{Cint}(undef, 1) params = CuSolverParameters() + dh = dense_handle() function bufferSize() out_cpu = Ref{Csize_t}(0) out_gpu = Ref{Csize_t}(0) - cusolverDnXgetrf_bufferSize(dense_handle(), params, m, n, T, + cusolverDnXgetrf_bufferSize(dh, params, m, n, T, A, lda, T, out_gpu, out_cpu) out_gpu[], out_cpu[] end - with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu - cusolverDnXgetrf(dense_handle(), params, m, n, T, A, lda, ipiv, + with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu + cusolverDnXgetrf(dh, params, m, n, T, A, lda, ipiv, T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu, - sizeof(buffer_cpu), info) + sizeof(buffer_cpu), dh.info) end - flag = @allowscalar info[1] - unsafe_free!(info) + flag = @allowscalar dh.info[1] chkargsok(BlasInt(flag)) A, ipiv, flag end @@ -97,13 +94,12 @@ function Xgetrs!(trans::Char, A::StridedCuMatrix{T}, ipiv::CuVector{Int64}, B::S nrhs = size(B, 2) lda = max(1, stride(A, 2)) ldb = max(1, stride(B, 2)) - info = CuVector{Cint}(undef, 1) params = CuSolverParameters() + dh = dense_handle() - cusolverDnXgetrs(dense_handle(), params, trans, n, nrhs, T, A, lda, ipiv, T, B, ldb, info) + cusolverDnXgetrs(dh, params, trans, n, nrhs, T, A, lda, ipiv, T, B, ldb, dh.info) - flag = @allowscalar info[1] - unsafe_free!(info) + flag = @allowscalar dh.info[1] chkargsok(BlasInt(flag)) B end @@ -112,24 +108,23 @@ end function Xgeqrf!(A::StridedCuMatrix{T}, tau::CuVector{T}) where {T <: BlasFloat} m, n = size(A) lda = max(1, stride(A, 2)) - info = CuVector{Cint}(undef, 1) params = CuSolverParameters() + dh = dense_handle() function bufferSize() out_cpu = Ref{Csize_t}(0) out_gpu = Ref{Csize_t}(0) - cusolverDnXgeqrf_bufferSize(dense_handle(), params, m, n, T, A, + cusolverDnXgeqrf_bufferSize(dh, params, m, n, T, A, lda, T, tau, T, out_gpu, out_cpu) out_gpu[], out_cpu[] end - with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu - cusolverDnXgeqrf(dense_handle(), params, m, n, T, A, + with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu + cusolverDnXgeqrf(dh, params, m, n, T, A, lda, T, tau, T, buffer_gpu, sizeof(buffer_gpu), - buffer_cpu, sizeof(buffer_cpu), info) + buffer_cpu, sizeof(buffer_cpu), dh.info) end - flag = @allowscalar info[1] - unsafe_free!(info) + flag = @allowscalar dh.info[1] chkargsok(BlasInt(flag)) A, tau end @@ -147,23 +142,23 @@ function sytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::Stride nrhs = size(B, 2) lda = max(1, stride(A, 2)) ldb = max(1, stride(B, 2)) - info = CuVector{Cint}(undef, 1) + dh = dense_handle() function bufferSize() out_cpu = Ref{Csize_t}(0) out_gpu = Ref{Csize_t}(0) - cusolverDnXsytrs_bufferSize(dense_handle(), uplo, n, nrhs, T, A, + cusolverDnXsytrs_bufferSize(dh, uplo, n, nrhs, T, A, lda, p, T, B, ldb, out_gpu, out_cpu) out_gpu[], out_cpu[] end - with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu - cusolverDnXsytrs(dense_handle(), uplo, n, nrhs, T, A, lda, p, + with_workspaces(dh.workspace_gpu, dh.workspace_cpu, + bufferSize()...) do buffer_gpu, buffer_cpu + cusolverDnXsytrs(dh, uplo, n, nrhs, T, A, lda, p, T, B, ldb, buffer_gpu, sizeof(buffer_gpu), - buffer_cpu, sizeof(buffer_cpu), info) + buffer_cpu, sizeof(buffer_cpu), dh.info) end - flag = @allowscalar info[1] - unsafe_free!(info) + flag = @allowscalar dh.info[1] chkargsok(BlasInt(flag)) B end @@ -174,20 +169,21 @@ function trtri!(uplo::Char, diag::Char, A::StridedCuMatrix{T}) where {T <: BlasF chkdiag(diag) n = checksquare(A) lda = max(1, stride(A, 2)) - info = CuVector{Cint}(undef, 1) + dh = dense_handle() function bufferSize() out_cpu = Ref{Csize_t}(0) out_gpu = Ref{Csize_t}(0) - cusolverDnXtrtri_bufferSize(dense_handle(), uplo, diag, n, T, A, lda, out_gpu, out_cpu) + cusolverDnXtrtri_bufferSize(dh, uplo, diag, n, T, A, lda, out_gpu, out_cpu) out_gpu[], out_cpu[] end - with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu - cusolverDnXtrtri(dense_handle(), uplo, diag, n, T, A, lda, buffer_gpu, sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), info) + with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu + cusolverDnXtrtri(dh, uplo, diag, n, T, A, lda, + buffer_gpu, sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), + dh.info) end - flag = @allowscalar info[1] - unsafe_free!(info) + flag = @allowscalar dh.info[1] chkargsok(BlasInt(flag)) A end @@ -204,16 +200,17 @@ function larft!(direct::Char, storev::Char, v::StridedCuMatrix{T}, tau::StridedC ldv = max(1, stride(v, 2)) ldt = max(1, stride(t, 2)) params = CuSolverParameters() + dh = dense_handle() function bufferSize() out_cpu = Ref{Csize_t}(0) out_gpu = Ref{Csize_t}(0) - cusolverDnXlarft_bufferSize(dense_handle(), params, direct, storev, n, k, T, + cusolverDnXlarft_bufferSize(dh, params, direct, storev, n, k, T, v, ldv, T, tau, T, t, ldt, T, out_gpu, out_cpu) out_gpu[], out_cpu[] end - with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu - cusolverDnXlarft(dense_handle(), params, direct, storev, n, k, T, v, ldv, T, tau, T, t, + with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu + cusolverDnXlarft(dh, params, direct, storev, n, k, T, v, ldv, T, tau, T, t, ldt, T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu)) end @@ -247,25 +244,24 @@ function Xgesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{T}) where {T <: Bla lda = max(1, stride(A, 2)) ldu = U == CU_NULL ? 1 : max(1, stride(U, 2)) ldvt = Vt == CU_NULL ? 1 : max(1, stride(Vt, 2)) - info = CuVector{Cint}(undef, 1) params = CuSolverParameters() + dh = dense_handle() function bufferSize() out_cpu = Ref{Csize_t}(0) out_gpu = Ref{Csize_t}(0) - cusolverDnXgesvd_bufferSize(dense_handle(), params, jobu, jobvt, + cusolverDnXgesvd_bufferSize(dh, params, jobu, jobvt, m, n, T, A, lda, R, Σ, T, U, ldu, T, Vt, ldvt, T, out_gpu, out_cpu) out_gpu[], out_cpu[] end - with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu - cusolverDnXgesvd(dense_handle(), params, jobu, jobvt, m, n, T, A, + with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu + cusolverDnXgesvd(dh, params, jobu, jobvt, m, n, T, A, lda, R, Σ, T, U, ldu, T, Vt, ldvt, T, buffer_gpu, - sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), info) + sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), dh.info) end - flag = @allowscalar info[1] - unsafe_free!(info) + flag = @allowscalar dh.info[1] chklapackerror(BlasInt(flag)) U, Σ, Vt end @@ -298,27 +294,26 @@ function Xgesvdp!(jobz::Char, econ::Int, A::StridedCuMatrix{T}) where {T <: Blas lda = max(1, stride(A, 2)) ldu = U == CU_NULL ? 1 : max(1, stride(U, 2)) ldv = V == CU_NULL ? 1 : max(1, stride(V, 2)) - info = CuVector{Cint}(undef, 1) h_err_sigma = Ref{Cdouble}(0) params = CuSolverParameters() + dh = dense_handle() function bufferSize() out_cpu = Ref{Csize_t}(0) out_gpu = Ref{Csize_t}(0) - cusolverDnXgesvdp_bufferSize(dense_handle(), params, jobz, econ, m, + cusolverDnXgesvdp_bufferSize(dh, params, jobz, econ, m, n, T, A, lda, R, Σ, T, U, ldu, T, V, ldv, T, out_gpu, out_cpu) out_gpu[], out_cpu[] end - with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu - cusolverDnXgesvdp(dense_handle(), params, jobz, econ, m, n, T, A, lda, R, + with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu + cusolverDnXgesvdp(dh, params, jobz, econ, m, n, T, A, lda, R, Σ, T, U, ldu, T, V, ldv, T, buffer_gpu, sizeof(buffer_gpu), - buffer_cpu, sizeof(buffer_cpu), info, h_err_sigma) + buffer_cpu, sizeof(buffer_cpu), dh.info, h_err_sigma) end - flag = @allowscalar info[1] - unsafe_free!(info) + flag = @allowscalar dh.info[1] chklapackerror(BlasInt(flag)) if jobz == 'N' return Σ, h_err_sigma[] @@ -353,26 +348,25 @@ function Xgesvdr!(jobu::Char, jobv::Char, A::StridedCuMatrix{T}, k::Integer; lda = max(1, stride(A, 2)) ldu = U == CU_NULL ? 1 : max(1, stride(U, 2)) ldv = V == CU_NULL ? 1 : max(1, stride(V, 2)) - info = CuVector{Cint}(undef, 1) params = CuSolverParameters() + dh = dense_handle() function bufferSize() out_cpu = Ref{Csize_t}(0) out_gpu = Ref{Csize_t}(0) - cusolverDnXgesvdr_bufferSize(dense_handle(), params, jobu, jobv, + cusolverDnXgesvdr_bufferSize(dh, params, jobu, jobv, m, n, k, p, niters, T, A, lda, R, Σ, T, U, ldu, T, V, ldv, T, out_gpu, out_cpu) out_gpu[], out_cpu[] end - with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu - cusolverDnXgesvdr(dense_handle(), params, jobu, jobv, m, n, + with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu + cusolverDnXgesvdr(dh, params, jobu, jobv, m, n, k, p, niters, T, A, lda, R, Σ, T, U, ldu, T, V, ldv, T, buffer_gpu, sizeof(buffer_gpu), - buffer_cpu, sizeof(buffer_cpu), info) + buffer_cpu, sizeof(buffer_cpu), dh.info) end - flag = @allowscalar info[1] - unsafe_free!(info) + flag = @allowscalar dh.info[1] chklapackerror(BlasInt(flag)) U, Σ, V end @@ -383,25 +377,24 @@ function Xsyevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: Blas n = checksquare(A) R = real(T) lda = max(1, stride(A, 2)) - info = CuVector{Cint}(undef, 1) W = CuVector{R}(undef, n) params = CuSolverParameters() + dh = dense_handle() function bufferSize() out_cpu = Ref{Csize_t}(0) out_gpu = Ref{Csize_t}(0) - cusolverDnXsyevd_bufferSize(dense_handle(), params, jobz, uplo, n, + cusolverDnXsyevd_bufferSize(dh, params, jobz, uplo, n, T, A, lda, R, W, T, out_gpu, out_cpu) out_gpu[], out_cpu[] end - with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu - cusolverDnXsyevd(dense_handle(), params, jobz, uplo, n, T, A, + with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu + cusolverDnXsyevd(dh, params, jobz, uplo, n, T, A, lda, R, W, T, buffer_gpu, sizeof(buffer_gpu), - buffer_cpu, sizeof(buffer_cpu), info) + buffer_cpu, sizeof(buffer_cpu), dh.info) end - flag = @allowscalar info[1] - unsafe_free!(info) + flag = @allowscalar dh.info[1] chkargsok(BlasInt(flag)) if jobz == 'N' @@ -421,29 +414,28 @@ function Xsyevdx!(jobz::Char, range::Char, uplo::Char, A::StridedCuMatrix{T}; (range == 'I') && !(1 ≤ il ≤ iu ≤ n) && throw(ArgumentError("illegal choice of eigenvalue indices (il = $il, iu = $iu), which must be between 1 and n = $n")) (range == 'V') && (vl ≥ vu) && throw(ArgumentError("lower boundary, $vl, must be less than upper boundary, $vu")) lda = max(1, stride(A, 2)) - info = CuVector{Cint}(undef, 1) W = CuVector{R}(undef, n) vl = Ref{R}(vl) vu = Ref{R}(vu) h_meig = Ref{Int64}(0) params = CuSolverParameters() + dh = dense_handle() function bufferSize() out_cpu = Ref{Csize_t}(0) out_gpu = Ref{Csize_t}(0) - cusolverDnXsyevdx_bufferSize(dense_handle(), params, jobz, range, uplo, n, + cusolverDnXsyevdx_bufferSize(dh, params, jobz, range, uplo, n, T, A, lda, vl, vu, il, iu, h_meig, R, W, T, out_gpu, out_cpu) out_gpu[], out_cpu[] end - with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu - cusolverDnXsyevdx(dense_handle(), params, jobz, range, uplo, n, T, A, + with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu + cusolverDnXsyevdx(dh, params, jobz, range, uplo, n, T, A, lda, vl, vu, il, iu, h_meig, R, W, T, buffer_gpu, - sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), info) + sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), dh.info) end - flag = @allowscalar info[1] - unsafe_free!(info) + flag = @allowscalar dh.info[1] chkargsok(BlasInt(flag)) if jobz == 'N'