diff --git a/src/bindings/js/node/include/model_wrap.hpp b/src/bindings/js/node/include/model_wrap.hpp index 1e3373b6185095..5ecb9dc86389bd 100644 --- a/src/bindings/js/node/include/model_wrap.hpp +++ b/src/bindings/js/node/include/model_wrap.hpp @@ -84,6 +84,14 @@ class ModelWrap : public Napi::ObjectWrap { */ Napi::Value is_dynamic(const Napi::CallbackInfo& info); + /** + * @brief Returns the number of outputs for this model + * @param info Contains information about the environment and passed arguments + * This method does not accept any arguments. If arguments are provided it throws Napi::Error + * @return number indicating the quantity of outputs for the model + */ + Napi::Value get_output_size(const Napi::CallbackInfo& info); + /** * @brief Sets a friendly name for a model. * @param info Contains information about the environment and passed arguments diff --git a/src/bindings/js/node/lib/addon.ts b/src/bindings/js/node/lib/addon.ts index 5f6348d780a1b2..7acdd5f0be56d6 100644 --- a/src/bindings/js/node/lib/addon.ts +++ b/src/bindings/js/node/lib/addon.ts @@ -73,6 +73,7 @@ interface Model { input(nameOrId?: string | number): Output; getName(): string; isDynamic(): boolean; + getOutputSize(): number; setFriendlyName(name: string): void; getFriendlyName(): string; } diff --git a/src/bindings/js/node/src/model_wrap.cpp b/src/bindings/js/node/src/model_wrap.cpp index d22dd6701ab156..5198f82cffc0ab 100644 --- a/src/bindings/js/node/src/model_wrap.cpp +++ b/src/bindings/js/node/src/model_wrap.cpp @@ -20,6 +20,7 @@ Napi::Function ModelWrap::get_class(Napi::Env env) { InstanceMethod("output", &ModelWrap::get_output), InstanceMethod("input", &ModelWrap::get_input), InstanceMethod("isDynamic", &ModelWrap::is_dynamic), + InstanceMethod("getOutputSize", &ModelWrap::get_output_size), InstanceMethod("setFriendlyName", &ModelWrap::set_friendly_name), InstanceMethod("getFriendlyName", &ModelWrap::get_friendly_name), InstanceAccessor<&ModelWrap::get_inputs>("inputs"), @@ -131,6 +132,16 @@ Napi::Value ModelWrap::is_dynamic(const Napi::CallbackInfo& info) { return Napi::Boolean::New(env, result); } +Napi::Value ModelWrap::get_output_size(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + if (info.Length() > 0) { + reportError(env, "getOutputSize() does not accept any arguments."); + return env.Undefined(); + } + const auto size = static_cast(_model->get_output_size()); + return Napi::Number::New(env, size); +} + Napi::Value ModelWrap::set_friendly_name(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); try { diff --git a/src/bindings/js/node/tests/model.test.js b/src/bindings/js/node/tests/model.test.js index c157c003e18f6f..184df9dc143be8 100644 --- a/src/bindings/js/node/tests/model.test.js +++ b/src/bindings/js/node/tests/model.test.js @@ -94,3 +94,21 @@ describe('Node.js getFriendlyName() / setFriendlyName()', () => { }); }); }); + +describe('Model.getOutputSize()', () => { + + it('should return a number indicating number of outputs for the model', () => { + const result = model.getOutputSize(); + assert.strictEqual(typeof result, 'number', 'getOutputSize() should return a number'); + }); + + it('should not accept any arguments', () => { + assert.throws(() => { + model.getOutputSize('unexpected argument'); + }, /^Error: getOutputSize\(\) does not accept any arguments\.$/, 'Expected getOutputSize to throw an error when called with arguments'); + }); + + it('should return 1 for the default model', () => { + assert.strictEqual(model.getOutputSize(), 1, 'Expected getOutputSize to return 1 for the default model'); + }); +}); diff --git a/src/frontends/onnx/frontend/src/op/org.openvinotoolkit/deformable_conv_2d.cpp b/src/frontends/onnx/frontend/src/op/org.openvinotoolkit/deformable_conv_2d.cpp index 7abb16859b9dcc..ff02d0f6e0a0d5 100644 --- a/src/frontends/onnx/frontend/src/op/org.openvinotoolkit/deformable_conv_2d.cpp +++ b/src/frontends/onnx/frontend/src/op/org.openvinotoolkit/deformable_conv_2d.cpp @@ -16,6 +16,7 @@ #include "op/org.openvinotoolkit/deformable_conv_2d.hpp" +#include "openvino/frontend/exception.hpp" #include "openvino/op/deformable_convolution.hpp" #include "utils/convpool.hpp" @@ -36,16 +37,32 @@ ov::OutputVector deformable_conv_2d(const ov::frontend::onnx::Node& node) { const auto deformable_groups = node.get_attribute_value("deformable_groups", 1); const auto auto_pad_type = convpool::get_auto_pad(node); - return {std::make_shared(inputs.at(0), - inputs.at(1), - inputs.at(2), - strides, - paddings.first, - paddings.second, - dilations, - auto_pad_type, - group, - deformable_groups)}; + if (inputs.size() == 3) { + return {std::make_shared(inputs.at(0), + inputs.at(1), + inputs.at(2), + strides, + paddings.first, + paddings.second, + dilations, + auto_pad_type, + group, + deformable_groups)}; + } else if (inputs.size() == 4) { + return {std::make_shared(inputs.at(0), + inputs.at(1), + inputs.at(2), + inputs.at(3), + strides, + paddings.first, + paddings.second, + dilations, + auto_pad_type, + group, + deformable_groups)}; + } else { + FRONT_END_GENERAL_CHECK(false, "Invalid number of inputs"); + } } } // namespace set_1 } // namespace op diff --git a/src/frontends/onnx/tests/models/org.openvinotoolkit/deformable_conv_2d_with_mask.prototxt b/src/frontends/onnx/tests/models/org.openvinotoolkit/deformable_conv_2d_with_mask.prototxt new file mode 100644 index 00000000000000..c317dca2669536 --- /dev/null +++ b/src/frontends/onnx/tests/models/org.openvinotoolkit/deformable_conv_2d_with_mask.prototxt @@ -0,0 +1,138 @@ +ir_version: 7 +producer_name: "OpenVINO ONNX Frontend" +graph { + node { + input: "data" + input: "deformation" + input: "filters" + input: "mask" + output: "out" + op_type: "DeformableConv2D" + } + name: "test_graph" + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "deformation" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + dim { + dim_value: 3 + } + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "filters" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + input { + name: "mask" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + dim { + dim_value: 3 + } + dim { + dim_value: 3 + } + } + } + } + } + initializer { + name: "filters" + dims: 1 + dims: 1 + dims: 2 + dims: 2 + data_type: 1 + float_data: 0.1 + float_data: 0.2 + float_data: 0.3 + float_data: 0.4 + } + output { + name: "out" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 3 + } + dim { + dim_value: 3 + } + } + } + } + } +} +opset_import { + version: 7 +} diff --git a/src/frontends/onnx/tests/onnx_import_org_openvino.in.cpp b/src/frontends/onnx/tests/onnx_import_org_openvino.in.cpp index 332467a0b92c2f..9116771f352b9e 100644 --- a/src/frontends/onnx/tests/onnx_import_org_openvino.in.cpp +++ b/src/frontends/onnx/tests/onnx_import_org_openvino.in.cpp @@ -565,6 +565,41 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_deformable_conv_2d) { test_case.run(); } +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_deformable_conv_2d_with_mask) { + auto model = convert_model("org.openvinotoolkit/deformable_conv_2d_with_mask.onnx"); + + auto test_case = ov::test::TestCase(model, s_device); + + // data + test_case.add_input( + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f}); + + // deformations + test_case.add_input({0.5f, -0.5f, 0.0f, 1.0f, 0.5f, -0.5f, 0.0f, 1.0f, 1.0f, 0.5f, -0.5f, 0.0f, + 1.0f, 0.5f, -0.5f, 0.0f, 1.0f, 1.0f, 0.5f, -0.5f, 0.0f, 1.0f, 0.5f, -0.5f, + 0.0f, 1.0f, 1.0f, 0.5f, -0.5f, 0.0f, 1.0f, 0.5f, -0.5f, 0.0f, 1.0f, 1.0f, + 0.5f, -0.5f, 0.0f, 1.0f, 0.5f, -0.5f, 0.0f, 1.0f, 1.0f, 0.5f, -0.5f, 0.0f, + 1.0f, 0.5f, -0.5f, 0.0f, 1.0f, 1.0f, 0.5f, -0.5f, 0.0f, 1.0f, 0.5f, -0.5f, + 0.0f, 1.0f, 1.0f, 0.5f, -0.5f, 0.0f, 1.0f, 0.5f, -0.5f, 0.0f, 1.0f, 1.0f}); + + // mask + test_case.add_input({0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.0f, 1.1f, 1.2f, + 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f, 2.0f, 2.1f, 2.2f, 2.3f, 2.4f, + 2.5f, 2.6f, 2.7f, 2.8f, 2.9f, 3.0f, 3.1f, 3.2f, 3.3f, 3.4f, 3.5f, 3.6f}); + + test_case.add_expected_output(Shape{1, 1, 3, 3}, + {14.7299995f, + 7.3200006f, + 15.0600004f, + 31.1000004f, + 28.9899998f, + 20.5800018f, + 32.6200027f, + 6.6400003f, + 1.4399999f}); + test_case.run(); +} + OPENVINO_TEST(${BACKEND_NAME}, onnx_model_generate_proposals) { auto model = convert_model("org.openvinotoolkit/generate_proposals.onnx"); diff --git a/src/frontends/onnx/tests/runtime/interpreter/unit_test.manifest b/src/frontends/onnx/tests/runtime/interpreter/unit_test.manifest index cd97b44b2f3d82..82fa651aaf6e94 100644 --- a/src/frontends/onnx/tests/runtime/interpreter/unit_test.manifest +++ b/src/frontends/onnx/tests/runtime/interpreter/unit_test.manifest @@ -45,6 +45,7 @@ onnx_controlflow_loop_power # No evaluator for DeformableConv2D onnx_model_deformable_conv_2d +onnx_model_deformable_conv_2d_with_mask # New fails onnx_model_quant_conv_linear_3d diff --git a/src/frontends/tensorflow/src/input_model.cpp b/src/frontends/tensorflow/src/input_model.cpp index cd64c466dec1b3..14474eec64125a 100644 --- a/src/frontends/tensorflow/src/input_model.cpp +++ b/src/frontends/tensorflow/src/input_model.cpp @@ -376,11 +376,12 @@ std::vector> InputModel::InputModelTFImpl::topologicall ops_to_do.push(output_operation_place); } - // walk through all NextIteration nodes and put their producers into ops_to_do - // this is needed to avoid missed nodes in the body graph of TF1 While operation for (const auto& op_place : m_op_places) { auto op_decoder = op_place->get_decoder(); + auto op_name = op_decoder->get_op_name(); if (op_decoder->get_op_type() == "NextIteration") { + // walk through all NextIteration nodes and put their producers into ops_to_do + // this is needed to avoid missed nodes in the body graph of TF1 While operation std::string producer_name; std::string producer_output_port_name; size_t producer_output_port_idx; @@ -390,6 +391,15 @@ std::vector> InputModel::InputModelTFImpl::topologicall "NextIteration is not found among operation places " + producer_name); ops_to_do.push(m_op_places_map.at(producer_name)); + } else if (op_decoder->get_op_type() == "LookupTableImport" || + op_decoder->get_op_type() == "LookupTableImportV2") { + // all LookupTableImport nodes must be preserved in a graph for conversion because + // they can be terminating nodes and contain input values for HashTable initialization + FRONT_END_GENERAL_CHECK(m_op_places_map.count(op_name), + "[TensorFlow Frontend] internal error or inconsistent model: LookupTableImport " + "operation is not found among operation places " + + op_name); + ops_to_do.push(m_op_places_map.at(op_name)); } } diff --git a/src/frontends/tensorflow_lite/src/op_table.cpp b/src/frontends/tensorflow_lite/src/op_table.cpp index d6d5cc9dc2dfa5..9adcc0b8db7fd4 100644 --- a/src/frontends/tensorflow_lite/src/op_table.cpp +++ b/src/frontends/tensorflow_lite/src/op_table.cpp @@ -5,7 +5,6 @@ #include "op_table.hpp" #include "openvino/opsets/opset10.hpp" -#include "openvino/opsets/opset8.hpp" #include "utils.hpp" using namespace std; @@ -21,7 +20,7 @@ using namespace ov::frontend::tensorflow::op; return func(context); \ } -#define OP_CONVERT_TYPE_RENAME(func, name) \ +#define DEQUANTIZE_INPUTS_WITH_NAMED_OUTPUTS(func) \ [](const ov::frontend::tensorflow_lite::NodeContext& node) -> OutputVector { \ auto decoder = node.get_decoder(); \ auto inputs = node.get_inputs(); \ @@ -36,7 +35,7 @@ namespace tensorflow_lite { namespace op { std::map get_supported_ops() { return { - {"ABS", translate_unary}, + {"ABS", translate_unary}, {"ADD", translate_binary_op_with_activation}, {"ADD_N", DEQUANTIZE_INPUTS(translate_add_n_op)}, {"ARG_MAX", DEQUANTIZE_INPUTS(translate_arg_max_op)}, @@ -45,23 +44,23 @@ std::map get_supported_ops() { // ATAN2 {"AVERAGE_POOL_2D", DEQUANTIZE_INPUTS(avg_pool_2d)}, {"BATCH_MATMUL", DEQUANTIZE_INPUTS(translate_batch_mat_mul_op)}, - {"BATCH_TO_SPACE_ND", OP_CONVERT_TYPE_RENAME(translate_batch_to_space_nd_op, "BatchToSpaceND")}, + {"BATCH_TO_SPACE_ND", DEQUANTIZE_INPUTS(translate_batch_to_space_nd_op)}, // BIDIRECTIONAL_SEQUENCE_LSTM // BIDIRECTIONAL_SEQUENCE_RNN - {"BROADCAST_ARGS", OP_CONVERT_TYPE_RENAME(translate_broadcast_args_op, "BroadcastArgs")}, - {"BROADCAST_TO", OP_CONVERT_TYPE_RENAME(translate_broadcast_to_op, "BroadcastTo")}, + {"BROADCAST_ARGS", DEQUANTIZE_INPUTS(translate_broadcast_args_op)}, + {"BROADCAST_TO", DEQUANTIZE_INPUTS(translate_broadcast_to_op)}, // BUCKETIZE // CALL // CALL_ONCE {"CAST", DEQUANTIZE_INPUTS(translate_cast_op)}, - {"CEIL", translate_unary}, + {"CEIL", translate_unary}, {"COMPLEX_ABS", DEQUANTIZE_INPUTS(complex_abs)}, // CONCAT_EMBEDDINGS {"CONCATENATION", DEQUANTIZE_INPUTS(concatenation)}, {"CONV_2D", DEQUANTIZE_INPUTS(conv2d)}, // CONV_3D // CONV_3D_TRANSPOSE - {"COS", translate_unary}, + {"COS", translate_unary}, // CUMSUM // CUSTOM // DELEGATE @@ -74,21 +73,21 @@ std::map get_supported_ops() { {"ELU", DEQUANTIZE_INPUTS(translate_elu_op)}, // EMBEDDING_LOOKUP // EMBEDDING_LOOKUP_SPARSE - {"EQUAL", translate_binary}, - {"EXP", translate_unary}, - {"EXPAND_DIMS", OP_CONVERT_TYPE_RENAME(translate_expand_dims_op, "ExpandDims")}, + {"EQUAL", translate_binary}, + {"EXP", translate_unary}, + {"EXPAND_DIMS", DEQUANTIZE_INPUTS(translate_expand_dims_op)}, // FAKE_QUANT {"FILL", DEQUANTIZE_INPUTS(translate_fill_op)}, - {"FLOOR", translate_unary}, + {"FLOOR", translate_unary}, {"FLOOR_DIV", DEQUANTIZE_INPUTS(translate_floor_div_op)}, - {"FLOOR_MOD", translate_binary}, + {"FLOOR_MOD", translate_binary}, {"FULLY_CONNECTED", DEQUANTIZE_INPUTS(fully_connected)}, {"GATHER", DEQUANTIZE_INPUTS(gather)}, {"GATHER_ND", DEQUANTIZE_INPUTS(translate_gather_nd_op)}, {"GELU", DEQUANTIZE_INPUTS(translate_gelu_op)}, - {"GREATER", translate_binary}, - {"GREATER_EQUAL", translate_binary}, - {"HARD_SWISH", translate_unary}, + {"GREATER", translate_binary}, + {"GREATER_EQUAL", translate_binary}, + {"HARD_SWISH", translate_unary}, // HASHTABLE // HASHTABLE_FIND // HASHTABLE_IMPORT @@ -99,48 +98,48 @@ std::map get_supported_ops() { {"L2_NORMALIZATION", DEQUANTIZE_INPUTS(l2_normalization)}, // L2_POOL_2D {"LEAKY_RELU", DEQUANTIZE_INPUTS(translate_leaky_relu_op)}, - {"LESS", translate_binary}, - {"LESS_EQUAL", translate_binary}, + {"LESS", translate_binary}, + {"LESS_EQUAL", translate_binary}, // LOCAL_RESPONSE_NORMALIZATION - {"LOG", translate_unary}, + {"LOG", translate_unary}, {"LOG_SOFTMAX", DEQUANTIZE_INPUTS(translate_log_softmax_op)}, - {"LOGICAL_AND", translate_binary}, - {"LOGICAL_NOT", translate_unary}, - {"LOGICAL_OR", translate_binary}, + {"LOGICAL_AND", translate_binary}, + {"LOGICAL_NOT", translate_unary}, + {"LOGICAL_OR", translate_binary}, {"LOGISTIC", translate_unary}, // LSH_PROJECTION // LSTM {"MATRIX_DIAG", DEQUANTIZE_INPUTS(translate_matrix_diag_op)}, // MATRIX_SET_DIAG {"MAX_POOL_2D", DEQUANTIZE_INPUTS(max_pool_2d)}, - {"MAXIMUM", translate_binary}, - {"MEAN", translate_reduce_op}, - {"MINIMUM", translate_binary}, + {"MAXIMUM", translate_binary}, + {"MEAN", translate_reduce_op}, + {"MINIMUM", translate_binary}, {"MIRROR_PAD", DEQUANTIZE_INPUTS(translate_mirror_pad_op)}, {"MUL", translate_binary_op_with_activation}, // MULTINOMIAL - {"NEG", translate_unary}, + {"NEG", translate_unary}, // NON_MAX_SUPPRESSION_V4 // NON_MAX_SUPPRESSION_V5 - {"NOT_EQUAL", translate_binary}, + {"NOT_EQUAL", translate_binary}, {"ONE_HOT", DEQUANTIZE_INPUTS(translate_one_hot_op)}, {"PACK", DEQUANTIZE_INPUTS(translate_pack_op)}, - {"PAD", OP_CONVERT_TYPE_RENAME(translate_pad_op, "Pad")}, - {"PADV2", OP_CONVERT_TYPE_RENAME(translate_padv2_op, "PadV2")}, - {"POW", translate_binary}, + {"PAD", DEQUANTIZE_INPUTS(translate_pad_op)}, + {"PADV2", DEQUANTIZE_INPUTS(translate_padv2_op)}, + {"POW", translate_binary}, {"PRELU", translate_binary}, {"QUANTIZE", quantize}, // RANDOM_STANDARD_NORMAL // RANDOM_UNIFORM {"RANGE", DEQUANTIZE_INPUTS(translate_range_op)}, - {"RANK", OP_CONVERT_TYPE_RENAME(translate_rank_op, "Rank")}, + {"RANK", DEQUANTIZE_INPUTS(translate_rank_op)}, // READ_VARIABLE // REAL - {"REDUCE_ALL", translate_reduce_op}, - {"REDUCE_ANY", translate_reduce_op}, - {"REDUCE_MAX", translate_reduce_op}, - {"REDUCE_MIN", translate_reduce_op}, - {"REDUCE_PROD", translate_reduce_op}, + {"REDUCE_ALL", translate_reduce_op}, + {"REDUCE_ANY", translate_reduce_op}, + {"REDUCE_MAX", translate_reduce_op}, + {"REDUCE_MIN", translate_reduce_op}, + {"REDUCE_PROD", translate_reduce_op}, {"RELU", translate_unary}, // RELU_0_TO_1 // RELU_N1_TO_1 @@ -149,37 +148,37 @@ std::map get_supported_ops() { {"RESIZE_BILINEAR", DEQUANTIZE_INPUTS(translate_interpolate_op)}, {"RESIZE_NEAREST_NEIGHBOR", DEQUANTIZE_INPUTS(translate_interpolate_op)}, {"REVERSE_SEQUENCE", DEQUANTIZE_INPUTS(translate_reverse_sequence_op)}, - {"REVERSE_V2", OP_CONVERT_TYPE_RENAME(translate_reverse_v2_op, "ReverseV2")}, + {"REVERSE_V2", DEQUANTIZE_INPUTS(translate_reverse_v2_op)}, {"RFFT2D", DEQUANTIZE_INPUTS(rfft2d)}, // RNN {"ROUND", DEQUANTIZE_INPUTS(translate_round_op)}, {"RSQRT", DEQUANTIZE_INPUTS(translate_rsqrt_op)}, {"SCATTER_ND", DEQUANTIZE_INPUTS(translate_scatter_nd_op)}, - {"SEGMENT_SUM", OP_CONVERT_TYPE_RENAME(translate_segment_sum_op, "SegmentSum")}, - {"SELECT", OP_CONVERT_TYPE_RENAME(translate_select_op, "Select")}, - {"SELECT_V2", OP_CONVERT_TYPE_RENAME(translate_select_v2_op, "SelectV2")}, + {"SEGMENT_SUM", DEQUANTIZE_INPUTS(translate_segment_sum_op)}, + {"SELECT", DEQUANTIZE_INPUTS(translate_select_op)}, + {"SELECT_V2", DEQUANTIZE_INPUTS(translate_select_v2_op)}, {"SHAPE", translate_shape_op}, - {"SIGN", translate_unary}, - {"SIN", translate_unary}, + {"SIGN", translate_unary}, + {"SIN", translate_unary}, // SKIP_GRAM - {"SLICE", OP_CONVERT_TYPE_RENAME(translate_slice_op, "Slice")}, + {"SLICE", DEQUANTIZE_INPUTS(translate_slice_op)}, {"SOFTMAX", DEQUANTIZE_INPUTS(softmax)}, - {"SPACE_TO_BATCH_ND", OP_CONVERT_TYPE_RENAME(translate_space_to_batch_nd_op, "SpaceToBatchND")}, + {"SPACE_TO_BATCH_ND", DEQUANTIZE_INPUTS(translate_space_to_batch_nd_op)}, {"SPACE_TO_DEPTH", DEQUANTIZE_INPUTS(translate_space_to_depth_op)}, // SPARSE_TO_DENSE {"SPLIT", DEQUANTIZE_INPUTS(translate_split_op)}, {"SPLIT_V", DEQUANTIZE_INPUTS(translate_split_v_op)}, {"SQRT", DEQUANTIZE_INPUTS(translate_sqrt_op)}, {"SQUARE", DEQUANTIZE_INPUTS(translate_square_op)}, - {"SQUARED_DIFFERENCE", translate_binary}, + {"SQUARED_DIFFERENCE", translate_binary}, {"SQUEEZE", DEQUANTIZE_INPUTS(translate_squeeze_op)}, {"STRIDED_SLICE", DEQUANTIZE_INPUTS(translate_strided_slice_op)}, {"SUB", translate_binary_op_with_activation}, - {"SUM", translate_reduce_op}, + {"SUM", translate_reduce_op}, // SVDF - {"TANH", translate_unary}, + {"TANH", translate_unary}, {"TILE", DEQUANTIZE_INPUTS(translate_tile_op)}, - {"TOPK_V2", OP_CONVERT_TYPE_RENAME(translate_top_k_v2_op, "TopKV2")}, + {"TOPK_V2", DEQUANTIZE_INPUTS_WITH_NAMED_OUTPUTS(translate_top_k_v2_op)}, {"TRANSPOSE", DEQUANTIZE_INPUTS(translate_transpose_op)}, {"TRANSPOSE_CONV", DEQUANTIZE_INPUTS(transpose_conv)}, // UNIDIRECTIONAL_SEQUENCE_LSTM @@ -191,7 +190,7 @@ std::map get_supported_ops() { // UNSORTED_SEGMENT_PROD // UNSORTED_SEGMENT_SUM // VAR_HANDLE - {"WHERE", OP_CONVERT_TYPE_RENAME(translate_where_op, "Where")}, + {"WHERE", DEQUANTIZE_INPUTS(translate_where_op)}, {"WHILE", while_op}, {"ZEROS_LIKE", DEQUANTIZE_INPUTS(translate_zeros_like_op)}, }; diff --git a/src/plugins/intel_gpu/README.md b/src/plugins/intel_gpu/README.md index cf220164645416..1b49895790778e 100644 --- a/src/plugins/intel_gpu/README.md +++ b/src/plugins/intel_gpu/README.md @@ -30,6 +30,7 @@ GPU Plugin contains the following components: * [Debug utils](./docs/gpu_debug_utils.md) * [OpenCL Runtime issues troubleshooting](./docs/gpu_plugin_driver_troubleshooting.md) * [GPU plugin unit test](./docs/gpu_plugin_unit_test.md) +* [Run benchmark from device_mem](./docs/use_device_mem.md) ## Documentation on dynamic-shape This contents explain the internal implementation of dynamic shape support in the GPU Plugin. For general usage of dynamic shape and limitations of the GPU plugin, please refer to this link: [GPU Device — OpenVINO™ documentation - Version(2023.1)](https://docs.openvino.ai/2023.1/openvino_docs_OV_UG_supported_plugins_GPU.html#dynamic-shapes). diff --git a/src/plugins/intel_gpu/docs/dynamic_shape/in_memory_cache.md b/src/plugins/intel_gpu/docs/dynamic_shape/in_memory_cache.md index 79847c6648bb38..e961178d48acd2 100644 --- a/src/plugins/intel_gpu/docs/dynamic_shape/in_memory_cache.md +++ b/src/plugins/intel_gpu/docs/dynamic_shape/in_memory_cache.md @@ -2,21 +2,21 @@ ## Motivation -When creating a primitive_impl in the Dynamic Shape model, if each primitive_impls are created about the same primitive with the same type and input / output shapes, it creates duplicated primitive_impl including new cl kernel build for same kernel source. this may result in inefficiency and performance degradation due to build the exactly same cl kernel source code multiple times for same layout and primitive type on the run time for dynamic model. To resolve this issue, ImplementationCache handle is newly introduced. +When creating a primitive_impl in the Dynamic Shape model, if each primitive_impls are created about the same primitive with the same type and input / output shapes, it creates duplicated primitive_impl including new cl kernel build for same kernel source. this may result in inefficiency and performance degradation due to build the exactly same cl kernel source code multiple times for same layout and primitive type on the run time for dynamic model. To resolve this issue, `ImplementationsCache` is newly introduced. ## Property -* ImplementationCache only handles primitive_impl which is created in primitive_inst::update_impl() and primitive_inst::update_weights() on dynamic shape model. In the case of static shape, kernels_cache handles static shape kernel duplication. -* ImplementationCache inherits LRUCacheThreadSafe which is ThreadSafe version of LRUCache which handles primitive_impl cache by increasing the cache hit rate for frequently used items. Therefore, ImplementationCache optimizes the performance of dynamic execution through frequently used primitive_impl. -* Since cldnn::program creates ImplementationCache as unique_ptr at cldnn::program constructor, its lifecycle is set to cldnn::program. -* ImplementationCache supports multi-stream, so the cldnn::network of each stream manages primitive_impl in same cache. -* ImplementationCache Capacity is set to 10000 by default, but may change in the future optimization. +* `ImplementationsCache` only handles primitive_impl which is created in `primitive_inst::update_impl()` and `primitive_inst::update_weights()` on dynamic shape model. In the case of static shape, kernels_cache handles static shape kernel duplication. +* `ImplementationsCache` inherits LruCacheThreadSafe which is ThreadSafe version of LruCache which handles primitive_impl cache by increasing the cache hit rate for frequently used items. Therefore, `ImplementationsCache` optimizes the performance of dynamic execution through frequently used primitive_impl. +* Since cldnn::program creates ImplementationsCache as unique_ptr at `cldnn::program `constructor, its lifecycle is set to `cldnn::program`. +* `ImplementationsCache` supports multi-stream, so the cldnn::network of each stream manages primitive_impl in same cache. +* `ImplementationsCache` Capacity is set to 10000 by default, but may change in the future optimization. ## Usages -ImplementationCache is used to handle primitive_impl cache at primitive_inst::update_impl() and primitive_inst::update_weights() in dynamic shape model. +`ImplementationsCache` is used to handle primitive_impl cache at `primitive_inst::update_impl()` and `primitive_inst::update_weights()` in dynamic shape model. -* In primitive_inst::update_impl(), it looks up the cache with key which is hash value of kernel_impl_param which is updated by the current primitive_inst. If it is not found from ImplementationCache, new primitive_impl is created and save it into the cache. -* In primitive_inst::update_weights(), if it is not found a primitive_impl with a hash key value which matches the weights_reorder_kernel_params of the primitive inst, it also create a new primitive_impl for weight reorder and put it in the cache. +* In `primitive_inst::update_impl()`, it looks up the cache with key which is hash value of kernel_impl_param which is updated by the current primitive_inst. If it is not found from `ImplementationsCache`, new primitive_impl is created and save it into the cache. +* In `primitive_inst::update_weights()`, if it is not found a primitive_impl with a hash key value which matches the weights_reorder_kernel_params of the primitive inst, it also create a new primitive_impl for weight reorder and put it in the cache. diff --git a/src/plugins/intel_gpu/docs/use_device_mem.md b/src/plugins/intel_gpu/docs/use_device_mem.md new file mode 100644 index 00000000000000..370f86815b93d6 --- /dev/null +++ b/src/plugins/intel_gpu/docs/use_device_mem.md @@ -0,0 +1,28 @@ +# Introduction + +This document describes the use of '--use_device_mem' option in benchmark_app. It makes performance difference for the platforms where memory access for host memory and device memory are not identical. Discrete GPUs and recent version of iGPU get performance boost from this feature. + +# Motivation +You can achieve best GPU performance when input data is placed on device memory. Intel OpenCL supports to specify such placement with USM(Unified Shared Memory) feature. It is recommended to place the input data on device memory if possible. For example, if the input data is decoded from a video stream by GPU, it is recommended to use that directly on GPU. On the other hand, if input data is generated from CPU, it is not possible to use this feature. +The bottom line is that the usage of this feature depends on the application data flow. If possible, please place the input data on device memory. + +# Benchmark_app support for device memory +OpenVINO benchmark_app sample contains feature to mimic the behavior of placing input data on device memory. It allocates input and output of the neural network on device memory. You can use feature with use_device_mem option from benchmark_app. + +### Restriction of use_device_mem +Currently, benchmark_app does not support to fill input data when use_device_mem is on. Input data is filled with random numbers. It is fine to measure performance for the networks where performance does not depend on the input data. However, if the target network performance depends on the input data, this option might report an incorrect result. For example, some object detection networks contain NMS layer and its execution time depends on the input data. In such detection network, it is not recommended to measure performance with use_device_mem option. + +### How to build sample for use_device_mem (on Windows) +The option depends on Intel OpenCL feature of USM memory. To use the option, you need to build sample with OpenCL enabled. Here's steps to build sample application with OpenCL. +1. Setup env variable for compiler and OpenVINO release package +1. \> git clone https://github.com/microsoft/vcpkg +1. \> cd vcpkg +1. \> .\bootstrap-vcpkg.bat +1. \> vcpkg search opencl +1. \> vcpkg install opencl +1. openvino_install\samples\cpp> cmake -DCMAKE_BUILD_TYPE=Release -B build -DCMAKE_TOOLCHAIN_FILE=path/to/vcpkg/scripts/buildsystems/vcpkg.cmake +1. openvino_install\samples\cpp> cmake --build build --config Release --parallel + +### How to build sample for use_device_mem (on Ubuntu) +1. \# apt install opencl-c-headers opencl-clhpp-headers +1. Build OpenVINO cpp sample with build script diff --git a/tests/model_hub_tests/models_hub_common/utils.py b/tests/model_hub_tests/models_hub_common/utils.py index 9c5a02207b5729..d223f3bc984c6e 100644 --- a/tests/model_hub_tests/models_hub_common/utils.py +++ b/tests/model_hub_tests/models_hub_common/utils.py @@ -56,11 +56,26 @@ def get_models_list_not_skipped(model_list_file: str, skip_list_file: str): def compare_two_tensors(ov_res, fw_res, eps): is_ok = True - if not np.allclose(ov_res, fw_res, atol=eps, rtol=eps, equal_nan=True): + if ov_res.dtype.type == str or ov_res.dtype.type == np.str_ or ov_res.dtype.type == np.object_: + ov_res = ov_res.astype('U') + # TF can represent string tensors in different format: array of bytestreams + # so we have to align formats of both string tensors, for example, to unicode + if ov_res.dtype.type != fw_res.dtype.type: + try: + fw_res = fw_res.astype('U') + except: + # ref_array of object type and each element must be utf-8 decoded + utf8_decoded_elems = [elem.decode('UTF-8') for elem in fw_res.flatten()] + fw_res = np.array(utf8_decoded_elems, dtype=str).reshape(fw_res.shape) + is_ok = np.array_equal(ov_res, fw_res) + elif ov_res.dtype == bool: + is_ok = np.array_equal(ov_res, fw_res) + elif not np.allclose(ov_res, fw_res, atol=eps, rtol=eps, equal_nan=True): is_ok = False max_diff = np.abs(ov_res.astype(np.float32) - fw_res.astype(np.float32)).max() print("Max diff is {}".format(max_diff)) - else: + + if is_ok: print("Accuracy validation successful!\n") print("absolute eps: {}, relative eps: {}".format(eps, eps)) return is_ok diff --git a/tests/model_hub_tests/tensorflow/model_lists/precommit_read_model b/tests/model_hub_tests/tensorflow/model_lists/precommit_read_model index 8e24ef1cbd8833..2a40914e660d84 100644 --- a/tests/model_hub_tests/tensorflow/model_lists/precommit_read_model +++ b/tests/model_hub_tests/tensorflow/model_lists/precommit_read_model @@ -4,4 +4,6 @@ mil-nce/s3d,https://www.kaggle.com/models/deepmind/mil-nce/frameworks/tensorFlow yamnet,https://www.kaggle.com/models/google/yamnet/frameworks/tensorFlow2/variations/yamnet/versions/1 universal-sentence-encoder-multilingual,https://www.kaggle.com/models/google/universal-sentence-encoder/frameworks/tensorFlow2/variations/multilingual/versions/2 movenet/singlepose/lightning,https://www.kaggle.com/models/google/movenet/frameworks/tensorFlow2/variations/singlepose-lightning/versions/4 -imagenet/resnet_v2_50/feature_vector,https://www.kaggle.com/models/google/resnet-v2/frameworks/tensorFlow2/variations/50-feature-vector/versions/2 \ No newline at end of file +imagenet/resnet_v2_50/feature_vector,https://www.kaggle.com/models/google/resnet-v2/frameworks/tensorFlow2/variations/50-feature-vector/versions/2 +# LookupTableImportV2 is terminating node and is needed for conversion +openimages_v4/ssd/mobilenet_v2,https://www.kaggle.com/models/google/mobilenet-v2/frameworks/tensorFlow1/variations/openimages-v4-ssd-mobilenet-v2/versions/1 \ No newline at end of file