Skip to content

Commit

Permalink
Split GpuTimer into CUDA and ROCm specific implementations
Browse files Browse the repository at this point in the history
This requires the following changes:

- Move GpuEvent::Record as `RecordEvent` into `CudaStream` and `RocmStream`
- Move `GpuStream::WaitFor` into `CudaExecutor` and `RocmExecutor`
- `CudaStream` and `RocmStream` get a factory function instead of having an init function.
- The corresponding GpuDriver functions move into the .cc files where they get called.

PiperOrigin-RevId: 683837514
  • Loading branch information
beckerhe authored and tensorflower-gardener committed Oct 9, 2024
1 parent 7de0fbe commit f4a4639
Show file tree
Hide file tree
Showing 27 changed files with 702 additions and 410 deletions.
1 change: 0 additions & 1 deletion third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2892,7 +2892,6 @@ xla_test(
"//xla/stream_executor:device_description",
"//xla/stream_executor:platform",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/gpu:gpu_timer",
"//xla/stream_executor/gpu:mock_gpu_executor",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/service/gpu/determinism_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ limitations under the License.
#include "xla/service/gpu/tests/gpu_codegen_test.h"
#include "xla/service/platform_util.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/gpu/gpu_timer.h"
#include "xla/stream_executor/gpu/mock_gpu_executor.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor.h"
Expand Down
55 changes: 52 additions & 3 deletions third_party/xla/xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -647,10 +647,15 @@ cc_library(
"gpu",
],
deps = [
":cuda_driver",
":cuda_status",
"//xla/stream_executor:event",
"//xla/stream_executor/gpu:context",
"//xla/stream_executor/gpu:gpu_driver_header",
"//xla/stream_executor/gpu:gpu_event",
"//xla/stream_executor/gpu:gpu_types_header",
"//xla/stream_executor/gpu:scoped_activate_context",
"@com_google_absl//absl/base",
"@com_google_absl//absl/status",
"@local_config_cuda//cuda:cuda_headers",
],
)
Expand Down Expand Up @@ -964,12 +969,14 @@ cc_library(
],
deps = [
":cuda_collectives",
":cuda_driver",
":cuda_event", # buildcleaner: keep
":cuda_kernel", # buildcleaner: keep
":cuda_platform_id",
":cuda_runtime",
":cuda_status",
":cuda_stream",
":cuda_timer",
":cuda_version_parser",
":delay_kernel_cuda",
"//xla/stream_executor",
Expand All @@ -994,7 +1001,6 @@ cc_library(
"//xla/stream_executor/gpu:gpu_kernel_header",
"//xla/stream_executor/gpu:gpu_semaphore",
"//xla/stream_executor/gpu:gpu_stream_header",
"//xla/stream_executor/gpu:gpu_timer",
"//xla/stream_executor/gpu:gpu_types_header",
"//xla/stream_executor/gpu:read_numa_node",
"//xla/stream_executor/gpu:scoped_activate_context",
Expand Down Expand Up @@ -1146,5 +1152,48 @@ cc_library(
"cuda-only",
"gpu",
],
deps = ["//xla/stream_executor/gpu:gpu_stream"],
deps = [
":cuda_event",
":cuda_status",
"//xla/stream_executor:event",
"//xla/stream_executor:platform",
"//xla/stream_executor:stream",
"//xla/stream_executor/gpu:context",
"//xla/stream_executor/gpu:gpu_event",
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/gpu:gpu_stream",
"//xla/stream_executor/gpu:scoped_activate_context",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@local_config_cuda//cuda:cuda_headers",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
],
)

cc_library(
name = "cuda_timer",
srcs = ["cuda_timer.cc"],
hdrs = ["cuda_timer.h"],
tags = [
"cuda-only",
"gpu",
],
deps = [
":cuda_status",
"//xla/stream_executor:event_based_timer",
"//xla/stream_executor/gpu:context",
"//xla/stream_executor/gpu:gpu_event",
"//xla/stream_executor/gpu:gpu_semaphore",
"//xla/stream_executor/gpu:gpu_stream",
"//xla/stream_executor/gpu:scoped_activate_context",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/time",
"@local_config_cuda//cuda:cuda_headers",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
],
)
69 changes: 0 additions & 69 deletions third_party/xla/xla/stream_executor/cuda/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1005,25 +1005,6 @@ absl::Status GpuDriver::AddStreamCallback(Context* context, CUstream stream,
return cuda::ToStatus(cuLaunchHostFunc(stream, callback, data));
}

absl::StatusOr<GpuStreamHandle> GpuDriver::CreateStream(Context* context,
int priority) {
ScopedActivateContext activated(context);
GpuStreamHandle stream;
// If the priority is 0, then use the previous api to create the stream with
// the default priority for backward compatibility. Probably there is no
// difference in using the new api call but leaving it as is for now.
if (priority == 0) {
TF_RETURN_IF_ERROR(
cuda::ToStatus(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)));
} else {
TF_RETURN_IF_ERROR(cuda::ToStatus(
cuStreamCreateWithPriority(&stream, CU_STREAM_NON_BLOCKING, priority)));
}

VLOG(2) << "successfully created stream " << stream << " for context "
<< context << " on thread";
return stream;
}

void GpuDriver::DestroyStream(Context* context, GpuStreamHandle stream) {
if (stream == nullptr) {
Expand Down Expand Up @@ -1155,23 +1136,6 @@ bool GpuDriver::HostUnregister(Context* context, void* location) {
return true;
}

int GpuDriver::GetGpuStreamPriority(
Context* context, stream_executor::StreamPriority stream_priority) {
ScopedActivateContext activation(context);
if (stream_priority == stream_executor::StreamPriority::Default) {
return 0;
}
int lowest, highest;
auto status = cuda::ToStatus(cuCtxGetStreamPriorityRange(&lowest, &highest));
if (!status.ok()) {
LOG(ERROR)
<< "Could not query stream priority range. Returning default priority.";
return 0;
}
return stream_priority == stream_executor::StreamPriority::Highest ? highest
: lowest;
}

absl::Status GpuDriver::DestroyEvent(Context* context, CUevent* event) {
if (*event == nullptr) {
return absl::InvalidArgumentError("input event cannot be null");
Expand All @@ -1181,39 +1145,6 @@ absl::Status GpuDriver::DestroyEvent(Context* context, CUevent* event) {
return cuda::ToStatus(cuEventDestroy(*event), "Error destroying CUDA event");
}

absl::Status GpuDriver::RecordEvent(Context* context, CUevent event,
CUstream stream) {
ScopedActivateContext activated{context};
return cuda::ToStatus(cuEventRecord(event, stream),
"Error recording CUDA event");
}

absl::StatusOr<float> GpuDriver::GetEventElapsedTime(Context* context,
CUevent start,
CUevent stop) {
ScopedActivateContext activated{context};
// The stop event must have completed in order for cuEventElapsedTime to
// work.
auto status = cuda::ToStatus(cuEventSynchronize(stop));
if (!status.ok()) {
LOG(ERROR) << "failed to synchronize the stop event: " << status;
return false;
}

float elapsed_milliseconds;

TF_RETURN_IF_ERROR(
cuda::ToStatus(cuEventElapsedTime(&elapsed_milliseconds, start, stop)));

return elapsed_milliseconds;
}

absl::Status GpuDriver::WaitStreamOnEvent(Context* context, CUstream stream,
CUevent event) {
ScopedActivateContext activation(context);
return cuda::ToStatus(cuStreamWaitEvent(stream, event, 0 /* = flags */));
}

absl::Status GpuDriver::SynchronizeStream(Context* context, CUstream stream) {
ScopedActivateContext activated{context};
CHECK(stream != nullptr);
Expand Down
19 changes: 18 additions & 1 deletion third_party/xla/xla/stream_executor/cuda/cuda_event.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,25 @@ limitations under the License.

#include "xla/stream_executor/cuda/cuda_event.h"

#include <cstdint>

#include "absl/base/casts.h"
#include "absl/status/status.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/stream_executor/cuda/cuda_driver.h"
#include "xla/stream_executor/cuda/cuda_status.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/gpu/context.h"
#include "xla/stream_executor/gpu/scoped_activate_context.h"

namespace stream_executor {
namespace gpu {
namespace {
absl::Status WaitStreamOnEvent(Context* context, CUstream stream,
CUevent event) {
ScopedActivateContext activation(context);
return cuda::ToStatus(cuStreamWaitEvent(stream, event, 0 /* = flags */));
}
} // namespace

Event::Status CudaEvent::PollForStatus() {
ScopedActivateContext activated(context());
Expand All @@ -34,5 +46,10 @@ Event::Status CudaEvent::PollForStatus() {
return Event::Status::kError;
}

absl::Status CudaEvent::WaitForEventOnExternalStream(std::intptr_t stream) {
return WaitStreamOnEvent(context(), absl::bit_cast<CUstream>(stream),
gpu_event());
}

} // namespace gpu
} // namespace stream_executor
6 changes: 6 additions & 0 deletions third_party/xla/xla/stream_executor/cuda/cuda_event.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ limitations under the License.
#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_
#define XLA_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_

#include <cstdint>

#include "absl/status/status.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/gpu/context.h"
#include "xla/stream_executor/gpu/gpu_event.h"

namespace stream_executor::gpu {
Expand All @@ -29,6 +33,8 @@ class CudaEvent : public GpuEvent {
explicit CudaEvent(Context *context) : GpuEvent(context) {}

Event::Status PollForStatus() override;

absl::Status WaitForEventOnExternalStream(std::intptr_t stream) override;
};

} // namespace stream_executor::gpu
Expand Down
14 changes: 7 additions & 7 deletions third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ limitations under the License.
#include "xla/stream_executor/cuda/cuda_runtime.h"
#include "xla/stream_executor/cuda/cuda_status.h"
#include "xla/stream_executor/cuda/cuda_stream.h"
#include "xla/stream_executor/cuda/cuda_timer.h"
#include "xla/stream_executor/cuda/cuda_version_parser.h"
#include "xla/stream_executor/cuda/delay_kernel.h"
#include "xla/stream_executor/device_description.h"
Expand All @@ -67,7 +68,6 @@ limitations under the License.
#include "xla/stream_executor/gpu/gpu_kernel.h"
#include "xla/stream_executor/gpu/gpu_semaphore.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/gpu/gpu_timer.h"
#include "xla/stream_executor/gpu/gpu_types.h"
#include "xla/stream_executor/gpu/read_numa_node.h"
#include "xla/stream_executor/gpu/scoped_activate_context.h"
Expand Down Expand Up @@ -427,10 +427,10 @@ CudaExecutor::CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) {
}
TF_ASSIGN_OR_RETURN(auto start_event, CreateGpuEvent(/*allow_timing=*/true));
TF_ASSIGN_OR_RETURN(auto stop_event, CreateGpuEvent(/*allow_timing=*/true));
TF_RETURN_IF_ERROR(start_event->Record(stream->gpu_stream()));
return std::make_unique<GpuTimer>(gpu_context(), std::move(start_event),
std::move(stop_event), stream,
std::move(semaphore));
TF_RETURN_IF_ERROR(stream->RecordEvent(start_event.get()));
return std::make_unique<CudaTimer>(gpu_context(), std::move(start_event),
std::move(stop_event), stream,
std::move(semaphore));
}

bool CudaExecutor::UnloadGpuBinary(const void* gpu_binary) {
Expand Down Expand Up @@ -811,9 +811,9 @@ absl::StatusOr<std::unique_ptr<Event>> CudaExecutor::CreateEvent() {
absl::StatusOr<std::unique_ptr<Stream>> CudaExecutor::CreateStream(
std::optional<std::variant<StreamPriority, int>> priority) {
TF_ASSIGN_OR_RETURN(auto event, CreateGpuEvent(/*allow_timing=*/false));
auto stream = std::make_unique<CudaStream>(this, std::move(event), priority);
TF_ASSIGN_OR_RETURN(auto stream,
CudaStream::Create(this, std::move(event), priority));
absl::MutexLock l(&alive_gpu_streams_mu_);
TF_RETURN_IF_ERROR(stream->Init());
auto gpu_stream = stream->gpu_stream();
alive_gpu_streams_[gpu_stream] = stream.get();
return std::move(stream);
Expand Down
Loading

0 comments on commit f4a4639

Please sign in to comment.