From edde4227f0819b24bab550fde309132ff945b32c Mon Sep 17 00:00:00 2001 From: Sheng Yang Date: Wed, 17 Jul 2024 02:34:06 -0700 Subject: [PATCH] PR #14950: [XLA:GPU] Register python callback on sycl platform Imported from GitHub PR https://github.com/openxla/xla/pull/14950 Copybara import of the project: -- 9ea4dd7cd70737862f124f77ebe66854247aa06a by Sheng, Yang : [XLA:GPU] Register python callback on sycl platform Merging this change closes #14950 PiperOrigin-RevId: 653153315 --- third_party/xla/xla/python/py_client.cc | 6 +++--- third_party/xla/xla/python/py_host_callback.cc | 2 +- third_party/xla/xla/service/platform_util.cc | 3 +++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/python/py_client.cc b/third_party/xla/xla/python/py_client.cc index 984374bb79d68f..3b4ebcd9901d09 100644 --- a/third_party/xla/xla/python/py_client.cc +++ b/third_party/xla/xla/python/py_client.cc @@ -96,9 +96,9 @@ limitations under the License. #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || TENSORFLOW_USE_SYCL #include "xla/python/py_client_gpu.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || TENSORFLOW_USE_SYCL namespace xla { @@ -627,7 +627,7 @@ PyClient::GetEmitPythonCallbackDescriptor(nb::callable callable, XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback", &XlaPythonCpuCallback); -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || TENSORFLOW_USE_SYCL XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( "xla_python_gpu_callback", &XlaPythonGpuCallback, absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value())); diff --git a/third_party/xla/xla/python/py_host_callback.cc b/third_party/xla/xla/python/py_host_callback.cc index 2754e567350601..2c402bc09e1094 100644 --- a/third_party/xla/xla/python/py_host_callback.cc +++ b/third_party/xla/xla/python/py_host_callback.cc @@ -127,7 +127,7 @@ PyCpuLoadedHostCallback::Create(ifrt::Client* ifrt_client, absl::Span result_shapes) { ifrt::PlatformId platform_id = ifrt_client->platform_id(); if (platform_id != CpuId() && platform_id != CudaId() && - platform_id != RocmId()) { + platform_id != RocmId() && platform_id != SyclId()) { return Unimplemented("CpuCallback supports CPU and GPU only"); } diff --git a/third_party/xla/xla/service/platform_util.cc b/third_party/xla/xla/service/platform_util.cc index a4e26feb524acf..0861cdfdb35e13 100644 --- a/third_party/xla/xla/service/platform_util.cc +++ b/third_party/xla/xla/service/platform_util.cc @@ -54,9 +54,12 @@ std::string CanonicalPlatformName(const std::string& platform_name) { } // When configured on CUDA, "gpu" and "cuda" mean the same thing. // When configured on ROCm, "gpu" and "rocm" mean the same thing. + // When configured on SYCL, "gpu" and "sycl" mean the same thing. if (lowercase_platform_name == "gpu") { #if TENSORFLOW_USE_ROCM return "rocm"; +#elif TENSORFLOW_USE_SYCL + return "sycl"; #else return "cuda"; #endif