Skip to content

Commit

Permalink
[PJRT:C] Add PJRT_Executable_Fingerprint to support AOT compilation.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 575543476
  • Loading branch information
tensorflower-gardener committed Oct 22, 2023
1 parent bb107e1 commit b27b5d3
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 6 deletions.
5 changes: 5 additions & 0 deletions third_party/xla/xla/pjrt/c/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# PJRT C API changelog

## 0.35 (Oct 20, 2023)

* Added PJRT_Executable_Fingerprint method
* Deprecated PJRT_LoadedExecutable_Fingerprint

## 0.34 (Oct 9, 2023)

* Added PJRT_Structure_Type::PJRT_Structure_Type_Profiler.
Expand Down
31 changes: 26 additions & 5 deletions third_party/xla/xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ extern "C" {
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 34
#define PJRT_API_MINOR 35

// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
Expand Down Expand Up @@ -1315,6 +1315,24 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_SizeOfGeneratedCodeInBytes_Args,
typedef PJRT_Error* PJRT_Executable_SizeOfGeneratedCodeInBytes(
PJRT_Executable_SizeOfGeneratedCodeInBytes_Args* args);

struct PJRT_Executable_Fingerprint_Args {
size_t struct_size;
void* priv;
PJRT_Executable* executable;
// Has the lifetime of `executable`
const char* executable_fingerprint; // out
size_t executable_fingerprint_size; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Fingerprint_Args,
executable_fingerprint_size);

// A unique fingerprint for `executable`. Two executables that were produced by
// compiling with identical inputs (same program, compile options, compiler
// version, etc.) should have the same fingerprint. May not be implemented by
// all platforms.
typedef PJRT_Error* PJRT_Executable_Fingerprint(
PJRT_Executable_Fingerprint_Args* args);

struct PJRT_Executable_GetCostAnalysis_Args {
size_t struct_size;
void* priv;
Expand Down Expand Up @@ -1434,10 +1452,11 @@ struct PJRT_LoadedExecutable_Fingerprint_Args {
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Fingerprint_Args,
executable_fingerprint_size);
// A unique fingerprint for `executable`. Two executables that were produced by
// compiling with identical inputs (same program, compile options, compiler
// version, etc.) should have the same fingerprint. May not be implemented by
// all platforms.
// DEPRECATED. Will be removed in PJRT version 2.0. Please use
// PJRT_Executable_Fingerprint instead. A unique fingerprint for `executable`.
// Two executables that were produced by compiling with identical inputs (same
// program, compile options, compiler version, etc.) should have the same
// fingerprint. May not be implemented by all platforms.
typedef PJRT_Error* PJRT_LoadedExecutable_Fingerprint(
PJRT_LoadedExecutable_Fingerprint_Args* args);

Expand Down Expand Up @@ -2090,6 +2109,8 @@ typedef struct {
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyToMemory);

_PJRT_API_STRUCT_FIELD(PJRT_Client_CreateViewOfDeviceBuffer);

_PJRT_API_STRUCT_FIELD(PJRT_Executable_Fingerprint);
} PJRT_Api;

const size_t PJRT_Api_STRUCT_SIZE =
Expand Down
17 changes: 16 additions & 1 deletion third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ PJRT_Error* PJRT_Client_LookupAddressableDevice(
return nullptr;
}

// TODO: b/306669267 - this method is deprecated. When can we return
// unimplemented?
PJRT_Error* PJRT_LoadedExecutable_Fingerprint(
PJRT_LoadedExecutable_Fingerprint_Args* args) {
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
Expand Down Expand Up @@ -1115,6 +1117,18 @@ PJRT_Error* PJRT_Executable_OptimizedProgram(
}
}

PJRT_Error* PJRT_Executable_Fingerprint(
PJRT_Executable_Fingerprint_Args* args) {
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
"PJRT_Executable_Fingerprint_Args",
PJRT_Executable_Fingerprint_Args_STRUCT_SIZE, args->struct_size));
PJRT_RETURN_IF_ERROR(args->executable->fingerprint.status());
args->executable_fingerprint = args->executable->fingerprint.value().c_str();
args->executable_fingerprint_size =
args->executable->fingerprint.value().size();
return nullptr;
}

PJRT_Error* PJRT_Executable_GetCostAnalysis(
PJRT_Executable_GetCostAnalysis_Args* args) {
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
Expand Down Expand Up @@ -2154,7 +2168,8 @@ PJRT_TopologyDescription* CreateWrapperDeviceTopology(

PJRT_Executable::PJRT_Executable(
std::shared_ptr<xla::PjRtExecutable> executable)
: executable(std::move(executable)) {}
: executable(std::move(executable)),
fingerprint(this->executable->FingerprintExecutable()) {}

PJRT_LoadedExecutable::PJRT_LoadedExecutable(
std::shared_ptr<xla::PjRtLoadedExecutable> executable, PJRT_Client* client)
Expand Down
6 changes: 6 additions & 0 deletions third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ struct PJRT_Executable {
// Must be shared_ptr so that we can share with PJRT_LoadedExecutable.
std::shared_ptr<xla::PjRtExecutable> executable;

xla::StatusOr<std::string> fingerprint;

// Used to synchronize concurrent setting of cached values.
mutable absl::Mutex mutex;

Expand Down Expand Up @@ -262,6 +264,7 @@ PJRT_Error* PJRT_LoadedExecutable_AddressableDevices(
PJRT_Error* PJRT_Executable_NumOutputs(PJRT_Executable_NumOutputs_Args* args);
PJRT_Error* PJRT_Executable_SizeOfGeneratedCodeInBytes(
PJRT_Executable_SizeOfGeneratedCodeInBytes_Args* args);
PJRT_Error* PJRT_Executable_Fingerprint(PJRT_Executable_Fingerprint_Args* args);
PJRT_Error* PJRT_Executable_GetCostAnalysis(
PJRT_Executable_GetCostAnalysis_Args* args);
PJRT_Error* PJRT_Executable_OutputElementTypes(
Expand All @@ -286,6 +289,8 @@ PJRT_Error* PJRT_Executable_DeserializeAndLoad(
PJRT_Executable_DeserializeAndLoad_Args* args);
PJRT_Error* PJRT_LoadedExecutable_GetExecutable(
PJRT_LoadedExecutable_GetExecutable_Args* args);
// TODO: b/306669267 - this method is deprecated. When can we return
// unimplemented?
PJRT_Error* PJRT_LoadedExecutable_Fingerprint(
PJRT_LoadedExecutable_Fingerprint_Args* args);

Expand Down Expand Up @@ -563,6 +568,7 @@ constexpr PJRT_Api CreatePjrtApi(
pjrt::PJRT_Buffer_CopyToMemory,
/*PJRT_Client_CreateViewOfDeviceBuffer=*/
pjrt::PJRT_Client_CreateViewOfDeviceBuffer,
/*PJRT_Executable_Fingerprint=*/pjrt::PJRT_Executable_Fingerprint,
};
}

Expand Down
13 changes: 13 additions & 0 deletions third_party/xla/xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,19 @@ StatusOr<std::string> PjRtCApiExecutable::SerializeExecutable() const {
return std::string(ser_args.serialized_bytes, ser_args.serialized_bytes_size);
}

StatusOr<std::string> PjRtCApiExecutable::FingerprintExecutable() const {
PJRT_Executable_Fingerprint_Args args;
args.struct_size = PJRT_Executable_Fingerprint_Args_STRUCT_SIZE;
args.priv = nullptr;
args.executable = c_executable();

RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_Executable_Fingerprint(&args),
c_api_);

return std::string(args.executable_fingerprint,
args.executable_fingerprint_size);
}

// ------------------------ Loaded Executables ---------------------------------

PjRtCApiLoadedExecutable::PjRtCApiLoadedExecutable(
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/pjrt/pjrt_c_api_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,8 @@ class PjRtCApiExecutable : public PjRtExecutable {

StatusOr<std::string> SerializeExecutable() const override;

StatusOr<std::string> FingerprintExecutable() const override;

private:
const PJRT_Api* c_api_;
std::unique_ptr<PJRT_Executable, ::pjrt::PJRT_ExecutableDeleter> executable_;
Expand Down

0 comments on commit b27b5d3

Please sign in to comment.