Skip to content

Commit

Permalink
PR tensorflow#14950: [XLA:GPU] Register python callback on sycl platform
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#14950

Copybara import of the project:

--
9ea4dd7cd70737862f124f77ebe66854247aa06a by Sheng, Yang <yang.sheng@intel.com>:

[XLA:GPU] Register python callback on sycl platform

Merging this change closes tensorflow#14950

PiperOrigin-RevId: 653153315
  • Loading branch information
ShengYang1 authored and tensorflower-gardener committed Jul 17, 2024
1 parent c1653b2 commit edde422
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
6 changes: 3 additions & 3 deletions third_party/xla/xla/python/py_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ limitations under the License.
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || TENSORFLOW_USE_SYCL
#include "xla/python/py_client_gpu.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || TENSORFLOW_USE_SYCL

namespace xla {

Expand Down Expand Up @@ -627,7 +627,7 @@ PyClient::GetEmitPythonCallbackDescriptor(nb::callable callable,
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback",
&XlaPythonCpuCallback);

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || TENSORFLOW_USE_SYCL
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
"xla_python_gpu_callback", &XlaPythonGpuCallback,
absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()));
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/python/py_host_callback.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ PyCpuLoadedHostCallback::Create(ifrt::Client* ifrt_client,
absl::Span<const Shape> result_shapes) {
ifrt::PlatformId platform_id = ifrt_client->platform_id();
if (platform_id != CpuId() && platform_id != CudaId() &&
platform_id != RocmId()) {
platform_id != RocmId() && platform_id != SyclId()) {
return Unimplemented("CpuCallback supports CPU and GPU only");
}

Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/platform_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,12 @@ std::string CanonicalPlatformName(const std::string& platform_name) {
}
// When configured on CUDA, "gpu" and "cuda" mean the same thing.
// When configured on ROCm, "gpu" and "rocm" mean the same thing.
// When configured on SYCL, "gpu" and "sycl" mean the same thing.
if (lowercase_platform_name == "gpu") {
#if TENSORFLOW_USE_ROCM
return "rocm";
#elif TENSORFLOW_USE_SYCL
return "sycl";
#else
return "cuda";
#endif
Expand Down

0 comments on commit edde422

Please sign in to comment.