From bb2d62167003e73b18c1f918928c0037a8b2c78b Mon Sep 17 00:00:00 2001 From: Harsha HS Date: Thu, 21 Nov 2024 00:22:39 +0000 Subject: [PATCH] [ROCm] Fix kernel launch dimension Launch dimension should be of the form ((block.x, 1, 1), (thread.x, thready, 1)) to accommodate checks in (parallel_loop_emitter.cc)[https://github.com/openxla/xla/blob/main/xla/service/gpu/parallel_loop_emitter.cc#L169-L171] --- xla/service/gpu/BUILD | 1 + xla/service/gpu/launch_dimensions.cc | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 269fa70ddf27c..5dc80604ab7c8 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -135,6 +135,7 @@ cc_library( deps = [ "//xla:shape_util", "//xla:util", + "//xla/service:platform_util", "//xla/stream_executor:device_description", "//xla/stream_executor:launch_dim", "@com_google_absl//absl/log", diff --git a/xla/service/gpu/launch_dimensions.cc b/xla/service/gpu/launch_dimensions.cc index f9e28995d0996..21cc4758dbc56 100644 --- a/xla/service/gpu/launch_dimensions.cc +++ b/xla/service/gpu/launch_dimensions.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "xla/service/platform_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" @@ -37,11 +38,19 @@ LaunchDimensions CalculateLaunchDimensions( num_elements = CeilOfRatio(num_elements, int64_t{dim_config.unroll_factor}); const int kWarpSchedulers = 4; - int64_t threads_per_block = std::min( + int64_t threads_per_block_x = std::min( gpu_device_info.threads_per_warp() * kWarpSchedulers, num_elements); - int64_t num_blocks = CeilOfRatio(num_elements, threads_per_block); + int64_t num_blocks = CeilOfRatio(num_elements, threads_per_block_x); + CHECK(num_blocks < gpu_device_info.block_dim_limit().x); + int threads_per_block_y = 1; + if (xla::PlatformUtil::CanonicalPlatformName("gpu").value() == "rocm") { + while ((num_blocks * threads_per_block_x) > std::numeric_limits::max()) { + threads_per_block_x /= 2; + threads_per_block_y *= 2; + } + } return LaunchDimensions(se::BlockDim(num_blocks, 1, 1), - se::ThreadDim(threads_per_block, 1, 1)); + se::ThreadDim(threads_per_block_x, threads_per_block_y, 1)); } } // namespace gpu