From b27b5d3ea312a73b64193d02e852717dbf3c33c0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 21 Oct 2023 21:40:02 -0700 Subject: [PATCH] [PJRT:C] Add PJRT_Executable_Fingerprint to support AOT compilation. PiperOrigin-RevId: 575543476 --- third_party/xla/xla/pjrt/c/CHANGELOG.md | 5 +++ third_party/xla/xla/pjrt/c/pjrt_c_api.h | 31 ++++++++++++++++--- .../xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 17 +++++++++- .../xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h | 6 ++++ third_party/xla/xla/pjrt/pjrt_c_api_client.cc | 13 ++++++++ third_party/xla/xla/pjrt/pjrt_c_api_client.h | 2 ++ 6 files changed, 68 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/pjrt/c/CHANGELOG.md b/third_party/xla/xla/pjrt/c/CHANGELOG.md index db6d03cd136233..908d96082b68e8 100644 --- a/third_party/xla/xla/pjrt/c/CHANGELOG.md +++ b/third_party/xla/xla/pjrt/c/CHANGELOG.md @@ -1,5 +1,10 @@ # PJRT C API changelog +## 0.35 (Oct 20, 2023) + +* Added PJRT_Executable_Fingerprint method +* Deprecated PJRT_LoadedExecutable_Fingerprint + ## 0.34 (Oct 9, 2023) * Added PJRT_Structure_Type::PJRT_Structure_Type_Profiler. diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api.h b/third_party/xla/xla/pjrt/c/pjrt_c_api.h index 20f35608a8058a..8295597ee0e586 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api.h @@ -53,7 +53,7 @@ extern "C" { // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 34 +#define PJRT_API_MINOR 35 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -1315,6 +1315,24 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_SizeOfGeneratedCodeInBytes_Args, typedef PJRT_Error* PJRT_Executable_SizeOfGeneratedCodeInBytes( PJRT_Executable_SizeOfGeneratedCodeInBytes_Args* args); +struct PJRT_Executable_Fingerprint_Args { + size_t struct_size; + void* priv; + PJRT_Executable* executable; + // Has the lifetime of `executable` + const char* executable_fingerprint; // out + size_t executable_fingerprint_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Fingerprint_Args, + executable_fingerprint_size); + +// A unique fingerprint for `executable`. Two executables that were produced by +// compiling with identical inputs (same program, compile options, compiler +// version, etc.) should have the same fingerprint. May not be implemented by +// all platforms. +typedef PJRT_Error* PJRT_Executable_Fingerprint( + PJRT_Executable_Fingerprint_Args* args); + struct PJRT_Executable_GetCostAnalysis_Args { size_t struct_size; void* priv; @@ -1434,10 +1452,11 @@ struct PJRT_LoadedExecutable_Fingerprint_Args { }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Fingerprint_Args, executable_fingerprint_size); -// A unique fingerprint for `executable`. Two executables that were produced by -// compiling with identical inputs (same program, compile options, compiler -// version, etc.) should have the same fingerprint. May not be implemented by -// all platforms. +// DEPRECATED. Will be removed in PJRT version 2.0. Please use +// PJRT_Executable_Fingerprint instead. A unique fingerprint for `executable`. +// Two executables that were produced by compiling with identical inputs (same +// program, compile options, compiler version, etc.) should have the same +// fingerprint. May not be implemented by all platforms. typedef PJRT_Error* PJRT_LoadedExecutable_Fingerprint( PJRT_LoadedExecutable_Fingerprint_Args* args); @@ -2090,6 +2109,8 @@ typedef struct { _PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyToMemory); _PJRT_API_STRUCT_FIELD(PJRT_Client_CreateViewOfDeviceBuffer); + + _PJRT_API_STRUCT_FIELD(PJRT_Executable_Fingerprint); } PJRT_Api; const size_t PJRT_Api_STRUCT_SIZE = diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 9e7570ea57cf84..7ba96e3efa7fa9 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -416,6 +416,8 @@ PJRT_Error* PJRT_Client_LookupAddressableDevice( return nullptr; } +// TODO: b/306669267 - this method is deprecated. When can we return +// unimplemented? PJRT_Error* PJRT_LoadedExecutable_Fingerprint( PJRT_LoadedExecutable_Fingerprint_Args* args) { PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( @@ -1115,6 +1117,18 @@ PJRT_Error* PJRT_Executable_OptimizedProgram( } } +PJRT_Error* PJRT_Executable_Fingerprint( + PJRT_Executable_Fingerprint_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_Executable_Fingerprint_Args", + PJRT_Executable_Fingerprint_Args_STRUCT_SIZE, args->struct_size)); + PJRT_RETURN_IF_ERROR(args->executable->fingerprint.status()); + args->executable_fingerprint = args->executable->fingerprint.value().c_str(); + args->executable_fingerprint_size = + args->executable->fingerprint.value().size(); + return nullptr; +} + PJRT_Error* PJRT_Executable_GetCostAnalysis( PJRT_Executable_GetCostAnalysis_Args* args) { PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( @@ -2154,7 +2168,8 @@ PJRT_TopologyDescription* CreateWrapperDeviceTopology( PJRT_Executable::PJRT_Executable( std::shared_ptr executable) - : executable(std::move(executable)) {} + : executable(std::move(executable)), + fingerprint(this->executable->FingerprintExecutable()) {} PJRT_LoadedExecutable::PJRT_LoadedExecutable( std::shared_ptr executable, PJRT_Client* client) diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index f26450cc027547..acd0a26ea7ee22 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -93,6 +93,8 @@ struct PJRT_Executable { // Must be shared_ptr so that we can share with PJRT_LoadedExecutable. std::shared_ptr executable; + xla::StatusOr fingerprint; + // Used to synchronize concurrent setting of cached values. mutable absl::Mutex mutex; @@ -262,6 +264,7 @@ PJRT_Error* PJRT_LoadedExecutable_AddressableDevices( PJRT_Error* PJRT_Executable_NumOutputs(PJRT_Executable_NumOutputs_Args* args); PJRT_Error* PJRT_Executable_SizeOfGeneratedCodeInBytes( PJRT_Executable_SizeOfGeneratedCodeInBytes_Args* args); +PJRT_Error* PJRT_Executable_Fingerprint(PJRT_Executable_Fingerprint_Args* args); PJRT_Error* PJRT_Executable_GetCostAnalysis( PJRT_Executable_GetCostAnalysis_Args* args); PJRT_Error* PJRT_Executable_OutputElementTypes( @@ -286,6 +289,8 @@ PJRT_Error* PJRT_Executable_DeserializeAndLoad( PJRT_Executable_DeserializeAndLoad_Args* args); PJRT_Error* PJRT_LoadedExecutable_GetExecutable( PJRT_LoadedExecutable_GetExecutable_Args* args); +// TODO: b/306669267 - this method is deprecated. When can we return +// unimplemented? PJRT_Error* PJRT_LoadedExecutable_Fingerprint( PJRT_LoadedExecutable_Fingerprint_Args* args); @@ -563,6 +568,7 @@ constexpr PJRT_Api CreatePjrtApi( pjrt::PJRT_Buffer_CopyToMemory, /*PJRT_Client_CreateViewOfDeviceBuffer=*/ pjrt::PJRT_Client_CreateViewOfDeviceBuffer, + /*PJRT_Executable_Fingerprint=*/pjrt::PJRT_Executable_Fingerprint, }; } diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc index a8c82265fd17b0..2e2266db53696a 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc @@ -1140,6 +1140,19 @@ StatusOr PjRtCApiExecutable::SerializeExecutable() const { return std::string(ser_args.serialized_bytes, ser_args.serialized_bytes_size); } +StatusOr PjRtCApiExecutable::FingerprintExecutable() const { + PJRT_Executable_Fingerprint_Args args; + args.struct_size = PJRT_Executable_Fingerprint_Args_STRUCT_SIZE; + args.priv = nullptr; + args.executable = c_executable(); + + RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_Executable_Fingerprint(&args), + c_api_); + + return std::string(args.executable_fingerprint, + args.executable_fingerprint_size); +} + // ------------------------ Loaded Executables --------------------------------- PjRtCApiLoadedExecutable::PjRtCApiLoadedExecutable( diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.h b/third_party/xla/xla/pjrt/pjrt_c_api_client.h index 1b1eecf8e188b5..378ec3cf3f5e4f 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.h +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.h @@ -518,6 +518,8 @@ class PjRtCApiExecutable : public PjRtExecutable { StatusOr SerializeExecutable() const override; + StatusOr FingerprintExecutable() const override; + private: const PJRT_Api* c_api_; std::unique_ptr executable_;