diff --git a/tensorflow/lite/experimental/lrt/BUILD b/tensorflow/lite/experimental/lrt/BUILD index 7ea120900c12f3..cd9efefb75dcab 100644 --- a/tensorflow/lite/experimental/lrt/BUILD +++ b/tensorflow/lite/experimental/lrt/BUILD @@ -16,19 +16,3 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//tensorflow/lite/experimental/lrt:__subpackages__"], ) - -cc_binary( - name = "apply_plugin", - srcs = [ - "apply_plugin.cc", - # TODO: b/366821557 - Support pre-compiled plugins as data dependencies. - "//tensorflow/lite/experimental/lrt/vendors/examples:example_plugin_so", - ], - deps = [ - "//tensorflow/lite/experimental/lrt/core:api_internal", - "//tensorflow/lite/experimental/lrt/core:lite_rt_model_init", - "//tensorflow/lite/experimental/lrt/core:model", - "//tensorflow/lite/experimental/lrt/core/compiler_plugin:algo", - "@llvm-project//llvm:Support", - ], -) diff --git a/tensorflow/lite/experimental/lrt/apply_plugin.cc b/tensorflow/lite/experimental/lrt/apply_plugin.cc deleted file mode 100644 index 934ebae5bdf895..00000000000000 --- a/tensorflow/lite/experimental/lrt/apply_plugin.cc +++ /dev/null @@ -1,204 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include -#include - -#include "llvm/Support/CommandLine.h" -#include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/lite/experimental/lrt/c/lite_rt_support.h" -#include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/lite/experimental/lrt/core/compiler_plugin/algo.h" -#include "tensorflow/lite/experimental/lrt/core/lite_rt_model_init.h" -#include "tensorflow/lite/experimental/lrt/core/model.h" -#include "tensorflow/lite/experimental/lrt/vendors/c/lite_rt_compiler_plugin.h" - -using ::lrt::internal::GroupPartitions; -using ::lrt::internal::OutlinePartition; - -// NOLINTNEXTLINE -static llvm::cl::opt model_path( - "model_path", llvm::cl::desc("Path to flatbuffer file."), - llvm::cl::init("")); - -// TODO: b/366821557 - Support path to pre-compiled plugin in flags. -// NOLINTNEXTLINE -static llvm::cl::opt soc_manufacturer( - "soc_man", - llvm::cl::desc( - "String identifier of SoC manufacturer (e.g., Pixel, Qualcomm)."), - llvm::cl::init("ExampleSocManufacturer")); - -// NOLINTNEXTLINE -static llvm::cl::opt soc_model("soc_model", - llvm::cl::desc("Target SoC model."), - llvm::cl::init("ExampleSocModel")); - -// TODO swap "dry_run" for optional "don't delete partitioned subgraphs". -// NOLINTNEXTLINE -static llvm::cl::opt dry_run( - "dry_run", - llvm::cl::desc( - "Only run \"partition\" phase and output the spliced out subgraphs."), - llvm::cl::init(true)); - -#define EXIT_IF_NULL(val, msg) \ - if (!val) { \ - std::cerr << msg << "\n"; \ - return 1; \ - } - -bool IsSocModelSupported(LrtCompilerPlugin plugin, - std::string_view requested_soc_model) { - const auto num_supported_configs = LrtPluginNumSupportedSocModels(plugin); - for (int i = 0; i < num_supported_configs; ++i) { - const char* soc_model; - LRT_RETURN_VAL_IF_NOT_OK( - LrtPluginGetSupportedSocModel(plugin, i, &soc_model), false); - if (requested_soc_model == soc_model) { - return true; - } - } - - return false; -} - -// TODO: b/366821557 - Replace loading pre-compiled plugin. -UniqueLrtCompilerPlugin LoadPlugin() { - if (soc_manufacturer != LrtPluginSocManufacturer()) { - std::cerr << "Only Test currently supported"; - return nullptr; - } - - LrtCompilerPlugin plugin; - LRT_RETURN_VAL_IF_NOT_OK(LrtPluginInit(&plugin), nullptr); - auto result = UniqueLrtCompilerPlugin(plugin); - - if (!IsSocModelSupported(result.get(), soc_model)) { - std::cerr << "Only ExampleSocModel currently supported\n"; - return nullptr; - } - - return result; -} - -UniqueLrtModel LoadModel(std::string_view filename) { - LrtModel model; - LRT_RETURN_VAL_IF_NOT_OK(LoadModelFromFile(filename.data(), &model), nullptr); - return UniqueLrtModel(model); -} - -LrtStatus ApplyPlugin(LrtModel model, LrtCompilerPlugin plugin, - std::string_view soc_model) { - LRT_RETURN_STATUS_IF_NOT_OK( - RegisterCustomOpCode(model, LrtPluginSocManufacturer())); - - LrtOpListT selected_ops; - LRT_RETURN_STATUS_IF_NOT_OK( - LrtPluginPartitionModel(plugin, model, &selected_ops)); - - auto partitions = GroupPartitions(selected_ops.ops); - - // TODO: b/366821557 - Support multiple subgraphs in plugin application. - auto& main_subgraph = model->subgraphs.front(); - - std::vector slices; - std::vector custom_ops; - slices.reserve(partitions.size()); - custom_ops.reserve(partitions.size()); - - for (auto& partition : partitions) { - LrtSubgraph new_subgraph = &model->subgraphs.emplace_back(); - - LrtOp custom_op = OutlinePartition(main_subgraph, new_subgraph, partition); - custom_ops.push_back(custom_op); - slices.push_back(new_subgraph); - } - - if (dry_run) { - return kLrtStatusOk; - } - - LrtCompiledResult compiled_result; - LRT_RETURN_STATUS_IF_NOT_OK(LrtPluginCompile(plugin, soc_model.data(), - slices.data(), slices.size(), - &compiled_result)); - - lrt_param_index_t num_calls_compiled; - LRT_RETURN_STATUS_IF_NOT_OK( - LrtCompiledResultGetNumCalls(compiled_result, &num_calls_compiled)); - - if (num_calls_compiled != slices.size()) { - std::cerr - << "Plugin must provide and entry point for each compiled partition\n"; - return kLrtStatusErrorNotFound; - } - - for (int i = 0; i < num_calls_compiled; ++i) { - const void* call_info; - size_t call_info_size; - - LRT_RETURN_STATUS_IF_NOT_OK(LrtCompiledResultGetCallInfo( - compiled_result, i, &call_info, &call_info_size)); - - auto* custom_op = custom_ops.at(i); - custom_op->custom_options.assign(reinterpret_cast(call_info), - call_info_size); - } - - const void* byte_code; - size_t byte_code_size; - - LRT_RETURN_STATUS_IF_NOT_OK(LrtCompiledResultGetByteCode( - compiled_result, &byte_code, &byte_code_size)); - - LRT_RETURN_STATUS_IF_NOT_OK(AppendMetadata(model, byte_code, byte_code_size, - LrtPluginSocManufacturer())); - - return kLrtStatusOk; -} - -int main(int argc, char** argv) { - llvm::cl::ParseCommandLineOptions(argc, argv); - - auto model = LoadModel(model_path); - EXIT_IF_NULL(model, "Failed to load model"); - - auto plugin = LoadPlugin(); - EXIT_IF_NULL(plugin, "Failed to load plugin."); - - LRT_RETURN_VAL_IF_NOT_OK(ApplyPlugin(model.get(), plugin.get(), soc_model), - 1); - - uint8_t* buf; - size_t buf_size; - size_t buf_offset; - - LRT_RETURN_VAL_IF_NOT_OK( - SerializeModel(model.release(), &buf, &buf_size, &buf_offset), 1); - - std::string out(reinterpret_cast(buf) + buf_offset, - buf_size - buf_offset); - std::cout << out; - - delete[] buf; - - return 0; -} diff --git a/tensorflow/lite/experimental/lrt/c/lite_rt_common.h b/tensorflow/lite/experimental/lrt/c/lite_rt_common.h index 21d1cc8c4bd128..469423d1ec98bc 100644 --- a/tensorflow/lite/experimental/lrt/c/lite_rt_common.h +++ b/tensorflow/lite/experimental/lrt/c/lite_rt_common.h @@ -60,17 +60,20 @@ typedef enum { kLrtStatusErrorTimeoutExpired = 7, // File and loading related errors. - kLrtStatusBadFileOp = 500, - kLrtStatusFlatbufferFailedVerify = 501, - kLrtStatusDynamicLoadErr = 502, + kLrtStatusErrorFileIO = 500, + kLrtStatusErrorInvalidFlatbuffer = 501, + kLrtStatusErrorDynamicLoading = 502, + kLrtStatusSerializationErr = 503, + kLrtStatusCompilationError = 504, // IR related errors. - kLrtStatusParamIndexOOB = 1000, - kLrtStatusBadTensorType = 1001, - kLrtStatusGraphInvariantError = 1002, + kLrtStatusErrorIndexOOB = 1000, + kLrtStatusErrorInvalidIrType = 1001, + kLrtStatusErrorInvalidGraphInvariant = 1002, + kLrtStatusErrorGraphModification = 1003, // Tool related errors. - kLrtStatusToolBadConfig = 1500, + kLrtStatusErrorInvalidToolConfig = 1500, } LrtStatus; #ifdef __cplusplus diff --git a/tensorflow/lite/experimental/lrt/cc/BUILD b/tensorflow/lite/experimental/lrt/cc/BUILD index ff789d628d00cf..bc0ac22abf549f 100644 --- a/tensorflow/lite/experimental/lrt/cc/BUILD +++ b/tensorflow/lite/experimental/lrt/cc/BUILD @@ -28,7 +28,6 @@ cc_library( "//tensorflow/compiler/mlir/lite/core:model_builder_base", "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/experimental/lrt/c:lite_rt_c_api", - "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/lite/experimental/lrt/cc/lite_rt_support.h b/tensorflow/lite/experimental/lrt/cc/lite_rt_support.h index 7dc5d305f8b51e..5e05a52765dd6b 100644 --- a/tensorflow/lite/experimental/lrt/cc/lite_rt_support.h +++ b/tensorflow/lite/experimental/lrt/cc/lite_rt_support.h @@ -132,6 +132,16 @@ class LrtResult { } \ decl = result.Value(); +#define _MOVE_OR_BLOCK(decl, expr, block, result) \ + auto result = (expr); \ + if (!result.HasValue()) { \ + block; \ + } \ + decl = std::move(result.Value()); + +#define _MOVE_OR_RETURN_VAL(decl, expr, val, result) \ + _MOVE_OR_BLOCK(decl, expr, _RETURN_VAL(val), result) + #define _ASSIGN_OR_RETURN_VAL(decl, expr, val, result) \ _ASSIGN_OR_BLOCK(decl, expr, _RETURN_VAL(val), result) @@ -144,10 +154,17 @@ class LrtResult { #define _ASSIGN_OR_RETURN_STATUS(decl, expr, result) \ _ASSIGN_OR_RETURN_VAL(decl, expr, _STATUS_FROM_RESULT(result), result) +#define _MOVE_OR_RETURN_STATUS(decl, expr, result) \ + _MOVE_OR_RETURN_VAL(decl, expr, _STATUS_FROM_RESULT(result), result) + // Assign value behind result returned from expr. If not ok, return status. #define LRT_ASSIGN_OR_RETURN_STATUS(decl, expr) \ _ASSIGN_OR_RETURN_STATUS(decl, expr, _CONCAT_NAME(_result, __COUNTER__)) +// Assign value behind result returned from expr. If not ok, return status. +#define LRT_MOVE_OR_RETURN_STATUS(decl, expr) \ + _MOVE_OR_RETURN_STATUS(decl, expr, _CONCAT_NAME(_result, __COUNTER__)) + #define _FORWARD_RESULT(result, ty) LrtResult::FromStatus(result.Status()); #define _ASSIGN_OR_RETURN_RESULT(decl, expr, ty, result) \ @@ -157,6 +174,13 @@ class LrtResult { #define LRT_ASSIGN_OR_RETURN_RESULT(decl, expr, ty) \ _ASSIGN_OR_RETURN_RESULT(decl, expr, ty, _CONCAT_NAME(_result, __COUNTER__)) +#define _MOVE_OR_RETURN_RESULT(decl, expr, ty, result) \ + _MOVE_OR_RETURN_VAL(decl, expr, _FORWARD_RESULT(result, ty), result) + +// Move value behind result returned from expr. If not ok, return result. +#define LRT_MOVE_OR_RETURN_RESULT(decl, expr, ty) \ + _MOVE_OR_RETURN_RESULT(decl, expr, ty, _CONCAT_NAME(_result, __COUNTER__)) + #define LRT_ENSURE_SUPPORTED(cond, msg) \ if (!(cond)) { \ std::cerr << __FILE__ << ":" << __LINE__ << " " << msg << "\n"; \ diff --git a/tensorflow/lite/experimental/lrt/core/compiler_plugin/compiler_plugin.cc b/tensorflow/lite/experimental/lrt/core/compiler_plugin/compiler_plugin.cc index 171ef065ec058d..5950a02f41108f 100644 --- a/tensorflow/lite/experimental/lrt/core/compiler_plugin/compiler_plugin.cc +++ b/tensorflow/lite/experimental/lrt/core/compiler_plugin/compiler_plugin.cc @@ -138,14 +138,14 @@ CompilerPlugin::ResultT CompilerPlugin::LoadPlugin( if (OpenLib(lib_path, &plugin.lib_handle_) != kLrtStatusOk) { LITE_RT_LOG(LRT_WARNING, "Failed to load plugin at: %s", lib_path.data()); - return ResultT::FromStatus(kLrtStatusDynamicLoadErr); + return ResultT::FromStatus(kLrtStatusErrorDynamicLoading); } if (ResolvePluginApi(plugin.lib_handle_, plugin.plugin_api_) != kLrtStatusOk) { LITE_RT_LOG(LRT_WARNING, "Failed to resolve plugin api at: %s", lib_path.data()); - return ResultT::FromStatus(kLrtStatusDynamicLoadErr); + return ResultT::FromStatus(kLrtStatusErrorDynamicLoading); } if (plugin.plugin_api_.init(&plugin.plugin_handle_) != kLrtStatusOk) { @@ -155,7 +155,7 @@ CompilerPlugin::ResultT CompilerPlugin::LoadPlugin( LITE_RT_LOG(LRT_WARNING, "Failed to close loaded library at: %s", lib_path.data()); } - return ResultT::FromStatus(kLrtStatusDynamicLoadErr); + return ResultT::FromStatus(kLrtStatusErrorDynamicLoading); } // This should never change throughout the lifetime of the compiler @@ -235,8 +235,7 @@ LrtResult> CompilerPlugin::PartitionModel( LRT_RETURN_RESULT_IF_NOT_OK( plugin_api_.partition_model(plugin_handle_, c_model, &ops), std::vector); - - return LrtResult>::TakeValue(std::move(ops.ops)); + return LrtResult>::TakeValue(ops.Vec()); } LrtStatus CompilerPlugin::Compile(const absl::string_view soc_model, diff --git a/tensorflow/lite/experimental/lrt/core/dynamic_loading.cc b/tensorflow/lite/experimental/lrt/core/dynamic_loading.cc index 643c56c425eb81..150d7039ca13b0 100644 --- a/tensorflow/lite/experimental/lrt/core/dynamic_loading.cc +++ b/tensorflow/lite/experimental/lrt/core/dynamic_loading.cc @@ -46,7 +46,7 @@ LrtStatus OpenLib(absl::string_view so_path, void** lib_handle) { "Failed to load .so at path: %s, with error:\n\t %s\n", so_path, ::dlerror()); - return kLrtStatusDynamicLoadErr; + return kLrtStatusErrorDynamicLoading; } *lib_handle = res; return kLrtStatusOk; @@ -55,7 +55,7 @@ LrtStatus OpenLib(absl::string_view so_path, void** lib_handle) { LrtStatus CloseLib(void* lib_handle) { if (0 != ::dlclose(lib_handle)) { LITE_RT_LOG(LRT_ERROR, "Failed to close .so with error: %s", ::dlerror()); - return kLrtStatusDynamicLoadErr; + return kLrtStatusErrorDynamicLoading; } return kLrtStatusOk; } diff --git a/tensorflow/lite/experimental/lrt/core/dynamic_loading.h b/tensorflow/lite/experimental/lrt/core/dynamic_loading.h index 2a12f3f6470fae..07ce6e887dd2f0 100644 --- a/tensorflow/lite/experimental/lrt/core/dynamic_loading.h +++ b/tensorflow/lite/experimental/lrt/core/dynamic_loading.h @@ -41,7 +41,7 @@ inline static LrtStatus ResolveLibSymbol(void* lib_handle, if (ptr == nullptr) { LITE_RT_LOG(LRT_ERROR, "Faild to resolve symbol: %s, with err: %s\n", sym_name, ::dlerror()); - return kLrtStatusDynamicLoadErr; + return kLrtStatusErrorDynamicLoading; } *sym_handle = ptr; return kLrtStatusOk; diff --git a/tensorflow/lite/experimental/lrt/core/graph_tools.h b/tensorflow/lite/experimental/lrt/core/graph_tools.h index c3c58c1a778af4..59375ef2ccbc93 100644 --- a/tensorflow/lite/experimental/lrt/core/graph_tools.h +++ b/tensorflow/lite/experimental/lrt/core/graph_tools.h @@ -101,7 +101,8 @@ inline LrtResult> GetTensorUses( inline LrtResult GetTensorOnlyUse(LrtTensor tensor) { LRT_ASSIGN_OR_RETURN_RESULT(auto uses, GetTensorUses(tensor), TensorUseInfo); if (uses.size() != 1) { - return LrtResult::FromStatus(kLrtStatusGraphInvariantError); + return LrtResult::FromStatus( + kLrtStatusErrorInvalidGraphInvariant); } return LrtResult::FromValue(uses[0]); } @@ -123,7 +124,8 @@ inline LrtResult> GetOpIns(LrtOp op) { inline LrtResult GetOnlyOpIn(LrtOp op) { LRT_ASSIGN_OR_RETURN_RESULT(auto ins, GetOpIns(op), LrtTensor); if (ins.size() != 1) { - return LrtResult::FromStatus(kLrtStatusGraphInvariantError); + return LrtResult::FromStatus( + kLrtStatusErrorInvalidGraphInvariant); } return LrtResult::FromValue(ins[0]); } @@ -145,7 +147,8 @@ inline LrtResult> GetOpOuts(LrtOp op) { inline LrtResult GetOnlyOpOut(LrtOp op) { LRT_ASSIGN_OR_RETURN_RESULT(auto outs, GetOpOuts(op), LrtTensor); if (outs.size() != 1) { - return LrtResult::FromStatus(kLrtStatusGraphInvariantError); + return LrtResult::FromStatus( + kLrtStatusErrorInvalidGraphInvariant); } return LrtResult::FromValue(outs[0]); } diff --git a/tensorflow/lite/experimental/lrt/core/lite_rt_model_init.cc b/tensorflow/lite/experimental/lrt/core/lite_rt_model_init.cc index 116a263bc70fdc..0f3a3af77babc1 100644 --- a/tensorflow/lite/experimental/lrt/core/lite_rt_model_init.cc +++ b/tensorflow/lite/experimental/lrt/core/lite_rt_model_init.cc @@ -54,7 +54,7 @@ LrtStatus VerifyFlatbuffer(const uint8_t* buf, size_t buf_size) { flatbuffers::Verifier verifier(buf, buf_size, options); if (!tflite::VerifyModelBuffer(verifier)) { _LRT_D_MSG("Failed to verify fb"); - return kLrtStatusFlatbufferFailedVerify; + return kLrtStatusErrorInvalidFlatbuffer; } return kLrtStatusOk; } @@ -332,7 +332,7 @@ LrtStatus LoadModelFromFile(const char* path, LrtModel* model) { std::unique_ptr alloc = tflite::GetAllocationFromFile(path, tflite::DefaultErrorReporter()); if (!alloc->valid()) { - return kLrtStatusBadFileOp; + return kLrtStatusErrorFileIO; } return LoadModel(reinterpret_cast(alloc->base()), diff --git a/tensorflow/lite/experimental/lrt/core/model.cc b/tensorflow/lite/experimental/lrt/core/model.cc index 2a89b339e93630..84bc13b3969fa0 100644 --- a/tensorflow/lite/experimental/lrt/core/model.cc +++ b/tensorflow/lite/experimental/lrt/core/model.cc @@ -33,7 +33,7 @@ LrtStatus GetModelNumSubgraphs(LrtModel model, LrtStatus GetModelSubgraph(LrtModel model, lrt_param_index_t subgraph_index, LrtSubgraph* subgraph) { if (subgraph_index >= model->subgraphs.size()) { - return kLrtStatusParamIndexOOB; + return kLrtStatusErrorIndexOOB; } *subgraph = model->subgraphs.data() + subgraph_index; return kLrtStatusOk; @@ -46,10 +46,14 @@ LrtStatus GetModelMainSubgraph(LrtModel model, return kLrtStatusOk; } -void ModelDestroy(LrtModel model) { delete model; } +void ModelDestroy(LrtModel model) { + if (model != nullptr) { + delete model; + } +} LrtStatus PushOp(LrtOpList op_list, LrtOp op) { - op_list->ops.push_back(op); + op_list->Push(op); return kLrtStatusOk; } @@ -149,7 +153,7 @@ LrtStatus GetTensorTypeId(LrtTensor tensor, LrtTensorTypeId* type_id) { LrtStatus GetUrankedTensorType(LrtTensor tensor, LrtUnrankedTensorType* unranked_tensor_type) { if (tensor->type_id != kLrtUnrankedTensorType) { - return kLrtStatusBadTensorType; + return kLrtStatusErrorInvalidIrType; } *unranked_tensor_type = tensor->type_detail.unranked_tensor_type; return kLrtStatusOk; @@ -158,7 +162,7 @@ LrtStatus GetUrankedTensorType(LrtTensor tensor, LrtStatus GetRankedTensorType(LrtTensor tensor, LrtRankedTensorType* ranked_tensor_type) { if (tensor->type_id != kLrtRankedTensorType) { - return kLrtStatusBadTensorType; + return kLrtStatusErrorInvalidIrType; } *ranked_tensor_type = tensor->type_detail.ranked_tensor_type; return kLrtStatusOk; diff --git a/tensorflow/lite/experimental/lrt/core/model.h b/tensorflow/lite/experimental/lrt/core/model.h index 6cf61810798346..b54bf04a705890 100644 --- a/tensorflow/lite/experimental/lrt/core/model.h +++ b/tensorflow/lite/experimental/lrt/core/model.h @@ -136,8 +136,23 @@ struct LrtModelT { // // Used for communicating selections of ops. -struct LrtOpListT { - std::vector ops; +class LrtOpListT { + public: + void Push(LrtOp op) { ops_.push_back(op); } + + std::vector Vec() const { + std::vector res; + res.reserve(ops_.size()); + res.assign(ops_.begin(), ops_.end()); + return res; + } + + private: + // NOTE: This was originally a vector. Was encountering really odd + // segfaults when freeing after code on another side of a compilation boundary + // was doing pushes that resized. A list+copy to vector is not optimimal, + // revisit if bottleneck. + std::list ops_; }; #endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_CORE_MODEL_H_ diff --git a/tensorflow/lite/experimental/lrt/core/model_test.cc b/tensorflow/lite/experimental/lrt/core/model_test.cc index 0c085c5d9d88d4..73452a6b729779 100644 --- a/tensorflow/lite/experimental/lrt/core/model_test.cc +++ b/tensorflow/lite/experimental/lrt/core/model_test.cc @@ -23,7 +23,6 @@ #include // IWYU pragma: keep #include -#include "flatbuffers/verifier.h" // from @flatbuffers #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" @@ -37,11 +36,7 @@ namespace { -inline bool VerifyFlatbuffer(const uint8_t* buf, size_t buf_size) { - flatbuffers::Verifier::Options options; - flatbuffers::Verifier verifier(buf, buf_size, options); - return tflite::VerifyModelBuffer(verifier); -} +using ::lrt::testing::VerifyFlatbuffer; inline UniqueLrtModel LoadModelThroughRoundTrip(std::string_view path) { auto model = lrt::testing::LoadTestFileModel(path); @@ -83,7 +78,7 @@ class TopologyTest : public ::testing::TestWithParam { TEST(LrtModelTest, TestLoadTestDataBadFilepath) { LrtModel model = nullptr; ASSERT_STATUS_HAS_CODE(LoadModelFromFile("bad_path", &model), - kLrtStatusBadFileOp); + kLrtStatusErrorFileIO); } TEST(LrtModelTest, TestLoadTestDataBadFileData) { @@ -103,7 +98,7 @@ TEST(LrtModelTest, TestLoadTestDataBadFileData) { LrtModel model = nullptr; ASSERT_STATUS_HAS_CODE(LoadModelFromFile(test_file_path.c_str(), &model), - kLrtStatusFlatbufferFailedVerify); + kLrtStatusErrorInvalidFlatbuffer); // NOLINTEND } diff --git a/tensorflow/lite/experimental/lrt/test/BUILD b/tensorflow/lite/experimental/lrt/test/BUILD index 5fded2cfa5bbf5..76d18768127aa9 100644 --- a/tensorflow/lite/experimental/lrt/test/BUILD +++ b/tensorflow/lite/experimental/lrt/test/BUILD @@ -52,10 +52,12 @@ cc_library( "//tensorflow/lite/experimental/lrt/c:lite_rt_c_api", "//tensorflow/lite/experimental/lrt/cc:lite_rt_cc_api", "//tensorflow/lite/experimental/lrt/core:lite_rt_model_init", + "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", + "@flatbuffers//:runtime_cc", "@local_tsl//tsl/platform", ], ) diff --git a/tensorflow/lite/experimental/lrt/test/common.cc b/tensorflow/lite/experimental/lrt/test/common.cc index 66590ed122878c..8e3c258b66f3d0 100644 --- a/tensorflow/lite/experimental/lrt/test/common.cc +++ b/tensorflow/lite/experimental/lrt/test/common.cc @@ -14,6 +14,8 @@ #include "tensorflow/lite/experimental/lrt/test/common.h" +#include +#include // NOLINTNEXTLINE #include #include @@ -24,9 +26,11 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "flatbuffers/verifier.h" // from @flatbuffers #include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" #include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" #include "tensorflow/lite/experimental/lrt/core/lite_rt_model_init.h" +#include "tensorflow/lite/schema/schema_generated.h" #include "tsl/platform/platform.h" namespace lrt { @@ -79,5 +83,11 @@ void TouchTestFile(absl::string_view filename, absl::string_view dir) { std::ofstream f(path); } +bool VerifyFlatbuffer(const uint8_t* buf, size_t buf_size) { + flatbuffers::Verifier::Options options; + flatbuffers::Verifier verifier(buf, buf_size, options); + return tflite::VerifyModelBuffer(verifier); +} + } // namespace testing } // namespace lrt diff --git a/tensorflow/lite/experimental/lrt/test/common.h b/tensorflow/lite/experimental/lrt/test/common.h index f2b20bfbdc04dd..d342b8a111655c 100644 --- a/tensorflow/lite/experimental/lrt/test/common.h +++ b/tensorflow/lite/experimental/lrt/test/common.h @@ -58,6 +58,8 @@ UniqueLrtModel LoadTestFileModel(absl::string_view filename); void TouchTestFile(absl::string_view filename, absl::string_view dir); +bool VerifyFlatbuffer(const uint8_t* buf, size_t buf_size); + } // namespace testing } // namespace lrt diff --git a/tensorflow/lite/experimental/lrt/tools/BUILD b/tensorflow/lite/experimental/lrt/tools/BUILD index 83463fbc8eae7e..70625215ff676b 100644 --- a/tensorflow/lite/experimental/lrt/tools/BUILD +++ b/tensorflow/lite/experimental/lrt/tools/BUILD @@ -19,11 +19,19 @@ package( cc_library( name = "apply_plugin", + testonly = 1, srcs = ["apply_plugin.cc"], hdrs = ["apply_plugin.h"], deps = [ + ":dump", + ":tool_display", "//tensorflow/lite/experimental/lrt/c:lite_rt_c_api", "//tensorflow/lite/experimental/lrt/cc:lite_rt_cc_api", + "//tensorflow/lite/experimental/lrt/core:lite_rt_model_init", + "//tensorflow/lite/experimental/lrt/core/compiler_plugin", + "//tensorflow/lite/experimental/lrt/core/compiler_plugin:algo", + "//tensorflow/lite/experimental/lrt/test:common", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", @@ -33,10 +41,20 @@ cc_library( cc_test( name = "apply_plugin_test", srcs = ["apply_plugin_test.cc"], - data = ["//tensorflow/lite/experimental/lrt/vendors/examples:example_plugin_so"], + data = [ + "//tensorflow/lite/experimental/lrt/test:tflite_test_data", + "//tensorflow/lite/experimental/lrt/vendors/examples:example_plugin_so", + ], + tags = [ + "noasan", + "nomsan", + "nosan", + ], deps = [ ":apply_plugin", "//tensorflow/lite/experimental/lrt/c:lite_rt_c_api", + "//tensorflow/lite/experimental/lrt/core:lite_rt_model_init", + "//tensorflow/lite/experimental/lrt/core:model", "//tensorflow/lite/experimental/lrt/test:common", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings:string_view", @@ -44,6 +62,46 @@ cc_test( ], ) +cc_binary( + name = "apply_plugin_main", + testonly = 1, + srcs = ["apply_plugin_main.cc"], + data = ["//tensorflow/lite/experimental/lrt/vendors/examples:example_plugin_so"], + linkstatic = 1, + tags = [ + "noasan", + "nomsan", + "nosan", + ], + deps = [ + ":apply_plugin", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + ], +) + +cc_library( + name = "tool_display", + srcs = ["tool_display.cc"], + hdrs = ["tool_display.h"], + deps = [ + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "tool_display_test", + srcs = ["tool_display_test.cc"], + data = ["//tensorflow/lite/experimental/lrt/test:tflite_test_data"], + deps = [ + ":tool_display", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "dump", srcs = ["dump.cc"], @@ -65,7 +123,6 @@ cc_test( ":dump", "//tensorflow/lite/experimental/lrt/core:model", "//tensorflow/lite/experimental/lrt/test:common", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/lite/experimental/lrt/tools/apply_plugin.cc b/tensorflow/lite/experimental/lrt/tools/apply_plugin.cc index 21edb2733e8d68..5475868be323ca 100644 --- a/tensorflow/lite/experimental/lrt/tools/apply_plugin.cc +++ b/tensorflow/lite/experimental/lrt/tools/apply_plugin.cc @@ -14,34 +14,62 @@ #include "tensorflow/lite/experimental/lrt/tools/apply_plugin.h" +#include +#include #include #include +#include +#include #include +#include +#include "absl/log/absl_check.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" #include "tensorflow/lite/experimental/lrt/c/lite_rt_support.h" #include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" +#include "tensorflow/lite/experimental/lrt/core/compiler_plugin/algo.h" +#include "tensorflow/lite/experimental/lrt/core/compiler_plugin/compiler_plugin.h" +#include "tensorflow/lite/experimental/lrt/core/lite_rt_model_init.h" +#include "tensorflow/lite/experimental/lrt/test/common.h" +#include "tensorflow/lite/experimental/lrt/tools/dump.h" +#include "tensorflow/lite/experimental/lrt/tools/tool_display.h" namespace lrt::tools { -#define _ENSURE_CONFIG(expr) \ - if (!(expr)) { \ - return kLrtStatusToolBadConfig; \ - } - +using ::lrt::internal::CompilerPlugin; +using ::lrt::internal::Dump; +using ::lrt::internal::GroupPartitions; +using ::lrt::internal::OutlinePartition; +using ::lrt::testing::VerifyFlatbuffer; using ::lrt::tools::ApplyPluginRun; +#define _ENSURE_CONFIG(expr) \ + if (!(expr)) { \ + return kLrtStatusErrorInvalidToolConfig; \ + } + namespace { -class ApplyPluginContext { +static constexpr absl::string_view kArt = R"( + __ _ __ ____ __ + / / (_/ /____ / __ \/ /_ + / / / / __/ _ \/ /_/ / __/ + / /___/ / /_/ __/ _, _/ /_ +/_____/_/\__/\___/_/ |_|\__/ +)"; + +class Context { public: - using Ptr = std::unique_ptr; - using ResultT = LrtResult; + using Ptr = std::unique_ptr; + using ResultT = LrtResult; - explicit ApplyPluginContext(ApplyPluginRun::Ptr run) : run_(std::move(run)) {} + explicit Context(ApplyPluginRun::Ptr run) + : run_(std::move(run)), + display_(ToolDisplay(run_->dump_out, Context::CmdStr(run_->cmd))) {} ApplyPluginRun::Cmd Cmd() const { return run_->cmd; } @@ -50,41 +78,231 @@ class ApplyPluginContext { run_->lib_search_paths.size()); } - std::ostream& Dump() { - if (!run_->dump_out.has_value()) { - return null_stream_; - } - return *run_->dump_out; + absl::string_view SocModelTarget() const { + ABSL_CHECK_EQ(run_->soc_models.size(), 1); + return run_->soc_models.front(); + } + + std::ostream& Out() { + ABSL_CHECK_EQ(run_->outs.size(), 1); + return run_->outs.front(); + } + + ApplyPluginRun::OutStreamT SwapOut(ApplyPluginRun::OutStreamT out) { + ABSL_CHECK_EQ(run_->outs.size(), 1); + auto res = run_->outs.front(); + run_->outs.at(0) = out; + return res; } const ApplyPluginRun& Run() const { return *run_; } + ApplyPluginRun& Run() { return *run_; } + + ToolDisplay& Dump() { return display_; } void DumpPrelude(); + static absl::string_view CmdStr(ApplyPluginRun::Cmd cmd); + private: ApplyPluginRun::Ptr run_; - std::ostream null_stream_ = std::ostream(nullptr); + ToolDisplay display_; }; -void ApplyPluginContext::DumpPrelude() { - static constexpr absl::string_view kCmdTpl = "ApplyPlugin: %s\n"; - switch (Run().cmd) { +absl::string_view Context::CmdStr(ApplyPluginRun::Cmd cmd) { + switch (cmd) { case ApplyPluginRun::Cmd::INFO: - Dump() << absl::StreamFormat(kCmdTpl, "INFO"); - break; + return "INFO"; case ApplyPluginRun::Cmd::NOOP: - Dump() << absl::StreamFormat(kCmdTpl, "NOOP"); - break; + return "NOOP"; case ApplyPluginRun::Cmd::PARTITION: - Dump() << absl::StreamFormat(kCmdTpl, "PARTITION"); - break; + return "PARTITION"; case ApplyPluginRun::Cmd::COMPILE: - Dump() << absl::StreamFormat(kCmdTpl, "COMPILE"); - break; + return "COMPILE"; case ApplyPluginRun::Cmd::APPLY: - Dump() << absl::StreamFormat(kCmdTpl, "APPLY"); - break; + return "APPLY"; + } +} + +void Context::DumpPrelude() { + Dump().Display() << kArt << "\n"; + // TODO pretty print run struct. +} + +CompilerPlugin::ResultVecT LoadAllPlugins(Context* ctx) { + ctx->Dump().Start("Load Plugins"); + ctx->Dump().Labeled() << "Loading plugins from: "; + const auto paths = ctx->LibSearchPaths(); + for (auto it = paths.begin(); it < paths.end(); ++it) { + ctx->Dump().Display() << *it; + if (it < paths.end() - 1) { + ctx->Dump().Display() << ", "; + } + } + ctx->Dump().Display() << "\n"; + + auto plugins = CompilerPlugin::LoadPlugins(ctx->LibSearchPaths()); + if (!plugins.HasValue()) { + ctx->Dump().Fail(); + return plugins; + } + ctx->Dump().Labeled() << "Found plugins\n"; + ctx->Dump().Labeled() << absl::StreamFormat("Loaded %lu plugins\n", + plugins.Value().size()); + + ctx->Dump().Done(); + return plugins; +} + +CompilerPlugin::ResultT LoadPlugin(Context* ctx) { + LRT_MOVE_OR_RETURN_RESULT(auto plugins, LoadAllPlugins(ctx), CompilerPlugin); + ctx->Dump().Start("Select Plugin"); + + for (auto& plugin : plugins) { + if (plugin.SocManufacturer() == ctx->Run().soc_manufacturer) { + ctx->Dump().Done(); + return CompilerPlugin::ResultT::TakeValue(std::move(plugin)); + } + } + + ctx->Dump().Fail(); + return CompilerPlugin::ResultT::FromStatus(kLrtStatusErrorNotFound); +} + +LrtResult LoadModel(Context* ctx) { + ctx->Dump().Start("Load Model"); + ctx->Dump().Labeled() << absl::StreamFormat("Loading model from: %s\n", + ctx->Run().model.value()); + + LrtModel model; + if (LoadModelFromFile(ctx->Run().model->data(), &model) != kLrtStatusOk) { + ctx->Dump().Fail(); + return LrtResult::FromStatus(kLrtStatusErrorFileIO); + } + + ctx->Dump().Labeled(); + Dump(*model, ctx->Dump().Display()); + + ctx->Dump().Done(); + return LrtResult::TakeValue(UniqueLrtModel(model)); +} + +LrtStatus SerializeModel(Context* ctx, UniqueLrtModel model) { + ctx->Dump().Start("Serialize Model"); + + uint8_t* buf; + size_t size; + size_t offset; + if (SerializeModel(model.release(), &buf, &size, &offset) != kLrtStatusOk) { + delete[] buf; + ctx->Dump().Fail(); + return kLrtStatusSerializationErr; + } + + auto out_buf = buf + offset; + const size_t out_size = size - offset; + if (!VerifyFlatbuffer(out_buf, out_size)) { + ctx->Dump().Labeled() << "Failed to verify flatbuffer\n"; + ctx->Dump().Fail(); + delete[] buf; + return kLrtStatusErrorInvalidFlatbuffer; + } + + ctx->Out().write(reinterpret_cast(out_buf), out_size); + ctx->Dump().Labeled() << absl::StreamFormat( + "Serialized a model of size: %lu\n", out_size); + + delete[] buf; + + ctx->Dump().Done(); + return kLrtStatusOk; +} + +std::vector ApplyPartition(Context* ctx, LrtModelT& model, + CompilerPlugin& plugin) { + ctx->Dump().Start("Partition Model"); + LRT_RETURN_VAL_IF_NOT_OK( + RegisterCustomOpCode(&model, ctx->Run().soc_manufacturer->data()), {}); + + ctx->Dump().Labeled() << "Input model: \n"; + for (auto it = model.subgraphs.begin(); it < model.subgraphs.end(); ++it) { + ctx->Dump().Labeled(); + ctx->Dump().Indented() << "(input graph) "; + Dump(*it, ctx->Dump().Display()); + } + + auto partiion = plugin.PartitionModel(model); + if (!partiion.HasValue()) { + return {}; + } + auto grouped_partitions = GroupPartitions(partiion.Value()); + if (grouped_partitions.empty()) { + return {}; + } + ctx->Dump().Labeled() << absl::StreamFormat( + "Plugin selected %lu ops, yielding %lu partitions\n", + partiion.Value().size(), grouped_partitions.size()); + + std::vector res; + for (auto& partition : grouped_partitions) { + LrtOp custom_op = OutlinePartition( + model.subgraphs.front(), &model.subgraphs.emplace_back(), partition); + res.push_back(custom_op); + } + + ctx->Dump().Labeled() << "Partitioned model: \n"; + ctx->Dump().Labeled(); + ctx->Dump().Indented() << "(initial graph) "; + Dump(model.subgraphs.front(), ctx->Dump().Display()); + for (auto it = model.subgraphs.begin() + 1; it < model.subgraphs.end(); + ++it) { + ctx->Dump().Labeled(); + ctx->Dump().Indented() << "(new graph) "; + Dump(*it, ctx->Dump().Display()); } + + ctx->Dump().Done(); + return res; +} + +LrtResult PartitionModel(Context* ctx, UniqueLrtModel model, + CompilerPlugin& plugin) { + auto custom_ops = ApplyPartition(ctx, *model, plugin); + if (custom_ops.empty()) { + return LrtResult::FromStatus( + kLrtStatusErrorGraphModification); + } + return LrtResult::TakeValue(std::move(model)); +} + +LrtResult> CompilePartitions( + Context* ctx, std::vector& partitions, + CompilerPlugin& plugin) { + ctx->Dump().Start("Compile Model"); + ctx->Dump().Labeled() << absl::StreamFormat( + "Requesting compilation for target \"%s\" on %lu subgraphs\n", + ctx->SocModelTarget(), partitions.size()); + + std::vector call_info_out; + if (plugin.Compile(ctx->SocModelTarget(), partitions, ctx->Out(), + call_info_out) != kLrtStatusOk) { + ctx->Dump().Fail(); + return LrtResult>::FromStatus( + kLrtStatusCompilationError); + } + + ctx->Dump().Labeled() << "Entry point info: "; + for (auto it = call_info_out.begin(); it < call_info_out.end(); ++it) { + ctx->Dump().Display() << absl::StreamFormat("\"%s\"", *it); + if (it < call_info_out.end() - 1) { + ctx->Dump().Display() << ", "; + } + } + ctx->Dump().Display() << "\n"; + + ctx->Dump().Done(); + return LrtResult>::TakeValue( + std::move(call_info_out)); } // @@ -93,13 +311,24 @@ void ApplyPluginContext::DumpPrelude() { LrtStatus ValidateInfoRun(const ApplyPluginRun& run) { _ENSURE_CONFIG(!run.lib_search_paths.empty()); - _ENSURE_CONFIG(run.dump_out.has_value()); + _ENSURE_CONFIG(run.outs.size() == 1); return kLrtStatusOk; } -LrtStatus Info(ApplyPluginContext* context) { - // TODO - return kLrtStatusErrorUnsupported; +LrtStatus Info(Context* ctx) { + LRT_MOVE_OR_RETURN_STATUS(auto plugins, LoadAllPlugins(ctx)); + for (auto& plugin : plugins) { + ctx->Out() << absl::StreamFormat("< LrtCompilerPlugin > \"%s\" | ", + plugin.SocManufacturer()); + const auto& models = plugin.SocModels(); + for (auto it = models.begin(); it < models.end(); ++it) { + ctx->Out() << absl::StreamFormat("\"%s\"", *it); + if (it < models.end() - 1) { + ctx->Out() << ", "; + } + } + } + return kLrtStatusOk; } // @@ -112,9 +341,10 @@ LrtStatus ValidateNoopRun(const ApplyPluginRun& run) { return kLrtStatusOk; } -LrtStatus Noop(ApplyPluginContext* context) { - // TODO - return kLrtStatusErrorUnsupported; +LrtStatus Noop(Context* ctx) { + LRT_MOVE_OR_RETURN_STATUS(auto model, LoadModel(ctx)); + LRT_RETURN_STATUS_IF_NOT_OK(SerializeModel(ctx, std::move(model))); + return kLrtStatusOk; } // @@ -129,9 +359,14 @@ LrtStatus ValidatePartitionRun(const ApplyPluginRun& run) { return kLrtStatusOk; } -LrtStatus Partition(ApplyPluginContext* context) { - // TODO - return kLrtStatusErrorUnsupported; +LrtStatus Partition(Context* ctx) { + LRT_MOVE_OR_RETURN_STATUS(auto plugin, LoadPlugin(ctx)); + LRT_MOVE_OR_RETURN_STATUS(auto model, LoadModel(ctx)); + + LRT_MOVE_OR_RETURN_STATUS(auto new_model, + PartitionModel(ctx, std::move(model), plugin)); + LRT_RETURN_STATUS_IF_NOT_OK(SerializeModel(ctx, std::move(new_model))); + return kLrtStatusOk; } // @@ -153,9 +388,19 @@ LrtStatus ValidateCompileRun(const ApplyPluginRun& run) { return kLrtStatusOk; } -LrtStatus Compile(ApplyPluginContext* context) { - // TODO - return kLrtStatusErrorUnsupported; +LrtStatus Compile(Context* ctx) { + LRT_MOVE_OR_RETURN_STATUS(auto model, LoadModel(ctx)); + LRT_MOVE_OR_RETURN_STATUS(auto plugin, LoadPlugin(ctx)); + + std::vector compilation_input; + compilation_input.reserve(model->subgraphs.size()); + for (auto& subgraph : model->subgraphs) { + compilation_input.push_back(&subgraph); + } + LRT_MOVE_OR_RETURN_STATUS(auto entry_point_info, + CompilePartitions(ctx, compilation_input, plugin)); + + return kLrtStatusOk; } // @@ -177,15 +422,55 @@ LrtStatus ValidateApplyRun(const ApplyPluginRun& run) { return kLrtStatusOk; } -LrtStatus Apply(ApplyPluginContext* context) { - // TODO - return kLrtStatusErrorUnsupported; +LrtStatus Apply(Context* ctx) { + LRT_MOVE_OR_RETURN_STATUS(auto model, LoadModel(ctx)); + LRT_MOVE_OR_RETURN_STATUS(auto plugin, LoadPlugin(ctx)); + ctx->Dump().Labeled() << "Loaded assets\n"; + static constexpr size_t kNumInputSubgraphs = 1; + LRT_ENSURE_SUPPORTED(model->subgraphs.size() == kNumInputSubgraphs, + "Only single subgraph models currently supported."); + + auto custom_ops = ApplyPartition(ctx, *model, plugin); + LRT_ENSURE(!custom_ops.empty(), kLrtStatusErrorGraphModification, + "Failed to partiion graph."); + + std::vector compilation_input; + for (auto it = model->subgraphs.begin() + kNumInputSubgraphs; + it < model->subgraphs.end(); ++it) { + compilation_input.push_back(&*it); + } + + std::stringstream compilation_out; + ApplyPluginRun::OutStreamT out = ctx->SwapOut(compilation_out); + LRT_MOVE_OR_RETURN_STATUS(auto call_info, + CompilePartitions(ctx, compilation_input, plugin)); + LRT_ENSURE(call_info.size() == custom_ops.size(), kLrtStatusCompilationError, + "Failed to verify entry point information."); + + auto call_it = call_info.begin(); + auto custom_op_it = custom_ops.begin(); + for (; call_it < call_info.end() && custom_op_it < custom_ops.end();) { + (*custom_op_it)->custom_options.swap(*call_it); + ++call_it; + ++custom_op_it; + } + + model->subgraphs.resize(kNumInputSubgraphs); + + LRT_RETURN_STATUS_IF_NOT_OK(AppendMetadata( + model.get(), compilation_out.str().data(), compilation_out.str().size(), + plugin.SocManufacturer().data())); + + ctx->SwapOut(out); + LRT_RETURN_STATUS_IF_NOT_OK(SerializeModel(ctx, std::move(model))); + + return kLrtStatusOk; } } // namespace LrtStatus ApplyPlugin(ApplyPluginRun::Ptr run) { - ApplyPluginContext context(std::move(run)); + Context context(std::move(run)); context.DumpPrelude(); switch (context.Cmd()) { diff --git a/tensorflow/lite/experimental/lrt/tools/apply_plugin.h b/tensorflow/lite/experimental/lrt/tools/apply_plugin.h index e354564c451198..8405019a6f3ada 100644 --- a/tensorflow/lite/experimental/lrt/tools/apply_plugin.h +++ b/tensorflow/lite/experimental/lrt/tools/apply_plugin.h @@ -39,15 +39,14 @@ struct ApplyPluginRun { // A specific command implemented by the tool to run. enum class Cmd { // Displays info about all plugins found in given search paths. - // Writes all output to the "dump_out" stream. // // FLAG SEMANTICS: // "lib_search_paths": Required, at least one. // "model": Ignored. // "soc_manufacturer": Optional, filters plugins to display. // "soc_models": Ignored. - // "outs": Ignored. - // "dump_out": Required. + // "outs": Required, must be size one. + // "dump_out": Optional. // "serialization": Ignored. INFO, diff --git a/tensorflow/lite/experimental/lrt/tools/apply_plugin_main.cc b/tensorflow/lite/experimental/lrt/tools/apply_plugin_main.cc new file mode 100644 index 00000000000000..dc52b4a74431b6 --- /dev/null +++ b/tensorflow/lite/experimental/lrt/tools/apply_plugin_main.cc @@ -0,0 +1,127 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/CommandLine.h" +#include "tensorflow/lite/experimental/lrt/tools/apply_plugin.h" + +using ::lrt::tools::ApplyPlugin; +using ::lrt::tools::ApplyPluginRun; + +// NOLINTNEXTLINE +static llvm::cl::opt cmd( + llvm::cl::Positional, + llvm::cl::desc("Routine to run (apply, partition, compile, info, noop)."), + llvm::cl::init("partition")); + +// NOLINTNEXTLINE +static llvm::cl::opt model( + "model", llvm::cl::desc("Path to flatbuffer file."), llvm::cl::init("")); + +// TODO: b/366821557 - Support path to pre-compiled plugin in flags. +// NOLINTNEXTLINE +static llvm::cl::opt soc_manufacturer( + "soc_man", + llvm::cl::desc( + "String identifier of SoC manufacturer (e.g., Pixel, Qualcomm)."), + llvm::cl::init("ExampleSocManufacturer")); + +// TODO: Support multi target compilation. +// NOLINTNEXTLINE +static llvm::cl::opt soc_model("soc_model", + llvm::cl::desc("Target SoC model."), + llvm::cl::init("ExampleSocModel")); + +// NOLINTNEXTLINE +static llvm::cl::list libs( + "libs", + llvm::cl::desc("List of directories in which to search for suitable " + "compiler plugin shared libraries."), + llvm::cl::list_init(llvm::ArrayRef{ + "third_party/tensorflow/lite/experimental/lrt/vendors/examples"})); + +// NOLINTNEXTLINE +static llvm::cl::opt out( + "o", + llvm::cl::desc("Path to file for output, \"-\" indicates standard out."), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +static llvm::cl::opt err( + "err", + llvm::cl::desc("Path to file for error output, \"-\" indicates stdandard " + "error and \"none\" indicates silent."), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +static llvm::cl::opt serialization( + "serialization", llvm::cl::desc("Serialization strategy to use."), + llvm::cl::init("METADATA")); + +ApplyPluginRun::Ptr ParseFlags() { + auto res = std::make_unique(); + + std::ofstream file_out; + if (out != "-") { + file_out.open(out); + res->outs.clear(); + res->outs.push_back(file_out); + } + + std::ofstream file_err; + if (err != "-") { + file_err.open(err); + res->dump_out.emplace(file_err); + } + + if (!model.empty()) { + res->model = model; + } + + res->soc_manufacturer = soc_manufacturer; + res->soc_models.push_back(soc_model); + + res->lib_search_paths.assign(libs.begin(), libs.end()); + + if (cmd == "apply") { + res->cmd = ApplyPluginRun::Cmd::APPLY; + } else if (cmd == "partition") { + res->cmd = ApplyPluginRun::Cmd::PARTITION; + } else if (cmd == "compile") { + res->cmd = ApplyPluginRun::Cmd::COMPILE; + } else if (cmd == "info") { + res->cmd = ApplyPluginRun::Cmd::INFO; + } else if (cmd == "noop") { + res->cmd = ApplyPluginRun::Cmd::NOOP; + } + + return res; +} + +int main(int argc, char* argv[]) { + llvm::cl::ParseCommandLineOptions(argc, argv); + + auto run = ParseFlags(); + if (run == nullptr) { + return 1; + } + + return ApplyPlugin(std::move(run)); +} diff --git a/tensorflow/lite/experimental/lrt/tools/apply_plugin_test.cc b/tensorflow/lite/experimental/lrt/tools/apply_plugin_test.cc index 5e835206a80613..cdc89755649e96 100644 --- a/tensorflow/lite/experimental/lrt/tools/apply_plugin_test.cc +++ b/tensorflow/lite/experimental/lrt/tools/apply_plugin_test.cc @@ -14,20 +14,26 @@ #include "tensorflow/lite/experimental/lrt/tools/apply_plugin.h" +#include #include #include #include +#include #include #include "absl/log/absl_check.h" #include "absl/strings/string_view.h" #include "tensorflow/lite/experimental/lrt/c/lite_rt_common.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/core/lite_rt_model_init.h" +#include "tensorflow/lite/experimental/lrt/core/model.h" #include "tensorflow/lite/experimental/lrt/test/common.h" namespace { using ::lrt::tools::ApplyPlugin; using ::lrt::tools::ApplyPluginRun; +using ::testing::HasSubstr; static constexpr absl::string_view kPluginSearchPath = "third_party/tensorflow/lite/experimental/lrt/vendors/examples"; @@ -38,7 +44,7 @@ static constexpr absl::string_view kSocModel = "ExampleSocModel"; absl::string_view TestModelPath() { static char kModelPath[512] = {}; - if (kModelPath[0] != '\0') { + if (kModelPath[0] == '\0') { const auto model_path = ::lrt::testing::GetTestFilePath("one_mul.tflite"); ABSL_CHECK(model_path.size() < 512); model_path.copy(kModelPath, model_path.size(), 0); @@ -61,69 +67,96 @@ TEST(TestApplyPluginTool, TestInfoBadConfig) { auto run = MakeBaseRun(ApplyPluginRun::Cmd::INFO); run->dump_out = {}; run->lib_search_paths.clear(); - ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), kLrtStatusToolBadConfig); + ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), + kLrtStatusErrorInvalidToolConfig); } TEST(TestApplyPluginTool, TestInfo) { auto run = MakeBaseRun(ApplyPluginRun::Cmd::INFO); - ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), - kLrtStatusErrorUnsupported); + std::stringstream out; + run->outs.push_back(out); + ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + EXPECT_THAT( + out.str(), + ::testing::HasSubstr("< LrtCompilerPlugin > \"ExampleSocManufacturer\" | " + "\"ExampleSocModel\"")); } TEST(TestApplyPluginTool, TestNoopBadConfig) { auto run = MakeBaseRun(ApplyPluginRun::Cmd::NOOP); run->model.reset(); - ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), kLrtStatusToolBadConfig); + ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), + kLrtStatusErrorInvalidToolConfig); } TEST(TestApplyPluginTool, TestNoop) { auto run = MakeBaseRun(ApplyPluginRun::Cmd::NOOP); std::stringstream out; run->outs.push_back(out); - ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), - kLrtStatusErrorUnsupported); + ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + + LrtModel model; + ASSERT_STATUS_OK( + LoadModel(reinterpret_cast(out.view().data()), + out.view().size(), &model)); + UniqueLrtModel u_model(model); + + EXPECT_EQ(model->subgraphs.size(), 1); } TEST(TestApplyPluginTool, TestPartitionBadConfig) { auto run = MakeBaseRun(ApplyPluginRun::Cmd::PARTITION); run->model.reset(); - ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), kLrtStatusToolBadConfig); + ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), + kLrtStatusErrorInvalidToolConfig); } TEST(TestApplyPluginTool, TestPartition) { auto run = MakeBaseRun(ApplyPluginRun::Cmd::PARTITION); std::stringstream out; run->outs.push_back(out); - ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), - kLrtStatusErrorUnsupported); + ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + EXPECT_FALSE(out.str().empty()); } TEST(TestApplyPluginTool, TestCompileBadConfig) { auto run = MakeBaseRun(ApplyPluginRun::Cmd::COMPILE); run->model.reset(); - ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), kLrtStatusToolBadConfig); + ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), + kLrtStatusErrorInvalidToolConfig); } TEST(TestApplyPluginTool, TestCompile) { auto run = MakeBaseRun(ApplyPluginRun::Cmd::COMPILE); std::stringstream out; run->outs.push_back(out); - ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), - kLrtStatusErrorUnsupported); + ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + EXPECT_FALSE(out.str().empty()); + EXPECT_THAT(out.str(), HasSubstr("Partition_0_with_1_muls")); } TEST(TestApplyPluginTool, TestApplyBadConfig) { auto run = MakeBaseRun(ApplyPluginRun::Cmd::APPLY); run->model.reset(); - ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), kLrtStatusToolBadConfig); + ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), + kLrtStatusErrorInvalidToolConfig); } TEST(TestApplyPluginTool, TestApply) { auto run = MakeBaseRun(ApplyPluginRun::Cmd::APPLY); std::stringstream out; run->outs.push_back(out); - ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), - kLrtStatusErrorUnsupported); + ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + + LrtModel model; + ASSERT_STATUS_OK( + LoadModel(reinterpret_cast(out.view().data()), + out.view().size(), &model)); + UniqueLrtModel u_model(model); + + EXPECT_EQ(model->subgraphs.size(), 1); + ASSERT_EQ(model->flatbuffer_model->metadata.size(), 2); + EXPECT_EQ(model->flatbuffer_model->metadata[1]->name, kSocManufacturer); } } // namespace diff --git a/tensorflow/lite/experimental/lrt/tools/dump.cc b/tensorflow/lite/experimental/lrt/tools/dump.cc index dbb24979d5747d..4f7632ee721702 100644 --- a/tensorflow/lite/experimental/lrt/tools/dump.cc +++ b/tensorflow/lite/experimental/lrt/tools/dump.cc @@ -245,6 +245,11 @@ void Dump(void* lib_handle, std::ostream& out) { #endif } +void Dump(const LrtModelT& model, std::ostream& out) { + out << absl::StreamFormat("LrtModel : [ #subgraphs=%d ]\n", + model.subgraphs.size()); +} + void DumpOptions(const LrtOpT& op, std::ostream& out) { switch (op.op_code) { case kLrtOpCodeTflAdd: diff --git a/tensorflow/lite/experimental/lrt/tools/dump.h b/tensorflow/lite/experimental/lrt/tools/dump.h index edf70ff9550001..885cf0652b5e6b 100644 --- a/tensorflow/lite/experimental/lrt/tools/dump.h +++ b/tensorflow/lite/experimental/lrt/tools/dump.h @@ -46,6 +46,9 @@ void Dump(LrtElementType type, std::ostream& out = std::cerr); // Dump details about the given LrtRankedTensorType to the given stream. void Dump(const LrtRankedTensorType& type, std::ostream& out = std::cerr); +// Dump details about the given LrtModel to the given stream. +void Dump(const LrtModelT& model, std::ostream& out = std::cerr); + // Dump details about options void DumpOptions(const LrtOpT& op, std::ostream& out = std::cerr); diff --git a/tensorflow/lite/experimental/lrt/tools/dump_test.cc b/tensorflow/lite/experimental/lrt/tools/dump_test.cc index 7744bd70312c65..2375e5531e290b 100644 --- a/tensorflow/lite/experimental/lrt/tools/dump_test.cc +++ b/tensorflow/lite/experimental/lrt/tools/dump_test.cc @@ -30,6 +30,12 @@ using ::lrt::testing::LoadTestFileModel; TEST(DumpTest, TestDump) { auto model = LoadTestFileModel("one_mul.tflite"); + { + std::ostringstream model_dump; + Dump(*model, model_dump); + EXPECT_EQ(model_dump.view(), "LrtModel : [ #subgraphs=1 ]\n"); + } + { const LrtTensorT& in_tensor = *model->subgraphs.front().inputs.front(); std::ostringstream in_tensor_dump; diff --git a/tensorflow/lite/experimental/lrt/tools/temp.txt b/tensorflow/lite/experimental/lrt/tools/temp.txt new file mode 100644 index 00000000000000..5aa3cb613f126e --- /dev/null +++ b/tensorflow/lite/experimental/lrt/tools/temp.txt @@ -0,0 +1,237 @@ +// std::pair GetModelAndPlugin() { +// std::vector plugins; +// LRT_CHECK_STATUS_OK(PluginManager::LoadPlugins({kPluginSearchPath}, +// plugins)); ABSL_CHECK_EQ(plugins.size(), 1); return +// {LoadTestFileModel(kModel), std::move(plugins.front())}; +// } + +// TEST(PluginToolTest, SerializeRoundTrip) { +// auto test_data = GetModelAndPlugin(); +// { +// ASSERT_EQ(test_data.first->subgraphs.size(), 1); +// const LrtSubgraphT& subgraph = test_data.first->subgraphs.front(); +// EXPECT_EQ(subgraph.inputs.size(), 2); +// EXPECT_EQ(subgraph.outputs.size(), 1); +// ASSERT_EQ(subgraph.ops.size(), 1); +// EXPECT_EQ(subgraph.ops.front()->op_code, kLrtOpCodeTflMul); +// } + +// PluginTool tool(std::move(test_data.first), std::move(test_data.second)); + +// std::stringstream serialized; +// ASSERT_STATUS_OK(tool.Serialize(serialized)); + +// LrtModel model; +// ASSERT_STATUS_OK( +// LoadModel(reinterpret_cast(serialized.str().data()), +// serialized.str().size(), &model)); +// UniqueLrtModel umodel(model); + +// { +// ASSERT_EQ(model->subgraphs.size(), 1); +// const LrtSubgraphT& subgraph = model->subgraphs.front(); +// EXPECT_EQ(subgraph.inputs.size(), 2); +// EXPECT_EQ(subgraph.outputs.size(), 1); +// ASSERT_EQ(subgraph.ops.size(), 1); +// EXPECT_EQ(subgraph.ops.front()->op_code, kLrtOpCodeTflMul); +// } +// } + +// TEST(PluginToolTest, DumpCompilationStats) { +// auto test_data = GetModelAndPlugin(); +// PluginTool tool(std::move(test_data.first), std::move(test_data.second)); + +// std::ostringstream dump_out; +// tool.DumpCompilationStats(dump_out); +// EXPECT_EQ(dump_out.view(), "LrtCompiledResult : +// \n"); +// } + +// TEST(PluginToolTest, TestPartition) { +// auto test_data = GetModelAndPlugin(); +// PluginTool tool(std::move(test_data.first), std::move(test_data.second)); +// ASSERT_STATUS_OK(tool.Partiion()); +// ASSERT_EQ(tool.Model().subgraphs.size(), 2); +// ASSERT_EQ(tool.MainSubgraph().ops.size(), 1); +// ASSERT_EQ(tool.Partitions().size(), 1); +// ASSERT_EQ(tool.Partitions().at(0).ops.size(), 1); +// } + +// TEST(PluginToolTest, TestDumpPartitionDetails) { +// auto test_data = GetModelAndPlugin(); +// PluginTool tool(std::move(test_data.first), std::move(test_data.second)); +// ASSERT_STATUS_OK(tool.Partiion()); +// std::ostringstream dump_out; +// tool.DumpPartitionDetails(dump_out); +// EXPECT_TRUE( +// absl::StrContains(dump_out.view(), +// "(main subgraph) LrtSubgraph : [ #ops=1 #tensors=3 ] +// " +// "(<2x2xf32>, <2x2xf32>) -> <2x2xf32>")); +// EXPECT_TRUE(absl::StrContains(dump_out.view(), +// "(partition) LrtSubgraph : [ #ops=1 +// #tensors=3 " +// "] (<2x2xf32>, <2x2xf32>) -> <2x2xf32>")); +// } + +// // Utility for applying various functions from given compiler +// // plugin to the given model. Writes details about the process to "dump". +// class PluginTool { +// public: +// // Perform the partition step. Plugin selects ops which are sliced from +// // the original graph. +// LrtStatus Partiion(); + +// // Perform the compilation step for "soc_model" provided. Writes +// // a new flatbuffer with embedded compiled module and custom ops to +// // the given stream. +// // NOTE: Currently this invalidates the underlying input model so it +// // cannot be called more than once. +// // TODO: Implement model copy to support compiling for multiple soc_models +// // in one run. +// LrtStatus Compile(const absl::string_view soc_model); + +// PluginTool(UniqueLrtModel model, internal::PluginManager plugin, +// std::ostream& dump = std::cerr) +// : model_(std::move(model)), plugin_(std::move(plugin)), dump_(dump) {} + +// PluginTool(const PluginTool&) = delete; +// PluginTool& operator=(const PluginTool&) = delete; +// PluginTool(PluginTool&&) = delete; +// PluginTool& operator=(PluginTool&&) = delete; + +// private: +// const LrtModelT& Model() const { return *model_; } +// LrtModelT& Model() { return *model_; } + +// const LrtSubgraphT& MainSubgraph() const { return +// Model().subgraphs.front(); } LrtSubgraphT& MainSubgraph() { return +// Model().subgraphs.front(); } + +// const absl::Span Partitions() const; + +// std::ostream& Dump() { return dump_; } +// std::ostream& dump_; + +// void DumpPartitionDetails() const; +// void DumpCompilationStats(const absl::string_view soc_model) const; + +// std::vector& CustomOps() { return custom_ops_; } +// std::vector custom_ops_; + +// UniqueLrtModel model_; + +// internal::PluginManager plugin_; +// }; + +// void PluginTool::DumpCompilationStats(const absl::string_view soc_model) +// const { +// static constexpr absl::string_view kCompiledResultTpl = +// "LrtCompiledResult : [ module_size=%lu (bytes), +// #compiled_partitions=%lu " +// "]\n"; +// static constexpr absl::string_view kCompiledResultErr = +// "LrtCompiledResult : \n"; +// if (plugin_.CompiledResultHandle(soc_model) == nullptr) { +// Dump() << kCompiledResultErr; +// return; +// } +// const void* byte_code; +// size_t byte_code_size; +// if (kLrtStatusOk != +// plugin_.Api().compiled_result_get_byte_code( +// plugin_.CompiledResultHandle(), &byte_code, &byte_code_size)) { +// Dump() << kCompiledResultErr; +// return; +// } + +// size_t num_compiled_partitions; +// if (kLrtStatusOk != +// plugin_.Api().compiled_result_get_num_calls( +// plugin_.CompiledResultHandle(), &num_compiled_partitions)) { +// Dump() << kCompiledResultErr; +// return; +// } + +// Dump() << absl::StreamFormat(kCompiledResultTpl, byte_code_size, +// num_compiled_partitions); +// } + +// void PluginTool::DumpPartitionDetails() const { +// Dump() << "[[ Partition Results ]]\n"; +// Dump() << "(main subgraph) "; +// lrt::internal::Dump(MainSubgraph(), Dump()); +// for (const auto& partition : Partitions()) { +// Dump() << "(partition) "; +// lrt::internal::Dump(partition, Dump()); +// } +// } + +// // Currently new partitioned subgraphs are appended to the model subgraphs +// and +// // there is only support of input models with one subgraph. +// const absl::Span PluginTool::Partitions() const { +// return absl::MakeConstSpan(model_->subgraphs.data() + 1, +// model_->subgraphs.size() - 1); +// } + +// LrtStatus PluginTool::Partiion() { +// LrtOpListT selected_ops; +// LRT_RETURN_STATUS_IF_NOT_OK(plugin_.Api().partition_model( +// plugin_.PluginHandle(), model_.get(), &selected_ops)); +// auto partitions = GroupPartitions(selected_ops.ops); + +// CustomOps().reserve(partitions.size()); + +// for (auto& partition : partitions) { +// LrtSubgraph new_subgraph = &model_->subgraphs.emplace_back(); +// CustomOps().push_back( +// OutlinePartition(MainSubgraph(), new_subgraph, partition)); +// } + +// return kLrtStatusOk; +// } + +// LrtStatus PluginTool::Compile(const absl::string_view soc_models) { +// LRT_RETURN_STATUS_IF_NOT_OK( +// plugin_.Api().compile(plugin_.PluginHandle(), soc_model.data(), +// slices.data(), slices.size(), &compiled_result)); + +// lrt_param_index_t num_calls_compiled; +// LRT_RETURN_STATUS_IF_NOT_OK( +// LrtCompiledResultGetNumCalls(compiled_result, &num_calls_compiled)); + +// if (num_calls_compiled != slices.size()) { +// std::cerr +// << "Plugin must provide and entry point for each compiled +// partition\n"; +// return kLrtStatusErrorNotFound; +// } + +// for (int i = 0; i < num_calls_compiled; ++i) { +// const void* call_info; +// size_t call_info_size; + +// LRT_RETURN_STATUS_IF_NOT_OK(LrtCompiledResultGetCallInfo( +// compiled_result, i, &call_info, &call_info_size)); + +// auto* custom_op = custom_ops.at(i); +// custom_op->custom_options.assign(reinterpret_cast(call_info), +// call_info_size); +// } +// return kLrtStatusOk; +// } + +// LrtStatus PluginTool::Serialize(const absl::string_view soc_model, +// std::ostream& out) { +// uint8_t* buf; +// size_t size; +// size_t offset; +// LRT_RETURN_STATUS_IF_NOT_OK( +// SerializeModel(model_.release(), &buf, &size, &offset)); +// const char* cbuf = reinterpret_cast(buf); +// out.write(cbuf + offset, size - offset); +// delete[] buf; +// return kLrtStatusOk; +// } \ No newline at end of file diff --git a/tensorflow/lite/experimental/lrt/tools/tool_display.cc b/tensorflow/lite/experimental/lrt/tools/tool_display.cc new file mode 100644 index 00000000000000..f455d0d21172d3 --- /dev/null +++ b/tensorflow/lite/experimental/lrt/tools/tool_display.cc @@ -0,0 +1,61 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/lrt/tools/tool_display.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" + +namespace lrt::tools { + +ToolDisplay::ToolDisplay(OptOstreamRefT display_stream, + const absl::string_view tool_label) + : display_(display_stream) { + label_ = absl::StrFormat( + "[LRT_TOOLS%s] ", + tool_label.empty() ? tool_label : absl::StrFormat(":%s", tool_label)); +} + +std::ostream& ToolDisplay::Display() { + return display_.has_value() ? display_.value().get() : null_display_; +} + +std::ostream& ToolDisplay::Labeled() { + Display() << label_; + return Display(); +} + +std::ostream& ToolDisplay::Indented() { + Display() << "\t"; + return Display(); +} + +void ToolDisplay::Start(const absl::string_view start_label) { + Labeled() << absl::StreamFormat("Starting %s...\n", start_label); +} + +void ToolDisplay::Done() { + Labeled(); + Indented() << "Done!\n"; +} + +void ToolDisplay::Fail() { + Labeled(); + Indented() << "Failed\n"; +} + +} // namespace lrt::tools diff --git a/tensorflow/lite/experimental/lrt/tools/tool_display.h b/tensorflow/lite/experimental/lrt/tools/tool_display.h new file mode 100644 index 00000000000000..8f6452d8a935b5 --- /dev/null +++ b/tensorflow/lite/experimental/lrt/tools/tool_display.h @@ -0,0 +1,64 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LRT_TOOLS_TOOL_DISPLAY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LRT_TOOLS_TOOL_DISPLAY_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" + +namespace lrt::tools { + +// Utility class for interactive logging for usage in command line tools only. +// Allows user to explicitly set target stream. +class ToolDisplay { + using OptOstreamRefT = std::optional>; + + public: + // Construct configured ToolDisplay. Label is used for prefixing dumps + // in "LabeledStream". If "dump" is null, all printing through this class + // is silenced. + explicit ToolDisplay(OptOstreamRefT display_stream = std::nullopt, + absl::string_view tool_label = ""); + + // Get out stream. + std::ostream& Display(); + + // Get Display with label prefix. + std::ostream& Labeled(); + + // Get Display with indent. + std::ostream& Indented(); + + // Log string indicating a sub rountine is beginning. + void Start(absl::string_view start_label); + + // Log string indicating a sub rountine is done and succeeded. + void Done(); + + // Log string indicating a sub rountine is done and failed. + void Fail(); + + private: + std::string label_; + std::ostream null_display_ = std::ostream(nullptr); + OptOstreamRefT display_; +}; + +} // namespace lrt::tools + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LRT_TOOLS_TOOL_DISPLAY_H_ diff --git a/tensorflow/lite/experimental/lrt/tools/tool_display_test.cc b/tensorflow/lite/experimental/lrt/tools/tool_display_test.cc new file mode 100644 index 00000000000000..2a60b5fc6a18d9 --- /dev/null +++ b/tensorflow/lite/experimental/lrt/tools/tool_display_test.cc @@ -0,0 +1,82 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/lrt/tools/tool_display.h" + +#include + +#include +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" + +namespace { + +using ::lrt::tools::ToolDisplay; + +static constexpr absl::string_view kToolName = "test-tool"; +static constexpr absl::string_view kLabel = "[LRT_TOOLS:test-tool]"; +static constexpr absl::string_view kStartLabel = "Test Routine"; +static constexpr absl::string_view kDisplayInfo = "info"; + +TEST(TestToolDisplay, Display) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Display() << kDisplayInfo; + EXPECT_EQ(out.view(), kDisplayInfo); +} + +TEST(TestToolDisplay, Indented) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Indented() << kDisplayInfo; + EXPECT_EQ(out.view(), absl::StrFormat("\t%s", kDisplayInfo)); +} + +TEST(TestToolDisplay, Labeled) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Labeled() << kDisplayInfo; + EXPECT_EQ(out.view(), absl::StrFormat("%s %s", kLabel, kDisplayInfo)); +} + +TEST(TestToolDisplay, LabeledNoToolName) { + std::stringstream out; + ToolDisplay display(out); + display.Labeled() << kDisplayInfo; + EXPECT_EQ(out.view(), absl::StrFormat("%s %s", "[LRT_TOOLS]", kDisplayInfo)); +} + +TEST(TestToolDisplay, Start) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Start(kStartLabel); + EXPECT_EQ(out.view(), + absl::StrFormat("%s Starting %s...\n", kLabel, kStartLabel)); +} + +TEST(TestToolDisplay, Done) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Done(); + EXPECT_EQ(out.view(), absl::StrFormat("%s \tDone!\n", kLabel)); +} + +TEST(TestToolDisplay, Fail) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Fail(); + EXPECT_EQ(out.view(), absl::StrFormat("%s \tFailed\n", kLabel)); +} + +} // namespace diff --git a/tensorflow/lite/experimental/lrt/vendors/examples/BUILD b/tensorflow/lite/experimental/lrt/vendors/examples/BUILD index 3f2c2181b596a6..66995f5bf3acac 100644 --- a/tensorflow/lite/experimental/lrt/vendors/examples/BUILD +++ b/tensorflow/lite/experimental/lrt/vendors/examples/BUILD @@ -32,6 +32,7 @@ lite_rt_dynamic_lib( "//tensorflow/lite/experimental/lrt/c:lite_rt_c_api", "//tensorflow/lite/experimental/lrt/cc:lite_rt_cc_api", "//tensorflow/lite/experimental/lrt/core:graph_tools", + "//tensorflow/lite/experimental/lrt/core:model", ], ) diff --git a/tensorflow/lite/experimental/lrt/vendors/examples/example_plugin.cc b/tensorflow/lite/experimental/lrt/vendors/examples/example_plugin.cc index ecc444134963c9..f54dac5776ebe8 100644 --- a/tensorflow/lite/experimental/lrt/vendors/examples/example_plugin.cc +++ b/tensorflow/lite/experimental/lrt/vendors/examples/example_plugin.cc @@ -76,7 +76,7 @@ LrtStatus LrtCompiledResultGetCallInfo(LrtCompiledResult compiled_result, const void** call_info, size_t* call_info_size) { if (call_idx >= compiled_result->per_op_data.size()) { - return kLrtStatusParamIndexOOB; + return kLrtStatusErrorIndexOOB; } *call_info = compiled_result->per_op_data.at(call_idx).data(); diff --git a/tensorflow/lite/experimental/lrt/vendors/examples/example_plugin_test.cc b/tensorflow/lite/experimental/lrt/vendors/examples/example_plugin_test.cc index 1884682dcdbaaf..103cca219e3e5c 100644 --- a/tensorflow/lite/experimental/lrt/vendors/examples/example_plugin_test.cc +++ b/tensorflow/lite/experimental/lrt/vendors/examples/example_plugin_test.cc @@ -53,13 +53,14 @@ TEST(TestCallDummyPlugin, PartitionSimpleMultiAdd) { auto plugin = GetDummyPlugin(); auto model = lrt::testing::LoadTestFileModel("simple_multi_op.tflite"); - LrtOpListT selected_ops; + LrtOpListT selected_op_list; ASSERT_STATUS_OK( - LrtPluginPartitionModel(plugin.get(), model.get(), &selected_ops)); + LrtPluginPartitionModel(plugin.get(), model.get(), &selected_op_list)); + const auto selected_ops = selected_op_list.Vec(); - ASSERT_EQ(selected_ops.ops.size(), 2); - ASSERT_EQ(selected_ops.ops[0]->op_code, kLrtOpCodeTflMul); - ASSERT_EQ(selected_ops.ops[1]->op_code, kLrtOpCodeTflMul); + ASSERT_EQ(selected_ops.size(), 2); + ASSERT_EQ(selected_ops[0]->op_code, kLrtOpCodeTflMul); + ASSERT_EQ(selected_ops[1]->op_code, kLrtOpCodeTflMul); } TEST(TestCallDummyPlugin, CompileMulSubgraph) { diff --git a/tensorflow/lite/experimental/lrt/vendors/qualcomm/compiler/qnn_compiler_plugin.cc b/tensorflow/lite/experimental/lrt/vendors/qualcomm/compiler/qnn_compiler_plugin.cc index 91e3cfd2f3515d..05a0479e8f3222 100644 --- a/tensorflow/lite/experimental/lrt/vendors/qualcomm/compiler/qnn_compiler_plugin.cc +++ b/tensorflow/lite/experimental/lrt/vendors/qualcomm/compiler/qnn_compiler_plugin.cc @@ -106,7 +106,7 @@ LrtStatus LrtCompiledResultGetCallInfo(LrtCompiledResult compiled_result, const void** call_info, size_t* call_info_size) { if (call_idx >= compiled_result->graph_names.size()) { - return kLrtStatusParamIndexOOB; + return kLrtStatusErrorIndexOOB; } *call_info = compiled_result->graph_names.at(call_idx).data(); diff --git a/tensorflow/lite/experimental/lrt/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc b/tensorflow/lite/experimental/lrt/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc index bd725f1b2d8b82..546f1018b21e26 100644 --- a/tensorflow/lite/experimental/lrt/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc +++ b/tensorflow/lite/experimental/lrt/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc @@ -18,6 +18,7 @@ #include #include "absl/log/absl_check.h" #include "tensorflow/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/lrt/c/lite_rt_op_code.h" #include "tensorflow/lite/experimental/lrt/cc/lite_rt_support.h" #include "tensorflow/lite/experimental/lrt/core/graph_tools.h" #include "tensorflow/lite/experimental/lrt/core/model.h" @@ -50,11 +51,13 @@ TEST(TestQnnPlugin, PartitionMulOps) { auto plugin = GetQnnPlugin(); auto model = lrt::testing::LoadTestFileModel("one_mul.tflite"); - LrtOpListT selected_ops; + LrtOpListT selected_op_list; ASSERT_STATUS_OK( - LrtPluginPartitionModel(plugin.get(), model.get(), &selected_ops)); + LrtPluginPartitionModel(plugin.get(), model.get(), &selected_op_list)); + const auto selected_ops = selected_op_list.Vec(); - EXPECT_EQ(selected_ops.ops.size(), 1); + ASSERT_EQ(selected_ops.size(), 1); + EXPECT_EQ(selected_ops[0]->op_code, kLrtOpCodeTflMul); } TEST(TestQnnPlugin, CompileMulSubgraph) { diff --git a/tensorflow/lite/experimental/lrt/vendors/qualcomm/qnn_manager.cc b/tensorflow/lite/experimental/lrt/vendors/qualcomm/qnn_manager.cc index d9b1c791534075..9b173f2ed5e7d5 100644 --- a/tensorflow/lite/experimental/lrt/vendors/qualcomm/qnn_manager.cc +++ b/tensorflow/lite/experimental/lrt/vendors/qualcomm/qnn_manager.cc @@ -107,7 +107,7 @@ LrtStatus QnnManager::ResolveApi() { if (lib_so_ == nullptr) { LITE_RT_LOG(LRT_ERROR, "%s", "Cannot resolve functions: libQnn*.so has not been loaded.\n"); - return kLrtStatusDynamicLoadErr; + return kLrtStatusErrorDynamicLoading; } auto providers = LoadProvidersFromLib(lib_so_); @@ -129,7 +129,7 @@ LrtStatus QnnManager::ResolveApi() { if (interface_ == nullptr) { LITE_RT_LOG(LRT_ERROR, "%s", "No valid interface was provided\n"); - return kLrtStatusDynamicLoadErr; + return kLrtStatusErrorDynamicLoading; } return kLrtStatusOk; @@ -139,7 +139,7 @@ LrtStatus QnnManager::ResolveSystemApi() { if (lib_so_ == nullptr) { LITE_RT_LOG(LRT_ERROR, "%s", "Cannot resolve functions: libQnn*.so has not been loaded.\n"); - return kLrtStatusDynamicLoadErr; + return kLrtStatusErrorDynamicLoading; } auto system_providers = LoadSystemProvidersFromLib(lib_system_so_); @@ -161,7 +161,7 @@ LrtStatus QnnManager::ResolveSystemApi() { if (system_interface_ == nullptr) { LITE_RT_LOG(LRT_ERROR, "%s", "No valid system interface was provided\n"); - return kLrtStatusDynamicLoadErr; + return kLrtStatusErrorDynamicLoading; } return kLrtStatusOk;