Skip to content

Commit

Permalink
Rollback of PR tensorflow#18287
Browse files Browse the repository at this point in the history
Revert the introduction of compile option overriding as it causes the JAX tests failure on cloud TPU with Python-3.11.

Reverts 1fec0f5

PiperOrigin-RevId: 692001892
  • Loading branch information
tensorflower-gardener committed Nov 1, 2024
1 parent 1c730cb commit 6ac2f0e
Show file tree
Hide file tree
Showing 6 changed files with 2 additions and 75 deletions.
6 changes: 0 additions & 6 deletions third_party/xla/xla/pjrt/c/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
# PJRT C API changelog

## 0.56

* Added `overridden_serialized_compile_options` and
`overridden_serialized_compile_options_size` fields to
`PJRT_Executable_DeserializeAndLoad_Args`.

## 0.55
* Added types F8E4M3 and F8E3M4.

Expand Down
7 changes: 1 addition & 6 deletions third_party/xla/xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next);
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 56
#define PJRT_API_MINOR 55

// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
Expand Down Expand Up @@ -1577,11 +1577,6 @@ struct PJRT_Executable_DeserializeAndLoad_Args {
const char* serialized_executable;
size_t serialized_executable_size;
PJRT_LoadedExecutable* loaded_executable; // out
// Serialized CompileOptionsProto or null (to use the options
// from the serialized executable).
// (https://github.com/openxla/xla/blob/main/xla/pjrt/compile_options.proto)
const char* overridden_serialized_compile_options;
size_t overridden_serialized_compile_options_size;
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_DeserializeAndLoad_Args,
loaded_executable);
Expand Down
13 changes: 1 addition & 12 deletions third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1570,20 +1570,9 @@ PJRT_Error* PJRT_Executable_DeserializeAndLoad(
absl::string_view serialized(args->serialized_executable,
args->serialized_executable_size);

std::optional<xla::CompileOptions> overriden_options;

if (args->overridden_serialized_compile_options &&
args->overridden_serialized_compile_options_size > 0) {
PJRT_ASSIGN_OR_RETURN(
overriden_options,
ParseCompileOptions(absl::string_view(
args->overridden_serialized_compile_options,
args->overridden_serialized_compile_options_size)));
}

PJRT_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtLoadedExecutable> executable,
args->client->client->DeserializeExecutable(
serialized, overriden_options));
serialized, /*options=*/std::nullopt));

args->loaded_executable =
new PJRT_LoadedExecutable(std::move(executable), args->client);
Expand Down
11 changes: 0 additions & 11 deletions third_party/xla/xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,17 +422,6 @@ PjRtCApiClient::DeserializeExecutable(absl::string_view serialized,
des_args.client = c_client_.get();
des_args.serialized_executable = serialized.data();
des_args.serialized_executable_size = serialized.length();
des_args.overridden_serialized_compile_options = nullptr;
des_args.overridden_serialized_compile_options_size = 0;

std::string options_str;
if (options) {
TF_ASSIGN_OR_RETURN(const CompileOptionsProto options_proto,
options->ToProto());
options_str = options_proto.SerializeAsString();
des_args.overridden_serialized_compile_options = options_str.c_str();
des_args.overridden_serialized_compile_options_size = options_str.size();
}

const PJRT_Api* api = pjrt_c_api();

Expand Down
37 changes: 0 additions & 37 deletions third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,42 +199,5 @@ TEST(PjRtClientTest, CompileUsesStableHloVersion) {
const_cast<PJRT_Api*>(c_api)->PJRT_Client_Compile = PJRT_Client_Compile_Orig;
}

TEST(PjRtClientTest, DeserializeExecutableWithDifferentDeviceAssignment) {
SetUpCpuPjRtApi();
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client,
GetCApiClient("cpu"));
ASSERT_GT(client->addressable_devices().size(), 1);

XlaBuilder builder("Identity");
Shape shape = ShapeUtil::MakeShape(S32, {2, 3});
auto input = Parameter(&builder, 0, shape, "input");
auto computation = builder.Build(input).value();

auto compile_options_for_device = [](int id) -> xla::CompileOptions {
xla::DeviceAssignment device_assignment(1, 1);
device_assignment(0, 0) = id;
xla::CompileOptions options;
options.executable_build_options.set_device_assignment(device_assignment);
return options;
};

// Compile the executable for device 0 and serialize it.
std::unique_ptr<PjRtLoadedExecutable> executable =
client->Compile(computation, compile_options_for_device(0)).value();
TF_ASSERT_OK_AND_ASSIGN(std::string serialized_executable,
executable->SerializeExecutable());

// Deserialize the executable for device 1.
TF_ASSERT_OK_AND_ASSIGN(
auto deserialized_executable,
client->DeserializeExecutable(serialized_executable,
compile_options_for_device(1)));

// Check that the executable's compile options were overridden
// with device id 1.
EXPECT_EQ(
deserialized_executable->addressable_devices()[0]->global_device_id(), 1);
}

} // namespace
} // namespace xla
3 changes: 0 additions & 3 deletions third_party/xla/xla/pjrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -612,9 +612,6 @@ class PjRtClient {
// Pending completion of b/237720161, `options` is a mandatory argument in
// most implementations of this interface. They _are_ optional for
// implementations related to the PJRT C API.
//
// If `options` are provided, then they override the compile options
// from the serialized executable (`serialized`).
virtual absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
DeserializeExecutable(absl::string_view serialized,
std::optional<CompileOptions> options) {
Expand Down

0 comments on commit 6ac2f0e

Please sign in to comment.