Skip to content

Commit

Permalink
Move GpuDriver HostRegister & Unregister into CudaExecutor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 685439655
  • Loading branch information
klucke authored and tensorflower-gardener committed Oct 13, 2024
1 parent 7fd2492 commit a2cf898
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 43 deletions.
25 changes: 0 additions & 25 deletions third_party/xla/xla/stream_executor/cuda/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -808,31 +808,6 @@ void GpuDriver::HostDeallocate(Context* context, void* location) {
}
}

bool GpuDriver::HostRegister(Context* context, void* location, uint64_t bytes) {
ScopedActivateContext activation(context);
// "Portable" memory is visible to all CUDA contexts. Safe for our use model.
auto status = cuda::ToStatus(
cuMemHostRegister(location, bytes, CU_MEMHOSTREGISTER_PORTABLE));
if (!status.ok()) {
LOG(ERROR) << "error registering host memory at " << location << ": "
<< status;
return false;
}
return true;
}

bool GpuDriver::HostUnregister(Context* context, void* location) {
ScopedActivateContext activation(context);
auto status = cuda::ToStatus(cuMemHostUnregister(location));
if (!status.ok()) {
LOG(ERROR) << "error unregistering host memory at " << location << ": "
<< status;
return false;
}
return true;
}


absl::Status GpuDriver::SynchronizeStream(Context* context, CUstream stream) {
ScopedActivateContext activated{context};
CHECK(stream != nullptr);
Expand Down
20 changes: 18 additions & 2 deletions third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -906,13 +906,29 @@ bool CudaExecutor::HostMemoryRegister(void* location, uint64_t size) {
VLOG(1) << "Called StreamExecutor::HostMemoryRegister(data=" << location
<< ")";

return GpuDriver::HostRegister(gpu_context(), location, size);
ScopedActivateContext activation(gpu_context());
// "Portable" memory is visible to all CUDA contexts. Safe for our use model.
auto status = cuda::ToStatus(
cuMemHostRegister(location, size, CU_MEMHOSTREGISTER_PORTABLE));
if (!status.ok()) {
LOG(ERROR) << "error registering host memory at " << location << ": "
<< status;
return false;
}
return true;
}

bool CudaExecutor::HostMemoryUnregister(void* location) {
VLOG(1) << "Called StreamExecutor::HostUnregister(data=" << location << ")";

return GpuDriver::HostUnregister(gpu_context(), location);
ScopedActivateContext activation(gpu_context());
auto status = cuda::ToStatus(cuMemHostUnregister(location));
if (!status.ok()) {
LOG(ERROR) << "error unregistering host memory at " << location << ": "
<< status;
return false;
}
return true;
}

absl::Status CudaExecutor::SynchronousMemZero(DeviceMemoryBase* location,
Expand Down
16 changes: 0 additions & 16 deletions third_party/xla/xla/stream_executor/gpu/gpu_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,6 @@ class GpuDriver {
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management
static void HostDeallocate(Context* context, void* location);

// Registers a memory region at location of size bytes via
// cuMemHostRegister/hipHostRegister.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gf0a9fe11544326dabd743b7aa6b54223
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management
static bool HostRegister(Context* context, void* location, uint64_t bytes);

// Unregisters a memory region that was previously registered at location via
// cuMemHostUnregister/hipHostUnregister.
//
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g63f450c8125359be87b7623b1c0b2a14
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management
//
// TODO(leary) verify an error will be returned if the location wasn't
// previously registered.
static bool HostUnregister(Context* context, void* location);

// Launches a CUDA/ROCm kernel via cuLaunchKernel/hipModuleLaunchKernel.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#execution-control
Expand Down

0 comments on commit a2cf898

Please sign in to comment.