From 4c20ab40658a698ddeb4fd03a73f63b38c6f18f2 Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Wed, 30 Oct 2024 17:05:25 -0700 Subject: [PATCH] Add IFRT se gpu client and tests accordingly. PiperOrigin-RevId: 691602092 --- third_party/xla/xla/python/pjrt_ifrt/BUILD | 31 ++++++++++++++ .../pjrt_ifrt/se_gpu_client_test_lib.cc | 42 +++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 third_party/xla/xla/python/pjrt_ifrt/se_gpu_client_test_lib.cc diff --git a/third_party/xla/xla/python/pjrt_ifrt/BUILD b/third_party/xla/xla/python/pjrt_ifrt/BUILD index 676fa32b1c581a..68155580a91204 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/BUILD +++ b/third_party/xla/xla/python/pjrt_ifrt/BUILD @@ -277,6 +277,22 @@ cc_library( alwayslink = True, ) +cc_library( + name = "se_gpu_client_test_lib", + testonly = True, + srcs = ["se_gpu_client_test_lib.cc"], + deps = [ + ":pjrt_ifrt", + "//xla/pjrt/gpu:se_gpu_pjrt_client", + "//xla/python/ifrt", + "//xla/python/ifrt:test_util", + "//xla/service:gpu_plugin", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", + ], + alwayslink = True, +) + cc_library( name = "pjrt_attribute_map_util", srcs = ["pjrt_attribute_map_util.cc"], @@ -393,6 +409,21 @@ xla_cc_test( ], ) +xla_cc_test( + name = "pjrt_client_impl_test_se_gpu", + size = "small", + srcs = [], + tags = [ + "no_oss", + "requires-gpu-nvidia:2", + ], + deps = [ + ":se_gpu_client_test_lib", + "//xla/python/ifrt:client_impl_test_lib", + "@com_google_googletest//:gtest_main", + ], +) + xla_cc_test( name = "pjrt_executable_impl_test_tfrt_cpu", size = "small", diff --git a/third_party/xla/xla/python/pjrt_ifrt/se_gpu_client_test_lib.cc b/third_party/xla/xla/python/pjrt_ifrt/se_gpu_client_test_lib.cc new file mode 100644 index 00000000000000..d7c4a96c88f3e6 --- /dev/null +++ b/third_party/xla/xla/python/pjrt_ifrt/se_gpu_client_test_lib.cc @@ -0,0 +1,42 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/status/statusor.h" +#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/test_util.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace { + +const bool kUnused = + (test_util::RegisterClientFactory( + []() -> absl::StatusOr> { + TF_ASSIGN_OR_RETURN( + auto pjrt_client, + xla::GetStreamExecutorGpuClient(xla::GpuClientOptions())); + return PjRtClient::Create(std::move(pjrt_client)); + }), + true); + +} // namespace +} // namespace ifrt +} // namespace xla