Skip to content

Commit

Permalink
[PJRT C API] Plumb PJRT C API version to JAX python client.
Browse files Browse the repository at this point in the history
Added a struct PjRtPluginAttributes to PJRT client. IFRT and JAX python client will use a key-value map for plugin attributes.

PjRtPluginAttributes can be extended to include other plugin attributes (e.g. whether a certain feature is supported).

PiperOrigin-RevId: 580988396
  • Loading branch information
Jieying Luo authored and tensorflower-gardener committed Nov 9, 2023
1 parent 345c804 commit ce05e8d
Show file tree
Hide file tree
Showing 15 changed files with 97 additions and 4 deletions.
5 changes: 5 additions & 0 deletions third_party/xla/xla/pjrt/pjrt_c_api_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ class PjRtCApiClient : public PjRtClient {

absl::string_view platform_version() const override;

std::optional<PjRtPluginAttributes> plugin_attributes() const override {
return PjRtPluginAttributes{c_api_->pjrt_api_version.major_version,
c_api_->pjrt_api_version.minor_version};
}

// TODO(b/244756954): Rethink this function altogether
PjRtRuntimeType runtime_type() const override {
return PjRtRuntimeType::kTfrt;
Expand Down
11 changes: 11 additions & 0 deletions third_party/xla/xla/pjrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,11 @@ class PjRtHostMemoryForDeviceManager {

class PjRtLoadedExecutable;

struct PjRtPluginAttributes {
int64_t pjrt_c_api_major_version;
int64_t pjrt_c_api_minor_version;
};

// Encapsulates the state of Python session with XLA.
//
// It is the responsibility of the client of this API to keep the PjRtClient
Expand Down Expand Up @@ -515,6 +520,12 @@ class PjRtClient {
// (e.g. the CUDA version on GPU or libtpu version on Cloud TPU).
virtual absl::string_view platform_version() const = 0;

// Returns information about the underlying PJRT C API plugin if such a plugin
// is being used, otherwise returns nullopt.
virtual std::optional<PjRtPluginAttributes> plugin_attributes() const {
return std::nullopt;
}

// TODO(b/244756954): Rethink this function altogether
// Returns an enum that identifies the type of runtime being used under this
// client.
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ cc_library(
"//xla/pjrt:lru_cache",
"//xla/pjrt:mlir_to_hlo",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_common",
"//xla/pjrt:pjrt_compiler",
"//xla/pjrt:pjrt_future",
"//xla/pjrt:pjrt_stream_executor_client",
Expand Down Expand Up @@ -1147,6 +1148,7 @@ cc_library(
":xla_compiler",
# placeholder for index annotation deps
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:initialize",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
Expand All @@ -1163,6 +1165,7 @@ cc_library(
"//xla/pjrt:pjrt_api",
"//xla/pjrt:pjrt_c_api_client",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_common",
"//xla/pjrt:pjrt_compiler",
"//xla/pjrt:tfrt_cpu_pjrt_client",
"//xla/pjrt/c:pjrt_c_api_hdrs",
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/python/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ cc_library(
"//xla:statusor",
"//xla:util",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_common",
"//xla/python/ifrt/ir",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/container:node_hash_set",
"@com_google_absl//absl/functional:function_ref",
Expand Down
10 changes: 10 additions & 0 deletions third_party/xla/xla/python/ifrt/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@ limitations under the License.
#ifndef XLA_PYTHON_IFRT_CLIENT_H_
#define XLA_PYTHON_IFRT_CLIENT_H_

#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <variant>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/compiler.h"
#include "xla/python/ifrt/tuple.h"
Expand Down Expand Up @@ -118,6 +124,10 @@ class Client : public llvm::RTTIExtends<Client, llvm::RTTIRoot> {
virtual absl::string_view platform_version() const = 0;
virtual PlatformId platform_id() const = 0;

using ClientAttribute = xla::PjRtValueType;
virtual absl::flat_hash_map<std::string, ClientAttribute> attributes()
const = 0;

virtual int device_count() const = 0;
virtual int addressable_device_count() const = 0;
virtual absl::Span<Device* const> devices() const = 0;
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/python/ifrt/mock.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ MockClient::MockClient(std::unique_ptr<xla::ifrt::Client> delegated)
ON_CALL(*this, platform_id).WillByDefault([this]() {
return delegated_->platform_id();
});
ON_CALL(*this, attributes).WillByDefault([this]() {
return delegated_->attributes();
});
ON_CALL(*this, device_count).WillByDefault([this]() {
return delegated_->device_count();
});
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/python/ifrt/mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ class MockClient final : public llvm::RTTIExtends<MockClient, Client> {
MOCK_METHOD(absl::string_view, runtime_type, (), (const, final));
MOCK_METHOD(absl::string_view, platform_name, (), (const, final));
MOCK_METHOD(absl::string_view, platform_version, (), (const, final));
MOCK_METHOD((absl::flat_hash_map<std::string, Client::ClientAttribute>),
attributes, (), (const, final));
MOCK_METHOD(int, device_count, (), (const, final));
MOCK_METHOD(PlatformId, platform_id, (), (const, final));
MOCK_METHOD(int, addressable_device_count, (), (const, final));
Expand Down
14 changes: 14 additions & 0 deletions third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ limitations under the License.
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>

#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/pjrt/pjrt_client.h"
Expand Down Expand Up @@ -109,6 +111,18 @@ class PjRtClient final
DCHECK(this);
return pjrt_client_->platform_id();
}
absl::flat_hash_map<std::string, ClientAttribute> attributes()
const override {
std::optional<PjRtPluginAttributes> attributes =
pjrt_client_->plugin_attributes();
if (!attributes.has_value()) {
return {};
}
return {{"pjrt_c_api_major_version",
ClientAttribute(attributes->pjrt_c_api_major_version)},
{"pjrt_c_api_minor_version",
ClientAttribute(attributes->pjrt_c_api_minor_version)}};
}

int device_count() const override {
DCHECK(this);
Expand Down
3 changes: 2 additions & 1 deletion third_party/xla/xla/python/py_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ namespace xla {
namespace py = pybind11;

PyClient::PyClient(std::shared_ptr<ifrt::Client> ifrt_client)
: ifrt_client_(std::move(ifrt_client)) {
: ifrt_client_(std::move(ifrt_client)),
client_attributes_(ifrt_client_->attributes()) {
CHECK(ifrt_client_);
// TODO(phawkins): this is a temporary backwards compatibility shim. We
// changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but
Expand Down
14 changes: 13 additions & 1 deletion third_party/xla/xla/python/py_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@ limitations under the License.
#ifndef XLA_PYTHON_PY_CLIENT_H_
#define XLA_PYTHON_PY_CLIENT_H_

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "pybind11/pybind11.h" // from @pybind11
#include "xla/client/xla_builder.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"
#include "xla/python/exceptions.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/pjrt_ifrt/pjrt_client.h"
Expand Down Expand Up @@ -143,6 +146,14 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
absl::string_view runtime_type() const {
return ifrt_client_->runtime_type();
}

// Returns implementation-specific attributes about this client, e.g. the PJRT
// C API version if applicable.
absl::flat_hash_map<std::string, xla::ifrt::Client::ClientAttribute>
attributes() const {
return client_attributes_;
}

int addressable_device_count() const {
return ifrt_client_->addressable_device_count();
}
Expand Down Expand Up @@ -247,7 +258,8 @@ class PyClient : public std::enable_shared_from_this<PyClient> {

std::shared_ptr<ifrt::Client> ifrt_client_;
std::string platform_name_;

absl::flat_hash_map<std::string, xla::ifrt::Client::ClientAttribute>
client_attributes_;
// Pointers to intrusive doubly-linked lists of arrays and executables, used
// to iterate over all known objects when heap profiling. The list structure
// is protected by the GIL.
Expand Down
5 changes: 5 additions & 0 deletions third_party/xla/xla/python/py_compile_only_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "pybind11/stl.h" // from @pybind11
#include "xla/pjrt/mlir_to_hlo.h"
Expand Down Expand Up @@ -136,6 +137,10 @@ class CompileOnlyIfRtClient final
ifrt::PlatformId platform_id() const override {
return topology_->platform_id();
}
absl::flat_hash_map<std::string, ClientAttribute> attributes()
const override {
return {};
}

int device_count() const override { return devices().size(); }
int addressable_device_count() const override { return 0; }
Expand Down
10 changes: 9 additions & 1 deletion third_party/xla/xla/python/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,15 @@ static void Init(py::module_& m) {
&PyClient::MakePythonCallbackUsingHostSendAndRecv),
py::arg("callable"), py::arg("operand_shapes"),
py::arg("result_shapes"), py::arg("send_channel_ids"),
py::arg("recv_channel_ids"), py::arg("serializer") = py::none());
py::arg("recv_channel_ids"), py::arg("serializer") = py::none())
.def("__getattr__", [](PyClient& client, std::string name) -> py::object {
const auto& attrs = client.attributes();
auto it = attrs.find(name);
if (it != attrs.end()) {
return std::visit([](auto&& v) { return py::cast(v); }, it->second);
}
throw py::attribute_error(absl::StrCat("Unknown attribute ", name));
});

m.def(
"get_tfrt_cpu_client",
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 212
_version = 213

# Version number for MLIR:Python components.
mlir_api_version = 54
Expand Down
16 changes: 16 additions & 0 deletions third_party/xla/xla/python/xla_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2627,6 +2627,22 @@ def testPlatformVersion(self):
self.assertIn("cl/", version)
self.assertIn("Built on ", version)

@unittest.skipIf(
not cloud_tpu and not pjrt_c_api, "PJRT version only exist for plugins"
)
def testPjRtCApiVersion(self):
self.assertGreaterEqual(self.backend.pjrt_c_api_major_version, 0)
self.assertGreaterEqual(self.backend.pjrt_c_api_minor_version, 0)

@unittest.skipIf(
cloud_tpu or pjrt_c_api, "PJRT version only exist for plugins"
)
def testNotExistPjRtCApiVersion(self):
with self.assertRaises(AttributeError):
self.backend.pjrt_c_api_major_version # pylint: disable=pointless-statement
with self.assertRaises(AttributeError):
self.backend.pjrt_c_api_minor_version # pylint: disable=pointless-statement

@unittest.skipIf(cloud_tpu or pathways or pathways_ifrt or tfrt_tpu,
"not implemented")
def testExecutableSerialization(self):
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/python/xla_extension/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ class Client:
self, callable: Callable, operand_shapes: Sequence[Shape],
result_shapes: Sequence[Shape], send_channel_ids: Sequence[int],
recv_channel_ids: Sequence[int], serializer: Optional[Callable] = ...) -> Any: ...
def __getattr__(self, name: str) -> Any: ...

def get_tfrt_cpu_client(asynchronous: bool = ...) -> Client: ...
def get_gpu_client(
Expand Down

0 comments on commit ce05e8d

Please sign in to comment.