diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index 6cc91e191bf8fe..17d0ad83ecf6b1 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -31,6 +31,11 @@ HIP_RUNTIME_PATH = '%{hip_runtime_path}' HIP_RUNTIME_LIBRARY = '%{hip_runtime_library}' ROCR_RUNTIME_PATH = '%{rocr_runtime_path}' ROCR_RUNTIME_LIBRARY = '%{rocr_runtime_library}' +RCCL_RUNTIME_PATH = '%{rccl_runtime_path}' +MIOPEN_RUNTIME_PATH = '%{miopen_runtime_path}' +ROCBLAS_RUNTIME_PATH = '%{rocblas_runtime_path}' +HIPFFT_RUNTIME_PATH = '%{hipfft_runtime_path}' +ROCRAND_RUNTIME_PATH = '%{rocrand_runtime_path}' VERBOSE = '%{crosstool_verbose}'=='1' CPU_COMPILER_IS_CLANG = '%{crosstool_clang}'=='1' @@ -258,6 +263,16 @@ def main(): gpu_linker_flags.append('-L' + HIP_RUNTIME_PATH) gpu_linker_flags.append('-Wl,-rpath=' + HIP_RUNTIME_PATH) gpu_linker_flags.append('-l' + HIP_RUNTIME_LIBRARY) + gpu_linker_flags.append('-L' + RCCL_RUNTIME_PATH) + gpu_linker_flags.append('-Wl,-rpath=' + RCCL_RUNTIME_PATH) + gpu_linker_flags.append('-L' + ROCBLAS_RUNTIME_PATH) + gpu_linker_flags.append('-Wl,-rpath=' + ROCBLAS_RUNTIME_PATH) + gpu_linker_flags.append('-L' + MIOPEN_RUNTIME_PATH) + gpu_linker_flags.append('-Wl,-rpath=' + MIOPEN_RUNTIME_PATH) + gpu_linker_flags.append('-L' + ROCRAND_RUNTIME_PATH) + gpu_linker_flags.append('-Wl,-rpath=' + ROCRAND_RUNTIME_PATH) + gpu_linker_flags.append('-L' + HIPFFT_RUNTIME_PATH) + gpu_linker_flags.append('-Wl,-rpath=' + HIPFFT_RUNTIME_PATH) gpu_linker_flags.append("-lrt") gpu_linker_flags.append("-lstdc++") diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 7395f73ff4c7a8..234b38e41e0dbc 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -778,6 +778,11 @@ def _create_local_rocm_repository(repository_ctx): "%{rocr_runtime_library}": "hsa-runtime64", "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", "%{hip_runtime_library}": "amdhip64", + "%{rccl_runtime_path}": rocm_config.rocm_paths["RCCL"] + "/lib", + "%{rocblas_runtime_path}": rocm_config.rocm_paths["ROCBLAS"] + "/lib", + "%{miopen_runtime_path}": rocm_config.rocm_paths["MIOPEN"] + "/lib", + "%{hipfft_runtime_path}": rocm_config.rocm_paths["HIPFFT"] + "/lib", + "%{rocrand_runtime_path}": rocm_config.rocm_paths["ROCRAND"] + "/lib", "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), "%{gcc_host_compiler_path}": str(cc), "%{crosstool_clang}": "1" if _is_clang_enabled(repository_ctx) else "0", diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc b/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc index a1f3eba243afa4..c5373b9fb2e859 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc @@ -40,6 +40,13 @@ string RocmRoot() { #endif } -string RocdlRoot() { return io::JoinPath(RocmRoot(), "amdgcn/bitcode"); } +string RocdlRoot() { + if (const char* device_lib_path_env = std::getenv("HIP_DEVICE_LIB_PATH")) { + return device_lib_path_env; + } + else{ + return io::JoinPath(RocmRoot(), "amdgcn/bitcode"); + } +} } // namespace tsl diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index bbdaaa6317181f..cac5372c161d80 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -854,7 +854,10 @@ absl::StatusOr> EmitModuleToHsaco( } // Locate lld. std::string lld_path; - if (std::getenv("ROCM_PATH")) { + if (std::getenv("LLVM_PATH")) { + lld_path = tsl::io::JoinPath(std::getenv("LLVM_PATH"), "bin"); + } + else if (std::getenv("ROCM_PATH")) { lld_path = tsl::io::JoinPath(std::getenv("ROCM_PATH"), "llvm/bin"); } else {