Skip to content

Commit

Permalink
Reverts 30b2ecd
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689151499
  • Loading branch information
klucke authored and tensorflower-gardener committed Oct 23, 2024
1 parent 25e1991 commit 0101689
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 8 deletions.
7 changes: 7 additions & 0 deletions third_party/xla/xla/stream_executor/cuda/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,13 @@ int GpuDriver::GetDeviceCount() {
return device_count;
}

absl::StatusOr<int32_t> GpuDriver::GetDriverVersion() {
int32_t version;
TF_RETURN_IF_ERROR(cuda::ToStatus(cuDriverGetVersion(&version),
"Could not get driver version"));
return version;
}

absl::StatusOr<size_t> GpuDriver::GraphGetNodeCount(GpuGraphHandle graph) {
size_t num_nodes;
TF_RETURN_IF_ERROR(
Expand Down
6 changes: 2 additions & 4 deletions third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1166,11 +1166,9 @@ CudaExecutor::CreateDeviceDescription(int device_ordinal) {

DeviceDescription desc;

int32_t driver_version = 0;
TF_RETURN_IF_ERROR(cuda::ToStatus(cuDriverGetVersion(&driver_version),
"Could not get driver version"));
desc.set_driver_version(
ParseCudaVersion(driver_version).value_or(SemanticVersion{0, 0, 0}));
ParseCudaVersion(GpuDriver::GetDriverVersion().value_or(0))
.value_or(SemanticVersion{0, 0, 0}));
desc.set_runtime_version(
ParseCudaVersion(CudaRuntime::GetRuntimeVersion().value_or(0))
.value_or(SemanticVersion{0, 0, 0}));
Expand Down
31 changes: 31 additions & 0 deletions third_party/xla/xla/stream_executor/gpu/gpu_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,43 @@ class GpuDriver {
GpuGraphNodeHandle node,
GpuGraphHandle child);

// The CUDA stream callback type signature.
// The data passed to AddStreamCallback is subsequently passed to this
// callback when it fires.
//
// Some notable things:
// * Callbacks must not make any CUDA API calls.
// * Callbacks from independent streams execute in an undefined order and may
// be serialized.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gab95a78143bae7f21eebb978f91e7f3f
typedef void (*StreamCallback)(void* data);

// Blocks the calling thread until the operations enqueued onto stream have
// been completed, via cuStreamSynchronize.
//
// TODO(leary) if a pathological thread enqueues operations onto the stream
// while another thread blocks like this, can you wind up waiting an unbounded
// amount of time?
//
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g15e49dd91ec15991eb7c0a741beb7dad
static absl::Status SynchronizeStream(Context* context,
GpuStreamHandle stream);

// -- Context- and device-independent calls.

// Returns the number of visible CUDA device via cuDeviceGetCount.
// This should correspond to the set of device ordinals available.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g52b5ce05cb8c5fb6831b2c0ff2887c74
static int GetDeviceCount();

// Returns the driver version number via cuDriverGetVersion.
// This is, surprisingly, NOT the actual driver version (e.g. 331.79) but,
// instead, the CUDA toolkit release number that this driver is compatible
// with; e.g. 6000 (for a CUDA 6.0 compatible driver) or 6050 (for a CUDA 6.5
// compatible driver).
//
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VERSION.html#group__CUDA__VERSION_1g8b7a10395392e049006e61bcdc8ebe71
static absl::StatusOr<int32_t> GetDriverVersion();
};

} // namespace gpu
Expand Down
7 changes: 7 additions & 0 deletions third_party/xla/xla/stream_executor/rocm/rocm_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,4 +494,11 @@ int GpuDriver::GetDeviceCount() {
return device_count;
}

absl::StatusOr<int32_t> GpuDriver::GetDriverVersion() {
int32_t version;
TF_RETURN_IF_ERROR(ToStatus(wrap::hipDriverGetVersion(&version),
"Could not get driver version"));
return version;
}

} // namespace stream_executor::gpu
6 changes: 2 additions & 4 deletions third_party/xla/xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1054,11 +1054,9 @@ RocmExecutor::CreateDeviceDescription(int device_ordinal) {
desc.set_runtime_version(
ParseRocmVersion(RocmRuntime::GetRuntimeVersion().value_or(0))
.value_or(SemanticVersion{0, 0, 0}));
int32_t driver_version = 0;
TF_RETURN_IF_ERROR(ToStatus(wrap::hipDriverGetVersion(&driver_version),
"Could not get driver version"));
desc.set_driver_version(
ParseRocmVersion(driver_version).value_or(SemanticVersion{0, 0, 0}));
ParseRocmVersion(GpuDriver::GetDriverVersion().value_or(0))
.value_or(SemanticVersion{0, 0, 0}));

// It would be better to use the PCI device ID or some other truly unique
// identifier for the GPU model. But getting this requires using NVML or
Expand Down

0 comments on commit 0101689

Please sign in to comment.