Skip to content

Commit

Permalink
cached workspaces via a fat handle for CUSOLVER/dense{_generic}
Browse files Browse the repository at this point in the history
  • Loading branch information
bjarthur committed Aug 16, 2024
1 parent 69043ee commit 8865fb6
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 225 deletions.
22 changes: 20 additions & 2 deletions lib/cusolver/CUSOLVER.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,43 @@ end
const idle_dense_handles =
HandleCache{CuContext,cusolverDnHandle_t}(dense_handle_ctor, dense_handle_dtor)

# fat handle, includes a cache
struct cusolverDnHandle
handle::cusolverDnHandle_t
workspace_gpu::CuVector{UInt8}
workspace_cpu::Vector{UInt8}
info::CuVector{Cint}
end
Base.unsafe_convert(::Type{Ptr{cusolverDnContext}}, handle::cusolverDnHandle) =
handle.handle

function dense_handle()
cuda = CUDA.active_state()

# every task maintains library state per device
LibraryState = @NamedTuple{handle::cusolverDnHandle_t, stream::CuStream}
LibraryState = @NamedTuple{handle::cusolverDnHandle, stream::CuStream}
states = get!(task_local_storage(), :CUSOLVER_dense) do
Dict{CuContext,LibraryState}()
end::Dict{CuContext,LibraryState}

# get library state
@noinline function new_state(cuda)
new_handle = pop!(idle_dense_handles, cuda.context)

workspace_gpu = CuVector{UInt8}(undef, 0)
workspace_cpu = Vector{UInt8}(undef, 0)
info = CuVector{Cint}(undef, 1)
fat_handle = cusolverDnHandle(new_handle, workspace_gpu, workspace_cpu, info)

finalizer(current_task()) do task
CUDA.unsafe_free!(workspace_gpu)
CUDA.unsafe_free!(info)
push!(idle_dense_handles, cuda.context, new_handle)
end

cusolverDnSetStream(new_handle, cuda.stream)

(; handle=new_handle, cuda.stream)
(; handle=fat_handle, cuda.stream)
end
state = get!(states, cuda.context) do
new_state(cuda)
Expand Down
Loading

0 comments on commit 8865fb6

Please sign in to comment.