diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.h b/third_party/xla/xla/pjrt/pjrt_c_api_client.h index 95a942cc9ce49f..0a62950c32859c 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.h +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.h @@ -258,6 +258,11 @@ class PjRtCApiClient : public PjRtClient { absl::string_view platform_version() const override; + std::optional plugin_attributes() const override { + return PjRtPluginAttributes{c_api_->pjrt_api_version.major_version, + c_api_->pjrt_api_version.minor_version}; + } + // TODO(b/244756954): Rethink this function altogether PjRtRuntimeType runtime_type() const override { return PjRtRuntimeType::kTfrt; diff --git a/third_party/xla/xla/pjrt/pjrt_client.h b/third_party/xla/xla/pjrt/pjrt_client.h index 046f4b09c51f9d..69c8ebda0cbd07 100644 --- a/third_party/xla/xla/pjrt/pjrt_client.h +++ b/third_party/xla/xla/pjrt/pjrt_client.h @@ -401,6 +401,11 @@ class PjRtHostMemoryForDeviceManager { class PjRtLoadedExecutable; +struct PjRtPluginAttributes { + int64_t pjrt_c_api_major_version; + int64_t pjrt_c_api_minor_version; +}; + // Encapsulates the state of Python session with XLA. // // It is the responsibility of the client of this API to keep the PjRtClient @@ -515,6 +520,12 @@ class PjRtClient { // (e.g. the CUDA version on GPU or libtpu version on Cloud TPU). virtual absl::string_view platform_version() const = 0; + // Returns information about the underlying PJRT C API plugin if such a plugin + // is being used, otherwise returns nullopt. + virtual std::optional plugin_attributes() const { + return std::nullopt; + } + // TODO(b/244756954): Rethink this function altogether // Returns an enum that identifies the type of runtime being used under this // client. diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 9d198475ea71b9..c0a5938032acd7 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -381,6 +381,7 @@ cc_library( "//xla/pjrt:lru_cache", "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_compiler", "//xla/pjrt:pjrt_future", "//xla/pjrt:pjrt_stream_executor_client", @@ -1147,6 +1148,7 @@ cc_library( ":xla_compiler", # placeholder for index annotation deps "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:initialize", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1163,6 +1165,7 @@ cc_library( "//xla/pjrt:pjrt_api", "//xla/pjrt:pjrt_c_api_client", "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_common", "//xla/pjrt:pjrt_compiler", "//xla/pjrt:tfrt_cpu_pjrt_client", "//xla/pjrt/c:pjrt_c_api_hdrs", diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index 13c22330ac9942..32eff2d0e3211a 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -73,9 +73,11 @@ cc_library( "//xla:statusor", "//xla:util", "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_common", "//xla/python/ifrt/ir", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/functional:function_ref", diff --git a/third_party/xla/xla/python/ifrt/client.h b/third_party/xla/xla/python/ifrt/client.h index d0b961654938ba..418b08017be884 100644 --- a/third_party/xla/xla/python/ifrt/client.h +++ b/third_party/xla/xla/python/ifrt/client.h @@ -16,14 +16,20 @@ limitations under the License. #ifndef XLA_PYTHON_IFRT_CLIENT_H_ #define XLA_PYTHON_IFRT_CLIENT_H_ +#include #include #include #include +#include +#include +#include +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/tuple.h" @@ -118,6 +124,10 @@ class Client : public llvm::RTTIExtends { virtual absl::string_view platform_version() const = 0; virtual PlatformId platform_id() const = 0; + using ClientAttribute = xla::PjRtValueType; + virtual absl::flat_hash_map attributes() + const = 0; + virtual int device_count() const = 0; virtual int addressable_device_count() const = 0; virtual absl::Span devices() const = 0; diff --git a/third_party/xla/xla/python/ifrt/mock.cc b/third_party/xla/xla/python/ifrt/mock.cc index b66f15c99901cf..57d9004813de19 100644 --- a/third_party/xla/xla/python/ifrt/mock.cc +++ b/third_party/xla/xla/python/ifrt/mock.cc @@ -119,6 +119,9 @@ MockClient::MockClient(std::unique_ptr delegated) ON_CALL(*this, platform_id).WillByDefault([this]() { return delegated_->platform_id(); }); + ON_CALL(*this, attributes).WillByDefault([this]() { + return delegated_->attributes(); + }); ON_CALL(*this, device_count).WillByDefault([this]() { return delegated_->device_count(); }); diff --git a/third_party/xla/xla/python/ifrt/mock.h b/third_party/xla/xla/python/ifrt/mock.h index f78d23cf17280e..9ba0583d0015dd 100644 --- a/third_party/xla/xla/python/ifrt/mock.h +++ b/third_party/xla/xla/python/ifrt/mock.h @@ -114,6 +114,8 @@ class MockClient final : public llvm::RTTIExtends { MOCK_METHOD(absl::string_view, runtime_type, (), (const, final)); MOCK_METHOD(absl::string_view, platform_name, (), (const, final)); MOCK_METHOD(absl::string_view, platform_version, (), (const, final)); + MOCK_METHOD((absl::flat_hash_map), + attributes, (), (const, final)); MOCK_METHOD(int, device_count, (), (const, final)); MOCK_METHOD(PlatformId, platform_id, (), (const, final)); MOCK_METHOD(int, addressable_device_count, (), (const, final)); diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h index d80af00f4de028..4877e360981662 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h @@ -19,8 +19,10 @@ limitations under the License. #include #include #include +#include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/pjrt/pjrt_client.h" @@ -109,6 +111,18 @@ class PjRtClient final DCHECK(this); return pjrt_client_->platform_id(); } + absl::flat_hash_map attributes() + const override { + std::optional attributes = + pjrt_client_->plugin_attributes(); + if (!attributes.has_value()) { + return {}; + } + return {{"pjrt_c_api_major_version", + ClientAttribute(attributes->pjrt_c_api_major_version)}, + {"pjrt_c_api_minor_version", + ClientAttribute(attributes->pjrt_c_api_minor_version)}}; + } int device_count() const override { DCHECK(this); diff --git a/third_party/xla/xla/python/py_client.cc b/third_party/xla/xla/python/py_client.cc index 9aa33e908874a4..752edd05cedb3f 100644 --- a/third_party/xla/xla/python/py_client.cc +++ b/third_party/xla/xla/python/py_client.cc @@ -57,7 +57,8 @@ namespace xla { namespace py = pybind11; PyClient::PyClient(std::shared_ptr ifrt_client) - : ifrt_client_(std::move(ifrt_client)) { + : ifrt_client_(std::move(ifrt_client)), + client_attributes_(ifrt_client_->attributes()) { CHECK(ifrt_client_); // TODO(phawkins): this is a temporary backwards compatibility shim. We // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but diff --git a/third_party/xla/xla/python/py_client.h b/third_party/xla/xla/python/py_client.h index 31397f89a7b60f..68b95c23affb4d 100644 --- a/third_party/xla/xla/python/py_client.h +++ b/third_party/xla/xla/python/py_client.h @@ -16,15 +16,18 @@ limitations under the License. #ifndef XLA_PYTHON_PY_CLIENT_H_ #define XLA_PYTHON_PY_CLIENT_H_ +#include #include #include #include #include #include +#include "absl/container/flat_hash_map.h" #include "pybind11/pybind11.h" // from @pybind11 #include "xla/client/xla_builder.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" #include "xla/python/exceptions.h" #include "xla/python/ifrt/client.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" @@ -143,6 +146,14 @@ class PyClient : public std::enable_shared_from_this { absl::string_view runtime_type() const { return ifrt_client_->runtime_type(); } + + // Returns implementation-specific attributes about this client, e.g. the PJRT + // C API version if applicable. + absl::flat_hash_map + attributes() const { + return client_attributes_; + } + int addressable_device_count() const { return ifrt_client_->addressable_device_count(); } @@ -247,7 +258,8 @@ class PyClient : public std::enable_shared_from_this { std::shared_ptr ifrt_client_; std::string platform_name_; - + absl::flat_hash_map + client_attributes_; // Pointers to intrusive doubly-linked lists of arrays and executables, used // to iterate over all known objects when heap profiling. The list structure // is protected by the GIL. diff --git a/third_party/xla/xla/python/py_compile_only_client.cc b/third_party/xla/xla/python/py_compile_only_client.cc index dfb8848118b333..78908290c64b04 100644 --- a/third_party/xla/xla/python/py_compile_only_client.cc +++ b/third_party/xla/xla/python/py_compile_only_client.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "pybind11/stl.h" // from @pybind11 #include "xla/pjrt/mlir_to_hlo.h" @@ -136,6 +137,10 @@ class CompileOnlyIfRtClient final ifrt::PlatformId platform_id() const override { return topology_->platform_id(); } + absl::flat_hash_map attributes() + const override { + return {}; + } int device_count() const override { return devices().size(); } int addressable_device_count() const override { return 0; } diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index a3fd6063d15e0c..8d09b08eb11491 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -480,7 +480,15 @@ static void Init(py::module_& m) { &PyClient::MakePythonCallbackUsingHostSendAndRecv), py::arg("callable"), py::arg("operand_shapes"), py::arg("result_shapes"), py::arg("send_channel_ids"), - py::arg("recv_channel_ids"), py::arg("serializer") = py::none()); + py::arg("recv_channel_ids"), py::arg("serializer") = py::none()) + .def("__getattr__", [](PyClient& client, std::string name) -> py::object { + const auto& attrs = client.attributes(); + auto it = attrs.find(name); + if (it != attrs.end()) { + return std::visit([](auto&& v) { return py::cast(v); }, it->second); + } + throw py::attribute_error(absl::StrCat("Unknown attribute ", name)); + }); m.def( "get_tfrt_cpu_client", diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index 0119614f974951..e2b201d2f7234f 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -48,7 +48,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 212 +_version = 213 # Version number for MLIR:Python components. mlir_api_version = 54 diff --git a/third_party/xla/xla/python/xla_client_test.py b/third_party/xla/xla/python/xla_client_test.py index fecbe4d2b157c5..5d7bd289098ee0 100644 --- a/third_party/xla/xla/python/xla_client_test.py +++ b/third_party/xla/xla/python/xla_client_test.py @@ -2627,6 +2627,22 @@ def testPlatformVersion(self): self.assertIn("cl/", version) self.assertIn("Built on ", version) + @unittest.skipIf( + not cloud_tpu and not pjrt_c_api, "PJRT version only exist for plugins" + ) + def testPjRtCApiVersion(self): + self.assertGreaterEqual(self.backend.pjrt_c_api_major_version, 0) + self.assertGreaterEqual(self.backend.pjrt_c_api_minor_version, 0) + + @unittest.skipIf( + cloud_tpu or pjrt_c_api, "PJRT version only exist for plugins" + ) + def testNotExistPjRtCApiVersion(self): + with self.assertRaises(AttributeError): + self.backend.pjrt_c_api_major_version # pylint: disable=pointless-statement + with self.assertRaises(AttributeError): + self.backend.pjrt_c_api_minor_version # pylint: disable=pointless-statement + @unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or tfrt_tpu, "not implemented") def testExecutableSerialization(self): diff --git a/third_party/xla/xla/python/xla_extension/__init__.pyi b/third_party/xla/xla/python/xla_extension/__init__.pyi index eb845069991eab..63df53d5ba3710 100644 --- a/third_party/xla/xla/python/xla_extension/__init__.pyi +++ b/third_party/xla/xla/python/xla_extension/__init__.pyi @@ -472,6 +472,7 @@ class Client: self, callable: Callable, operand_shapes: Sequence[Shape], result_shapes: Sequence[Shape], send_channel_ids: Sequence[int], recv_channel_ids: Sequence[int], serializer: Optional[Callable] = ...) -> Any: ... + def __getattr__(self, name: str) -> Any: ... def get_tfrt_cpu_client(asynchronous: bool = ...) -> Client: ... def get_gpu_client(