diff --git a/src/frontends/tensorflow_common/src/op/expand_dims.cpp b/src/frontends/tensorflow_common/src/op/expand_dims.cpp index b3b37ad38cc302..a40e5c9b1bc6df 100644 --- a/src/frontends/tensorflow_common/src/op/expand_dims.cpp +++ b/src/frontends/tensorflow_common/src/op/expand_dims.cpp @@ -3,7 +3,13 @@ // #include "common_op_table.hpp" +#include "helper_ops/complex_type_mark.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/less.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/subtract.hpp" #include "openvino/op/unsqueeze.hpp" +#include "utils.hpp" using namespace std; using namespace ov::op; @@ -14,9 +20,31 @@ namespace tensorflow { namespace op { OutputVector translate_expand_dims_op(const NodeContext& node) { - default_op_checks(node, 2, {"ExpandDims", "EXPAND_DIMS"}); + default_op_checks(node, 2, {"ExpandDims", "EXPAND_DIMS"}, true); auto input = node.get_input(0); auto axis = node.get_input(1); + auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); + + if (complex_type_mark) { + element::Type complex_part_type = complex_type_mark->get_complex_part_type(); + input = complex_type_mark->input_value(0); + + auto const_zero = create_same_type_const_scalar(axis, 0); + + auto is_axis_neg = make_shared(axis, const_zero); + + auto const_one = create_same_type_const_scalar(axis, 1); + auto axis_min_one = make_shared(axis, const_one); + + auto new_axis = make_shared(is_axis_neg, axis_min_one, axis); + + auto unsqueeze = make_shared(input, new_axis); + + set_node_name(node.get_name(), unsqueeze); + auto complex_result = make_shared(unsqueeze, complex_part_type); + return {complex_result}; + } + auto unsqueeze = make_shared(input, axis); set_node_name(node.get_name(), unsqueeze); return {unsqueeze}; diff --git a/src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp b/src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp index 71623f32843eac..63adae28ddabf3 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp @@ -4,17 +4,15 @@ #pragma once -#include "openvino/runtime/threading/cpu_streams_executor.hpp" +#include "openvino/runtime/threading/istreams_executor.hpp" #include "intel_gpu/graph/topology.hpp" #include "intel_gpu/graph/program.hpp" #include "intel_gpu/graph/serialization/binary_buffer.hpp" -#include "intel_gpu/runtime/compounds.hpp" #include "intel_gpu/runtime/memory.hpp" #include "intel_gpu/runtime/engine.hpp" #include "intel_gpu/runtime/event.hpp" #include "intel_gpu/runtime/stream.hpp" -#include "intel_gpu/runtime/lru_cache.hpp" #include "intel_gpu/runtime/shape_predictor.hpp" #include "intel_gpu/plugin/variable_state.hpp" @@ -211,7 +209,7 @@ struct network { bool is_dynamic() const { return _is_dynamic; } size_t get_weights_cache_capacity() const { return _weights_cache_capacity; } - memory_pool& get_memory_pool() { + memory_pool& get_memory_pool() const { return *_memory_pool; } @@ -284,7 +282,9 @@ struct network { void dump_memory_pool(std::string dump_path, int64_t curr_iter); #ifdef GPU_DEBUG_CONFIG - int64_t iteration = 0; + mutable int64_t iteration = 0; + friend class NetworkDebugHelper; + friend class NodeDebugHelper; #endif }; } // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/debug_helper.cpp b/src/plugins/intel_gpu/src/graph/debug_helper.cpp new file mode 100644 index 00000000000000..7f7071e704683e --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/debug_helper.cpp @@ -0,0 +1,526 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "debug_helper.hpp" +#include "openvino/util/file_util.hpp" + +#ifdef GPU_DEBUG_CONFIG + +#include "to_string_utils.h" +#include "loop_inst.h" +#include "condition_inst.h" +#include "program_dump_graph.h" + +#include +#include +#include + +namespace cldnn { + +namespace { + +float convert_element(int64_t i) { return static_cast(i); } +float convert_element(int32_t i) { return static_cast(i); } + +float convert_element(float f) { return f; } + +float convert_element(ov::float16 h) { return static_cast(h); } + +size_t get_x_pitch(const layout& layout) { + try { + auto tensor_x0 = tensor(batch(0), feature(0), spatial(0, 0, 0, 0)); + auto tensor_x1 = tensor(batch(0), feature(0), spatial(1, 0, 0, 0)); + auto x0 = layout.get_linear_offset(tensor_x0); + auto x1 = layout.get_linear_offset(tensor_x1); + return (x1 - x0); + } catch (...) { + // When spatial size of x=0, x_pitch is meaningless + return 0; + } +} + +template +void dump(memory::ptr mem, stream& stream, std::ofstream& file_stream, bool dump_raw) { + auto&& size = mem->get_layout().get_tensor(); + + GPU_DEBUG_GET_INSTANCE(debug_config); + auto batch_size = std::max(std::min(debug_config->dump_layers_limit_batch, size.batch[0]), 1); + tensor tmp_size(size); + tmp_size.batch[0] = batch_size; + if (tmp_size == size) { + file_stream << "shape: " << size.to_string() << " "; + file_stream << "(count: " << size.count() + << ", original format: " << cldnn::fmt_to_str(mem->get_layout().format) << ")" + << (dump_raw ? " raw data" : "") << std::endl; + } else { + file_stream << "shape: " << tmp_size.to_string() << " "; + file_stream << "(count: " << tmp_size.count() + << ", original format: " << cldnn::fmt_to_str(mem->get_layout().format) + << ", original shape: " << size.to_string() << ")" + << (dump_raw ? " raw data" : "") << std::endl; + } + + if (size.count() == 0) { + file_stream << "Empty buffer" << std::endl; + return; + } + + mem_lock lock(mem, stream); + auto mem_ptr = lock.data(); + auto x_pitch = get_x_pitch(mem->get_layout()); + std::stringstream buffer; + + if (!dump_raw) { + for (cldnn::tensor::value_type g = 0; g < size.group[0]; ++g) { + for (cldnn::tensor::value_type b = 0; b < batch_size; ++b) { + for (cldnn::tensor::value_type f = 0; f < size.feature[0]; ++f) { + for (cldnn::tensor::value_type w = 0; w < size.spatial[3]; ++w) { + for (cldnn::tensor::value_type z = 0; z < size.spatial[2]; ++z) { + for (cldnn::tensor::value_type y = 0; y < size.spatial[1]; ++y) { + cldnn::tensor t(cldnn::group(g), cldnn::batch(b), cldnn::feature(f), cldnn::spatial(0, y, z, w)); + size_t input_it = mem->get_layout().get_linear_offset(t); + + for (cldnn::tensor::value_type x = 0; x < size.spatial[0]; ++x, input_it += x_pitch) { + buffer << std::fixed << std::setprecision(6) << convert_element(mem_ptr[input_it]) << std::endl; + } + } + } + } + } + } + } + } else { + for (size_t i = 0; i < lock.size(); ++i) { + buffer << std::fixed << std::setprecision(6) << convert_element(mem_ptr[i]) << std::endl; + } + } + file_stream << buffer.str(); +} + +void unpack(cldnn::data_types type, uint8_t input, int8_t &v0, int8_t &v1) { + if (type == cldnn::data_types::i4) { + char s_bit = (input & 0x08); + char mask = s_bit > 0 ? 0xF0 : 0x00; + v0 = (input & 0x0F) | mask; + + input >>= 4; + s_bit = (input & 0x08); + mask = s_bit > 0 ? 0xF0 : 0x00; + v1 = (input & 0x0F) | mask; + } else if (type == cldnn::data_types::u4) { + v0 = input & 0x0F; + v1 = input >> 4; + } else { + OPENVINO_ASSERT(false, "not supported unpacking"); + } +} + +void dump_i4u4(cldnn::data_types type, memory::ptr mem, stream& stream, std::ofstream& file_stream, bool dump_raw) { + auto&& size = mem->get_layout().get_tensor(); + + GPU_DEBUG_GET_INSTANCE(debug_config); + auto batch_size = std::max(std::min(debug_config->dump_layers_limit_batch, size.batch[0]), 1); + tensor tmp_size(size); + tmp_size.batch[0] = batch_size; + if (tmp_size == size) { + file_stream << "shape: " << size.to_string() << " "; + file_stream << "(count: " << size.count() + << ", original format: " << cldnn::fmt_to_str(mem->get_layout().format) << ")" + << (dump_raw ? " raw data" : "") << std::endl; + } else { + file_stream << "shape: " << tmp_size.to_string() << " "; + file_stream << "(count: " << tmp_size.count() + << ", original format: " << cldnn::fmt_to_str(mem->get_layout().format) + << ", original shape: " << size.to_string() << ")" + << (dump_raw ? " raw data" : "") << std::endl; + } + + if (size.count() == 0) { + file_stream << "Empty buffer" << std::endl; + return; + } + + mem_lock lock(mem, stream); + auto mem_ptr = lock.data(); + std::stringstream buffer; + + if (dump_raw) { + for (size_t i = 0; i < lock.size(); ++i) { + int8_t v0, v1; + unpack(type, mem_ptr[i], v0, v1); + buffer << std::fixed << std::setprecision(6) << static_cast(v0) << std::endl; + buffer << std::fixed << std::setprecision(6) << static_cast(v1) << std::endl; + } + } else { + std::cout << __func__ << " supports raw dump only" << std::endl; + } + file_stream << buffer.str(); +} + +void log_memory_to_file(memory::ptr mem, layout data_layout, stream& stream, std::string layerName, bool dump_raw) { + std::cout << "Dump " << (dump_raw ? "raw " : "") << layerName << std::endl; + GPU_DEBUG_GET_INSTANCE(debug_config); + std::string filename = debug_config->get_name_for_dump(layerName); + filename = debug_config->dump_layers_path + filename + ".txt"; + std::ofstream file_stream(filename); + if (!mem) { + file_stream << "Empty" << std::endl; + return; + } + + // Reinterpret buffer to represent actual data layout + auto actual_mem = mem->get_engine()->reinterpret_buffer(*mem, data_layout); + + auto mem_dt = actual_mem->get_layout().data_type; + if (mem_dt == cldnn::data_types::f32) + dump(actual_mem, stream, file_stream, dump_raw); + else if (mem_dt == cldnn::data_types::f16) + dump(actual_mem, stream, file_stream, dump_raw); + else if (mem_dt == cldnn::data_types::i64) + dump(actual_mem, stream, file_stream, dump_raw); + else if (mem_dt == cldnn::data_types::i32) + dump(actual_mem, stream, file_stream, dump_raw); + else if (mem_dt == cldnn::data_types::i8) + dump(actual_mem, stream, file_stream, dump_raw); + else if (mem_dt == cldnn::data_types::u8) + dump(actual_mem, stream, file_stream, dump_raw); + else if (mem_dt == cldnn::data_types::u8) + dump(actual_mem, stream, file_stream, dump_raw); + else if (mem_dt == cldnn::data_types::i4 || mem_dt == cldnn::data_types::u4) + dump_i4u4(mem_dt, actual_mem, stream, file_stream, dump_raw); + else + std::cout << "Dump for this data type is not supported: " << dt_to_str(mem_dt) << std::endl; +} + +} // namespace + +static std::string get_file_path_for_binary_dump(cldnn::layout layout, std::string name) { + std::string filename; + std::string data_type = ov::element::Type(layout.data_type).get_type_name(); + std::string format = layout.format.to_string(); + std::string tensor; + auto dims = layout.get_dims(); + for (size_t r = 0 ; r < layout.get_rank() ; r++) { + tensor += ("_" + to_string(dims[r])); + } + +#ifdef GPU_DEBUG_CONFIG + GPU_DEBUG_GET_INSTANCE(debug_config); + std::string layer_name = debug_config->get_name_for_dump(name); + filename = debug_config->dump_layers_path + layer_name + + "__" + data_type + "_" + tensor + "__" + format + ".bin"; +#endif + return filename; +} + +NodeDebugHelper::NodeDebugHelper(const primitive_inst& inst) + : m_inst(inst) + , m_stream(inst.get_network().get_stream()) + , m_network(inst.get_network()) + , m_program(inst.get_network().get_program().get()) + , m_iter(m_network.iteration) { + // Load binary dump for input layers + if (!debug_config->load_layers_raw_dump.empty()) { + const std::string layer_name = m_inst.id(); + auto files = debug_config->get_filenames_for_matched_layer_loading_binaries(layer_name); + if (!files.empty()) { + if (m_inst.is_input()) { + // Loading binary dumps for output tensors of input-layers : only one output exists or index(dstN) exists + auto dump_file = debug_config->get_matched_from_filelist(files, "_dst0__"); + OPENVINO_ASSERT((files.size() == 1 || dump_file.length() != 0), "Unexpected binary dump for input layer"); + + OPENVINO_ASSERT(files.size() == m_inst.outputs_memory_count(), "Mis-match dump file count"); + + for (size_t i = 0; i < m_inst.outputs_memory_count(); i++) { + auto dump_file = files[0]; + if (files.size() > 1 || m_inst.outputs_memory_count() != 1) { + std::string pattern = "_dst" + std::to_string(i) + "__"; + dump_file = debug_config->get_matched_from_filelist(files, pattern); + } + OPENVINO_ASSERT((dump_file.length() > 0), "Could not find expected pattern '_dst[N]__' for binary dump"); + GPU_DEBUG_COUT << " Load binary dump : " << dump_file << " for " << layer_name << std::endl; + + std::vector bin = ov::util::load_binary(dump_file); + OPENVINO_ASSERT(!bin.empty(), "Failure loading binary from OV_GPU_LoadDumpRawBinary : " + dump_file); + + auto output_mem = m_inst.output_memory_ptr(i); + OPENVINO_ASSERT(output_mem->size() == bin.size(), "memory size mis-match for OV_GPU_LoadDumpRawBinary : " + layer_name + + "\n Expected size : " + to_string(output_mem->size()) + ", Binary : " + to_string(bin.size())); + + output_mem->copy_from(m_stream, static_cast(&bin[0]), true); + } + } else { + auto check_dst = debug_config->get_matched_from_filelist(files, "_dst0__"); + OPENVINO_ASSERT(check_dst.length() == 0, "Expected to load binaries for inputs of " + layer_name); + + // Loading input tensors for any layer + auto dump_file = debug_config->get_matched_from_filelist(files, "_src0__"); + OPENVINO_ASSERT(dump_file.length() != 0, "Could not find expected pattern '_src[N]__' for binary dump input : " + layer_name); + + for (size_t i = 0; i < m_inst.dependencies().size(); i++) { + auto dump_file = files[0]; + if (files.size() > 1 || m_inst.dependencies().size() != 1) { + std::string pattern = "_src" + std::to_string(i) + "__"; + dump_file = debug_config->get_matched_from_filelist(files, pattern); + } + if (dump_file.length() == 0) { + GPU_DEBUG_COUT << " Skip loading for input(" << i << ") of " << layer_name << std::endl; + continue; + } + OPENVINO_ASSERT((dump_file.length() > 0), "Could not find expected pattern '_src[N]__' for binary dump input"); + GPU_DEBUG_COUT << " Load binary dump : " << dump_file << " for input(" << i << ") of " << layer_name << std::endl; + + std::vector bin = ov::util::load_binary(dump_file); + OPENVINO_ASSERT(!bin.empty(), "Failure loading binary from OV_GPU_LoadDumpRawBinary : " + dump_file); + + auto input_mem = m_inst.dep_memory_ptr(i); + if (input_mem->size() != bin.size()) { + std::cout << "WARNING: memory size mis-match for OV_GPU_LoadDumpRawBinary : " + layer_name + << " " << input_mem->size() << " / " << bin.size() << std::endl; + bin.resize(input_mem->size()); + } + + input_mem->copy_from(m_stream, static_cast(&bin[0]), true); + } + } + } + } + + // Dump input buffers of 'inst' + if (debug_config->dump_layers_path.length() > 0) { + const std::string layer_name = inst.id(); + + if (debug_config->is_target_iteration(m_iter) && + debug_config->dump_layers_dst_only == 0 && debug_config->is_layer_for_dumping(layer_name)) { + std::string debug_str_for_bin_load = " Command for loading : OV_GPU_LoadDumpRawBinary=\"" + layer_name + ":"; + for (size_t i = 0; i < m_inst.dependencies().size(); i++) { + std::string name = get_file_prefix() + layer_name + "_src" + std::to_string(i); + auto input_mem = m_inst.dep_memory_ptr(i); + if (input_mem == nullptr) { + GPU_DEBUG_COUT << " input_mem_" << i << " is nullptr. Nothing to dump." << std::endl; + continue; + } + + auto dep = m_inst.dependencies().at(i); + auto input_layout = dep.first->get_output_layout(dep.second); + GPU_DEBUG_IF(debug_config->dump_layers_binary) { + // Binary dump : raw + auto filename = get_file_path_for_binary_dump(input_layout, name); + + mem_lock lock(input_mem, m_stream); + ov::util::save_binary(filename, lock.data(), input_mem->size()); + GPU_DEBUG_COUT << " Dump layer src : " << layer_name << " to " << filename << std::endl; + debug_str_for_bin_load += (filename + ","); + } else { + log_memory_to_file(input_mem, + input_layout, + m_stream, + name, + debug_config->dump_layers_raw); + } + } + + if (debug_config->dump_layers_binary && !inst.is_input()) { + debug_str_for_bin_load[debug_str_for_bin_load.size()-1] = '\"'; + GPU_DEBUG_COUT << debug_str_for_bin_load << std::endl; + } + } + } +} + + +NodeDebugHelper::~NodeDebugHelper() { + // Dump output buffers of 'inst' + if (debug_config->dump_layers_path.length() > 0) { + m_stream.finish(); + const std::string layer_name = m_inst.id(); + + GPU_DEBUG_IF(debug_config->is_target_iteration(m_iter) && + debug_config->is_layer_for_dumping(layer_name, m_inst.is_output(), m_inst.is_input())) { + std::string debug_str_for_bin_load = " Command for loading : OV_GPU_LoadDumpRawBinary=\"" + + layer_name + ":"; + for (size_t i = 0; i < m_inst.outputs_memory_count(); i++) { + std::string name = get_file_prefix() + "_dst" + std::to_string(i); + auto output_mem = m_inst.output_memory_ptr(i); + if (output_mem == nullptr) { + GPU_DEBUG_COUT << " output_mem is nullptr. Nothing to dump." << std::endl; + continue; + } + + GPU_DEBUG_IF(debug_config->dump_layers_binary) { + // Binary dump : raw + auto output_layout = m_inst.get_output_layout(i); + auto filename = get_file_path_for_binary_dump(output_layout, name); + + mem_lock lock(output_mem, m_stream); + ov::util::save_binary(filename, lock.data(), output_mem->size()); + GPU_DEBUG_COUT << " Dump layer dst : " << layer_name << " to " << filename << std::endl; + debug_str_for_bin_load += (filename + ","); + } else { + // Text dump + log_memory_to_file(output_mem, m_inst.get_output_layout(i), m_stream, name, debug_config->dump_layers_raw); + } + } + + GPU_DEBUG_IF(debug_config->dump_layers_binary && m_inst.is_input()) { + debug_str_for_bin_load[debug_str_for_bin_load.size()-1] = '\"'; + GPU_DEBUG_COUT << debug_str_for_bin_load << std::endl;; + } + } + } +} + +NetworkDebugHelper::NetworkDebugHelper(const network& net) + : m_network(net) + , m_iter(net.iteration) { + auto net_id = m_network.get_id(); + GPU_DEBUG_IF(debug_config->dump_memory_pool > 0) { + auto& iters = debug_config->dump_memory_pool_iters; + if (iters.empty() || iters.find(m_iter) != iters.end()) { + GPU_DEBUG_COUT << "============================================================================" << std::endl; + GPU_DEBUG_COUT << "Start network execution (net_id : " << net_id << ", iter :" << m_iter << ")" << std::endl; + if (m_iter == 0 && net_id > 0) { + dump_memory_pool(debug_config->dump_memory_pool_path, m_iter); + GPU_DEBUG_COUT << "============================================================================" << std::endl; + } + } + } else { + GPU_DEBUG_TRACE << "============================================================================" << std::endl; + GPU_DEBUG_TRACE << "Start network execution (net_id : " << net_id << ", iter :" << m_iter << ")" << std::endl; + } + + if (debug_config->list_layers == 1) { + for (auto& inst : m_network._exec_order) { + GPU_DEBUG_COUT << inst->id() << std::endl; + if (inst->get_node().is_type()) { + auto& loop_node = inst->get_node().as(); + for (auto& prim : loop_node.get_body_program()->get_processing_order()) { + GPU_DEBUG_COUT << "\t" << prim->id() << std::endl; + } + } else if (inst->get_node().is_type()) { + auto& cond_node = inst->get_node().as(); + GPU_DEBUG_COUT << "* Branch_True" << std::endl; + for (auto& prim : cond_node.get_branch_true().inner_program->get_processing_order()) { + GPU_DEBUG_COUT << "\t" << prim->id() << std::endl; + } + GPU_DEBUG_COUT << "* Branch_False" << std::endl; + for (auto& prim : cond_node.get_branch_false().inner_program->get_processing_order()) { + GPU_DEBUG_COUT << "\t" << prim->id() << std::endl; + } + } + } + + if (!m_network.is_internal()) + exit(0); + } +} + +NetworkDebugHelper::~NetworkDebugHelper() { + auto prog = m_network.get_program().get(); + auto net_id = m_network.get_id(); + // print '-data_shape' option for benchmark_app + if (debug_config->print_input_data_shapes == 1) { + std::stringstream data_shape_str; + auto add_string = [&data_shape_str](std::string str) { + data_shape_str << ((data_shape_str.rdbuf()->in_avail() == 0) ? " -data_shape " : ",") << str; + }; + + for (auto& inst : m_network._exec_order) { + auto name = inst->id(); + auto pos = name.find(':'); + auto type = name.substr(0, pos); + name.erase(0, pos + 1); + if (inst->is_input() && type == "parameter") { + add_string(name + inst->get_output_layout().get_partial_shape().to_string()); + } + } + + GPU_DEBUG_COUT << "[program:" << std::setw(2) << ((prog != nullptr) ? prog->get_id() : 0) + << "|network:" << std::setw(2) << net_id << "|iter:" << std::setw(4) << m_iter << "] benchmark_app cmd: " + << data_shape_str.str() << std::endl; + } + + if (!debug_config->dump_graphs.empty() && debug_config->is_target_iteration(m_iter)) { + auto get_fixed_str = [](int value, int length = 2) -> std::string { + std::ostringstream ss; + ss << std::setw(length) << std::setfill('0') << std::to_string(value); + return ss.str(); + }; + std::string path = get_dir_path(m_network.get_config()); + if (!path.empty()) { + std::ofstream ofs(path + "cldnn_program_exec_p" + get_fixed_str(prog->get_id()) + "_n" + get_fixed_str(net_id) + + "_" + get_fixed_str(m_iter, 5) + ".graph"); + dump_graph_init(ofs, *prog, [this](const primitive_id& id) -> std::shared_ptr { + return m_network.get_primitive(id); + }); + } + } + + if (debug_config->dump_memory_pool > 0) { + auto& iters = debug_config->dump_memory_pool_iters; + if (iters.empty() || iters.find(m_iter) != iters.end()) { + dump_memory_pool(debug_config->dump_memory_pool_path, m_iter); + GPU_DEBUG_COUT << "============================================================================" << std::endl; + } + } + + m_network.iteration++; +} + +void NetworkDebugHelper::dump_memory_pool(std::string dump_path, int64_t curr_iter) const { + m_network.get_memory_pool().dump(m_network.get_id(), curr_iter, dump_path); + auto get_constants_mem_size = [&](allocation_type type) -> size_t { + size_t mem_size = 0; + for (auto& prim : m_network._primitives) { + if (prim.second->get_node().is_constant()) { + for (size_t i = 0; i < prim.second->outputs_memory_count(); i++) { + if (prim.second->output_memory_ptr(i)->get_allocation_type() == type) + mem_size += prim.second->output_memory_ptr(i)->size(); + } + } + } + return mem_size; + }; + auto get_variables_mem_size = [&](allocation_type type) -> size_t { + size_t mem_size = 0; + for (auto& var : m_network.get_variables()) { + if (var.second->get_memory() && var.second->get_memory()->get_allocation_type() == type) + mem_size += var.second->get_actual_mem_size(); + } + return mem_size; + }; + auto get_mb_size = [&](int64_t size) -> std::string { + if (size == 0) return "0 MB"; + return std::to_string(static_cast(size) / (1024 * 1024)) + " MB"; + }; + int64_t usm_host_const_mem_size = get_constants_mem_size(allocation_type::usm_host); + int64_t usm_device_const_mem_size = get_constants_mem_size(allocation_type::usm_device); + int64_t usm_host_var_mem_size = get_variables_mem_size(allocation_type::usm_host); + int64_t usm_device_var_mem_size = get_variables_mem_size(allocation_type::usm_device); + int64_t host_mem_size = m_network.get_engine().get_used_device_memory(allocation_type::usm_host); + int64_t device_mem_size = m_network.get_engine().get_used_device_memory(allocation_type::usm_device); + int64_t usm_host_mem_pool_size = m_network.get_memory_pool().get_total_mem_pool_size(allocation_type::usm_host); + int64_t usm_host_etc_size = host_mem_size - usm_host_mem_pool_size + - usm_host_const_mem_size - usm_host_var_mem_size; + int64_t usm_device_mem_pool_size = m_network.get_memory_pool().get_total_mem_pool_size(allocation_type::usm_device); + int64_t usm_device_etc_size = device_mem_size - usm_device_mem_pool_size + - usm_device_const_mem_size - usm_device_var_mem_size; + GPU_DEBUG_COUT << "------------------------------------------------------------------------" << std::endl; + GPU_DEBUG_COUT << "Memory statistics for (net_id:" << m_network.get_id() << ", iter:" << curr_iter << ")" << std::endl; + GPU_DEBUG_COUT << " Total host mem size : " << get_mb_size(host_mem_size) << std::endl; + GPU_DEBUG_COUT << " * Memory pool : " << get_mb_size(usm_host_mem_pool_size) << std::endl; + GPU_DEBUG_COUT << " * Constant : " << get_mb_size(usm_host_const_mem_size) << std::endl; + GPU_DEBUG_COUT << " * Variable : " << get_mb_size(usm_host_var_mem_size) << std::endl; + GPU_DEBUG_COUT << " * ETC : " << get_mb_size(usm_host_etc_size) << std::endl; + GPU_DEBUG_COUT << " Total device mem size : " << get_mb_size(device_mem_size) << std::endl; + GPU_DEBUG_COUT << " * Memory pool : " << get_mb_size(usm_device_mem_pool_size) << std::endl; + GPU_DEBUG_COUT << " * Constant : " << get_mb_size(usm_device_const_mem_size) << std::endl; + GPU_DEBUG_COUT << " * Variable : " << get_mb_size(usm_device_var_mem_size) << std::endl; + GPU_DEBUG_COUT << " * ETC : " << get_mb_size(usm_device_etc_size) << std::endl; + GPU_DEBUG_COUT << "------------------------------------------------------------------------" << std::endl; +} + +} // namespace cldnn + +#endif // GPU_DEBUG_CONFIG diff --git a/src/plugins/intel_gpu/src/graph/debug_helper.hpp b/src/plugins/intel_gpu/src/graph/debug_helper.hpp new file mode 100644 index 00000000000000..c7c6bd006af1db --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/debug_helper.hpp @@ -0,0 +1,69 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "intel_gpu/graph/network.hpp" +#include "intel_gpu/graph/program.hpp" +#include "intel_gpu/runtime/stream.hpp" +#include "intel_gpu/runtime/debug_configuration.hpp" +#include "primitive_inst.h" + +namespace cldnn { + +#ifdef GPU_DEBUG_CONFIG + +class NodeDebugHelper { +public: + NodeDebugHelper(const primitive_inst& inst); + ~NodeDebugHelper(); + +private: + std::string get_iteration_prefix() { + if (m_iter < 0) + return std::string(""); + return std::to_string(m_iter) + "_"; + } + + std::string get_file_prefix() { + auto prog_id = ((m_program != nullptr) ? m_program->get_id() : 0); + auto net_id = m_network.get_id(); + + return "program" + std::to_string(prog_id) + "_network" + std::to_string(net_id) + "_" + get_iteration_prefix() + m_inst.id(); + } + + + const primitive_inst& m_inst; + stream& m_stream; + const network& m_network; + const program* m_program; + const size_t m_iter; + + const debug_configuration* debug_config = cldnn ::debug_configuration ::get_instance(); +}; + +class NetworkDebugHelper { +public: + NetworkDebugHelper(const network& net); + ~NetworkDebugHelper(); + +private: + void dump_memory_pool(std::string dump_path, int64_t curr_iter) const; + const network& m_network; + const size_t m_iter; + + const debug_configuration* debug_config = cldnn ::debug_configuration ::get_instance(); +}; + +#define NETWORK_DEBUG(net) NetworkDebugHelper __network_debug_helper(net) +#define NODE_DEBUG(inst) NodeDebugHelper __node_debug_helper(inst) + +#else + +#define NETWORK_DEBUG(...) +#define NODE_DEBUG(...) + +#endif // GPU_DEBUG_CONFIG + +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/include/program_dump_graph.h b/src/plugins/intel_gpu/src/graph/include/program_dump_graph.h index 075422a4196b38..cf5111de6b247e 100644 --- a/src/plugins/intel_gpu/src/graph/include/program_dump_graph.h +++ b/src/plugins/intel_gpu/src/graph/include/program_dump_graph.h @@ -14,6 +14,6 @@ std::string get_dir_path(const ExecutionConfig& config); void dump_graph_optimized(std::ofstream&, const program&); void dump_graph_processing_order(std::ofstream&, const program&); void dump_graph_init(std::ofstream&, const program&, - std::function(const primitive_id&)> get_primitive_inst = nullptr); + std::function(const primitive_id&)> get_primitive_inst = nullptr); void dump_graph_info(std::ofstream&, const program&); } // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/network.cpp b/src/plugins/intel_gpu/src/graph/network.cpp index 92d62782828d78..8f0e97dd51ee12 100644 --- a/src/plugins/intel_gpu/src/graph/network.cpp +++ b/src/plugins/intel_gpu/src/graph/network.cpp @@ -4,7 +4,6 @@ #include "intel_gpu/plugin/variable_state.hpp" #include "intel_gpu/primitives/read_value.hpp" -#include "openvino/util/file_util.hpp" #include "intel_gpu/primitives/data.hpp" #include "intel_gpu/primitives/mutable_data.hpp" @@ -31,13 +30,10 @@ #include "deconvolution_inst.h" #include "mutable_data_inst.h" #include "condition_inst.h" -#include "loop_inst.h" -#include "assign_inst.h" #include "read_value_inst.h" #include "reshape_inst.h" #include "kv_cache_inst.h" #include "program_helpers.h" -#include "to_string_utils.h" #include "program_dump_graph.h" #include @@ -51,8 +47,8 @@ #include #include +#include "debug_helper.hpp" #ifdef GPU_DEBUG_CONFIG -#include #include #include #include @@ -60,7 +56,6 @@ #endif namespace cldnn { - namespace { #ifdef GPU_DEBUG_CONFIG @@ -143,179 +138,6 @@ void dump_perf_data_raw(std::string dump_path, const std::list(i); } -float convert_element(int32_t i) { return static_cast(i); } - -float convert_element(float f) { return f; } - -float convert_element(ov::float16 h) { return static_cast(h); } - -size_t get_x_pitch(const layout& layout) { - try { - auto tensor_x0 = tensor(batch(0), feature(0), spatial(0, 0, 0, 0)); - auto tensor_x1 = tensor(batch(0), feature(0), spatial(1, 0, 0, 0)); - auto x0 = layout.get_linear_offset(tensor_x0); - auto x1 = layout.get_linear_offset(tensor_x1); - return (x1 - x0); - } catch (...) { - // When spatial size of x=0, x_pitch is meaningless - return 0; - } -} - -template -void dump(memory::ptr mem, stream& stream, std::ofstream& file_stream, bool dump_raw) { - auto&& size = mem->get_layout().get_tensor(); - - GPU_DEBUG_GET_INSTANCE(debug_config); - auto batch_size = std::max(std::min(debug_config->dump_layers_limit_batch, size.batch[0]), 1); - tensor tmp_size(size); - tmp_size.batch[0] = batch_size; - if (tmp_size == size) { - file_stream << "shape: " << size.to_string() << " "; - file_stream << "(count: " << size.count() - << ", original format: " << cldnn::fmt_to_str(mem->get_layout().format) << ")" - << (dump_raw ? " raw data" : "") << std::endl; - } else { - file_stream << "shape: " << tmp_size.to_string() << " "; - file_stream << "(count: " << tmp_size.count() - << ", original format: " << cldnn::fmt_to_str(mem->get_layout().format) - << ", original shape: " << size.to_string() << ")" - << (dump_raw ? " raw data" : "") << std::endl; - } - - if (size.count() == 0) { - file_stream << "Empty buffer" << std::endl; - return; - } - - mem_lock lock(mem, stream); - auto mem_ptr = lock.data(); - auto x_pitch = get_x_pitch(mem->get_layout()); - std::stringstream buffer; - - if (!dump_raw) { - for (cldnn::tensor::value_type g = 0; g < size.group[0]; ++g) { - for (cldnn::tensor::value_type b = 0; b < batch_size; ++b) { - for (cldnn::tensor::value_type f = 0; f < size.feature[0]; ++f) { - for (cldnn::tensor::value_type w = 0; w < size.spatial[3]; ++w) { - for (cldnn::tensor::value_type z = 0; z < size.spatial[2]; ++z) { - for (cldnn::tensor::value_type y = 0; y < size.spatial[1]; ++y) { - cldnn::tensor t(cldnn::group(g), cldnn::batch(b), cldnn::feature(f), cldnn::spatial(0, y, z, w)); - size_t input_it = mem->get_layout().get_linear_offset(t); - - for (cldnn::tensor::value_type x = 0; x < size.spatial[0]; ++x, input_it += x_pitch) { - buffer << std::fixed << std::setprecision(6) << convert_element(mem_ptr[input_it]) << std::endl; - } - } - } - } - } - } - } - } else { - for (size_t i = 0; i < lock.size(); ++i) { - buffer << std::fixed << std::setprecision(6) << convert_element(mem_ptr[i]) << std::endl; - } - } - file_stream << buffer.str(); -} - -void unpack(cldnn::data_types type, uint8_t input, int8_t &v0, int8_t &v1) { - if (type == cldnn::data_types::i4) { - char s_bit = (input & 0x08); - char mask = s_bit > 0 ? 0xF0 : 0x00; - v0 = (input & 0x0F) | mask; - - input >>= 4; - s_bit = (input & 0x08); - mask = s_bit > 0 ? 0xF0 : 0x00; - v1 = (input & 0x0F) | mask; - } else if (type == cldnn::data_types::u4) { - v0 = input & 0x0F; - v1 = input >> 4; - } else { - OPENVINO_ASSERT(false, "not supported unpacking"); - } -} - -void dump_i4u4(cldnn::data_types type, memory::ptr mem, stream& stream, std::ofstream& file_stream, bool dump_raw) { - auto&& size = mem->get_layout().get_tensor(); - - GPU_DEBUG_GET_INSTANCE(debug_config); - auto batch_size = std::max(std::min(debug_config->dump_layers_limit_batch, size.batch[0]), 1); - tensor tmp_size(size); - tmp_size.batch[0] = batch_size; - if (tmp_size == size) { - file_stream << "shape: " << size.to_string() << " "; - file_stream << "(count: " << size.count() - << ", original format: " << cldnn::fmt_to_str(mem->get_layout().format) << ")" - << (dump_raw ? " raw data" : "") << std::endl; - } else { - file_stream << "shape: " << tmp_size.to_string() << " "; - file_stream << "(count: " << tmp_size.count() - << ", original format: " << cldnn::fmt_to_str(mem->get_layout().format) - << ", original shape: " << size.to_string() << ")" - << (dump_raw ? " raw data" : "") << std::endl; - } - - if (size.count() == 0) { - file_stream << "Empty buffer" << std::endl; - return; - } - - mem_lock lock(mem, stream); - auto mem_ptr = lock.data(); - std::stringstream buffer; - - if (dump_raw) { - for (size_t i = 0; i < lock.size(); ++i) { - int8_t v0, v1; - unpack(type, mem_ptr[i], v0, v1); - buffer << std::fixed << std::setprecision(6) << static_cast(v0) << std::endl; - buffer << std::fixed << std::setprecision(6) << static_cast(v1) << std::endl; - } - } else { - std::cout << __func__ << " supports raw dump only" << std::endl; - } - file_stream << buffer.str(); -} - -void log_memory_to_file(memory::ptr mem, layout data_layout, stream& stream, std::string layerName, bool dump_raw) { - std::cout << "Dump " << (dump_raw ? "raw " : "") << layerName << std::endl; - GPU_DEBUG_GET_INSTANCE(debug_config); - std::string filename = debug_config->get_name_for_dump(layerName); - filename = debug_config->dump_layers_path + filename + ".txt"; - std::ofstream file_stream(filename); - if (!mem) { - file_stream << "Empty" << std::endl; - return; - } - - // Reinterpret buffer to represent actual data layout - auto actual_mem = mem->get_engine()->reinterpret_buffer(*mem, data_layout); - - auto mem_dt = actual_mem->get_layout().data_type; - if (mem_dt == cldnn::data_types::f32) - dump(actual_mem, stream, file_stream, dump_raw); - else if (mem_dt == cldnn::data_types::f16) - dump(actual_mem, stream, file_stream, dump_raw); - else if (mem_dt == cldnn::data_types::i64) - dump(actual_mem, stream, file_stream, dump_raw); - else if (mem_dt == cldnn::data_types::i32) - dump(actual_mem, stream, file_stream, dump_raw); - else if (mem_dt == cldnn::data_types::i8) - dump(actual_mem, stream, file_stream, dump_raw); - else if (mem_dt == cldnn::data_types::u8) - dump(actual_mem, stream, file_stream, dump_raw); - else if (mem_dt == cldnn::data_types::u8) - dump(actual_mem, stream, file_stream, dump_raw); - else if (mem_dt == cldnn::data_types::i4 || mem_dt == cldnn::data_types::u4) - dump_i4u4(mem_dt, actual_mem, stream, file_stream, dump_raw); - else - std::cout << "Dump for this data type is not supported: " << dt_to_str(mem_dt) << std::endl; -} - void wait_for_the_turn() { GPU_DEBUG_GET_INSTANCE(debug_config); bool need_to_wait; @@ -336,7 +158,6 @@ void wait_for_the_turn() { #else void dump_perf_data_raw(std::string, const std::list>&) {} -void log_memory_to_file(memory::ptr, layout, stream&, std::string, bool dump_raw) {} void wait_for_the_turn() {} #endif } // namespace @@ -346,25 +167,6 @@ static uint32_t get_unique_net_id() { return ++id_gen; } -static std::string get_file_path_for_binary_dump(cldnn::layout layout, std::string name) { - std::string filename; - std::string data_type = ov::element::Type(layout.data_type).get_type_name(); - std::string format = layout.format.to_string(); - std::string tensor; - auto dims = layout.get_dims(); - for (size_t r = 0 ; r < layout.get_rank() ; r++) { - tensor += ("_" + to_string(dims[r])); - } - -#ifdef GPU_DEBUG_CONFIG - GPU_DEBUG_GET_INSTANCE(debug_config); - std::string layer_name = debug_config->get_name_for_dump(name); - filename = debug_config->dump_layers_path + layer_name - + "__" + data_type + "_" + tensor + "__" + format + ".bin"; -#endif - return filename; -} - /* Network will always have net_id = 0 when it will be cldnn internal micronetwork (created i.e by propagate_constants opt pass). @@ -939,28 +741,10 @@ std::map network::execute(const std::vector& events) { OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, "NetworkImpl::Execute"); - int64_t curr_iter = -1; - GPU_DEBUG_GET_INSTANCE(debug_config); -#ifdef GPU_DEBUG_CONFIG - curr_iter = iteration; -#endif + NETWORK_DEBUG(*this); // Wait for previous execution completion reset_execution(false); - GPU_DEBUG_IF(debug_config->dump_memory_pool > 0) { - auto& iters = debug_config->dump_memory_pool_iters; - if (iters.empty() || iters.find(curr_iter) != iters.end()) { - GPU_DEBUG_COUT << "============================================================================" << std::endl; - GPU_DEBUG_COUT << "Start network execution (net_id : " << get_id() << ", iter :" << curr_iter << ")" << std::endl; - if (curr_iter == 0 && get_id() > 0) { - dump_memory_pool(debug_config->dump_memory_pool_path, curr_iter); - GPU_DEBUG_COUT << "============================================================================" << std::endl; - } - } - } else { - GPU_DEBUG_TRACE << "============================================================================" << std::endl; - GPU_DEBUG_TRACE << "Start network execution (net_id : " << get_id() << ", iter :" << curr_iter << ")" << std::endl; - } std::vector in_out_mem; auto is_surface_lock_check_needed = [&](const shared_mem_type& shared_mem_type) { @@ -996,33 +780,6 @@ void network::execute_impl(const std::vector& events) { auto surf_lock = surfaces_lock::create(get_engine().type(), in_out_mem, get_stream()); set_arguments(); - GPU_DEBUG_IF(debug_config->list_layers == 1) { - for (auto& inst : _exec_order) { - GPU_DEBUG_COUT << inst->id() << std::endl; - if (inst->get_node().is_type()) { - auto& loop_node = inst->get_node().as(); - for (auto& prim : loop_node.get_body_program()->get_processing_order()) { - GPU_DEBUG_COUT << "\t" << prim->id() << std::endl; - } - } else if (inst->get_node().is_type()) { - auto& cond_node = inst->get_node().as(); - GPU_DEBUG_COUT << "* Branch_True" << std::endl; - for (auto& prim : cond_node.get_branch_true().inner_program->get_processing_order()) { - GPU_DEBUG_COUT << "\t" << prim->id() << std::endl; - } - GPU_DEBUG_COUT << "* Branch_False" << std::endl; - for (auto& prim : cond_node.get_branch_false().inner_program->get_processing_order()) { - GPU_DEBUG_COUT << "\t" << prim->id() << std::endl; - } - } - } - if (!is_internal()) exit(0); - } - auto get_iteration_prefix = [](int64_t iter) { - if (iter < 0) - return std::string(""); - return std::to_string(iter) + "_"; - }; // This extra flush command is needed for dynamic models in both cases of out_of_order / in_order operating mode // since it reduces `bubbles` number in pipeline and GPU's idle time by timely flushing new kernels to device. @@ -1033,233 +790,43 @@ void network::execute_impl(const std::vector& events) { size_t executed_prims = 0; for (auto& inst : _exec_order) { - // Load binary dump for input layers - GPU_DEBUG_IF(!debug_config->load_layers_raw_dump.empty()) { - const std::string layer_name = inst->id(); - auto files = debug_config->get_filenames_for_matched_layer_loading_binaries(layer_name); - if (!files.empty()) { - if (inst->is_input()) { - // Loading binary dumps for output tensors of input-layers : only one output exists or index(dstN) exists - auto dump_file = debug_config->get_matched_from_filelist(files, "_dst0__"); - OPENVINO_ASSERT((files.size() == 1 || dump_file.length() != 0), "Unexpected binary dump for input layer"); - - OPENVINO_ASSERT(files.size() == get_primitive(inst->id())->outputs_memory_count(), "Mis-match dump file count"); - - for (size_t i = 0; i < get_primitive(inst->id())->outputs_memory_count(); i++) { - auto dump_file = files[0]; - if (files.size() > 1 || get_primitive(inst->id())->outputs_memory_count() != 1) { - std::string pattern = "_dst" + std::to_string(i) + "__"; - dump_file = debug_config->get_matched_from_filelist(files, pattern); - } - OPENVINO_ASSERT((dump_file.length() > 0), "Could not find expected pattern '_dst[N]__' for binary dump"); - GPU_DEBUG_COUT << " Load binary dump : " << dump_file << " for " << layer_name << std::endl; - - std::vector bin = ov::util::load_binary(dump_file); - OPENVINO_ASSERT(!bin.empty(), "Failure loading binary from OV_GPU_LoadDumpRawBinary : " + dump_file); - - auto output_mem = get_primitive(layer_name)->output_memory_ptr(i); - OPENVINO_ASSERT(output_mem->size() == bin.size(), "memory size mis-match for OV_GPU_LoadDumpRawBinary : " + layer_name - + "\n Expected size : " + to_string(output_mem->size()) + ", Binary : " + to_string(bin.size())); - - output_mem->copy_from(get_stream(), static_cast(&bin[0]), true); - } - } else { - auto check_dst = debug_config->get_matched_from_filelist(files, "_dst0__"); - OPENVINO_ASSERT(check_dst.length() == 0, "Expected to load binaries for inputs of " + layer_name); - - // Loading input tensors for any layer - auto dump_file = debug_config->get_matched_from_filelist(files, "_src0__"); - OPENVINO_ASSERT(dump_file.length() != 0, "Could not find expected pattern '_src[N]__' for binary dump input : " + layer_name); - - for (size_t i = 0; i < get_primitive(inst->id())->dependencies().size(); i++) { - auto dump_file = files[0]; - if (files.size() > 1 || get_primitive(inst->id())->dependencies().size() != 1) { - std::string pattern = "_src" + std::to_string(i) + "__"; - dump_file = debug_config->get_matched_from_filelist(files, pattern); - } - if (dump_file.length() == 0) { - GPU_DEBUG_COUT << " Skip loading for input(" << i << ") of " << layer_name << std::endl; - continue; - } - OPENVINO_ASSERT((dump_file.length() > 0), "Could not find expected pattern '_src[N]__' for binary dump input"); - GPU_DEBUG_COUT << " Load binary dump : " << dump_file << " for input(" << i << ") of " << layer_name << std::endl; - - std::vector bin = ov::util::load_binary(dump_file); - OPENVINO_ASSERT(!bin.empty(), "Failure loading binary from OV_GPU_LoadDumpRawBinary : " + dump_file); - - auto input_mem = get_primitive(inst->id())->dep_memory_ptr(i); - if (input_mem->size() != bin.size()) { - std::cout << "WARNING: memory size mis-match for OV_GPU_LoadDumpRawBinary : " + layer_name - << " " << input_mem->size() << " / " << bin.size() << std::endl; - bin.resize(input_mem->size()); - } - - input_mem->copy_from(get_stream(), static_cast(&bin[0]), true); - } - } - } - } - - // Dump input buffers of 'inst' - GPU_DEBUG_IF(debug_config->dump_layers_path.length() > 0) { - const std::string layer_name = inst->id(); - - GPU_DEBUG_IF(debug_config->is_target_iteration(curr_iter) && - debug_config->dump_layers_dst_only == 0 && debug_config->is_layer_for_dumping(layer_name)) { - std::string debug_str_for_bin_load = " Command for loading : OV_GPU_LoadDumpRawBinary=\"" + layer_name + ":"; - for (size_t i = 0; i < get_primitive(inst->id())->dependencies().size(); i++) { - std::string name = "program" + std::to_string((get_program() != nullptr) ? get_program()->get_id() : 0) + - "_network" + std::to_string(get_id()) + - "_" + get_iteration_prefix(curr_iter) + - layer_name + "_src" + std::to_string(i); - auto input_mem = get_primitive(inst->id())->dep_memory_ptr(i); - if (input_mem == nullptr) { - GPU_DEBUG_COUT << " input_mem_" << i << " is nullptr. Nothing to dump." << std::endl; - continue; - } - - auto dep = inst->dependencies().at(i); - auto input_layout = dep.first->get_output_layout(dep.second); - GPU_DEBUG_IF(debug_config->dump_layers_binary) { - // Binary dump : raw - auto filename = get_file_path_for_binary_dump(input_layout, name); - - mem_lock lock(input_mem, get_stream()); - ov::util::save_binary(filename, lock.data(), input_mem->size()); - GPU_DEBUG_COUT << " Dump layer src : " << layer_name << " to " << filename << std::endl; - debug_str_for_bin_load += (filename + ","); - } else { - log_memory_to_file(input_mem, - input_layout, - get_stream(), - name, - debug_config->dump_layers_raw); - } - } - - GPU_DEBUG_IF(debug_config->dump_layers_binary && !inst->is_input()) { - debug_str_for_bin_load[debug_str_for_bin_load.size()-1] = '\"'; - GPU_DEBUG_COUT << debug_str_for_bin_load << std::endl;; - } - } - } + NODE_DEBUG(*inst); execute_primitive(inst, events); executed_prims++; if (needs_flushing && executed_prims % flush_frequency == 0) get_stream().flush(); - - // Dump output buffers of 'inst' - GPU_DEBUG_IF(debug_config->dump_layers_path.length() > 0) { - get_stream().finish(); - const std::string layer_name = inst->id(); - auto prog_id = ((get_program() != nullptr) ? get_program()->get_id() : 0); - auto net_id = get_id(); - GPU_DEBUG_IF(debug_config->is_target_iteration(curr_iter) && - debug_config->is_layer_for_dumping(layer_name, inst->is_output(), inst->is_input())) { - std::string debug_str_for_bin_load = " Command for loading : OV_GPU_LoadDumpRawBinary=\"" - + layer_name + ":"; - for (size_t i = 0; i < get_primitive(layer_name)->outputs_memory_count(); i++) { - std::string name = "program" + std::to_string(prog_id) + - "_network" + std::to_string(net_id) + - "_" + get_iteration_prefix(curr_iter) + - layer_name + "_dst" + std::to_string(i); - auto output_mem = get_primitive(layer_name)->output_memory_ptr(i); - if (output_mem == nullptr) { - GPU_DEBUG_COUT << " output_mem is nullptr. Nothing to dump." << std::endl; - continue; - } - - GPU_DEBUG_IF(debug_config->dump_layers_binary) { - // Binary dump : raw - auto output_layout = inst->get_output_layout(i); - auto filename = get_file_path_for_binary_dump(output_layout, name); - - mem_lock lock(output_mem, get_stream()); - ov::util::save_binary(filename, lock.data(), output_mem->size()); - GPU_DEBUG_COUT << " Dump layer dst : " << layer_name << " to " << filename << std::endl; - debug_str_for_bin_load += (filename + ","); - } else { - // Text dump - log_memory_to_file(output_mem, inst->get_output_layout(i), get_stream(), name, debug_config->dump_layers_raw); - } - } - - GPU_DEBUG_IF(debug_config->dump_layers_binary && inst->is_input()) { - debug_str_for_bin_load[debug_str_for_bin_load.size()-1] = '\"'; - GPU_DEBUG_COUT << debug_str_for_bin_load << std::endl;; - } - } - } - } - - // print '-data_shape' option for benchmark_app - GPU_DEBUG_IF(debug_config->print_input_data_shapes == 1) { - std::stringstream data_shape_str; - auto add_string = [&data_shape_str](std::string str) { - data_shape_str << ((data_shape_str.rdbuf()->in_avail() == 0) ? " -data_shape " : ",") << str; - }; - - for (auto& inst : _exec_order) { - auto name = inst->id(); - auto pos = name.find(':'); - auto type = name.substr(0, pos); - name.erase(0, pos + 1); - if (inst->is_input() && type == "parameter") { - add_string(name + inst->get_output_layout().get_partial_shape().to_string()); - } - } - - GPU_DEBUG_COUT << "[program:" << std::setw(2) << ((get_program() != nullptr) ? get_program()->get_id() : 0) - << "|network:" << std::setw(2) << get_id() << "|iter:" << std::setw(4) << curr_iter << "] benchmark_app cmd: " - << data_shape_str.str() << std::endl; - } - - GPU_DEBUG_IF(!debug_config->dump_graphs.empty() && debug_config->is_target_iteration(curr_iter)) { - auto get_fixed_str = [](int value, int length = 2) -> std::string { - std::ostringstream ss; - ss << std::setw(length) << std::setfill('0') << std::to_string(value); - return ss.str(); - }; - std::string path = get_dir_path(get_config()); - if (!path.empty()) { - std::ofstream ofs(path + "cldnn_program_exec_p" + get_fixed_str(get_program()->get_id()) + "_n" + get_fixed_str(get_id()) - + "_" + get_fixed_str(curr_iter, 5) + ".graph"); - dump_graph_init(ofs, *get_program(), [&](const primitive_id& id) -> std::shared_ptr { - return get_primitive(id); - }); - } } // Store events only in case of OOO queue or enabled Profiling auto store_events = is_out_of_order_queue || _enable_profiling; if (store_events) { if (_program != nullptr) { - for (auto& inst : _program->get_processing_order()) { - // Special handling for mutable data. The event should be the same as the user or dependency with highest - // processing_num as the mutable_data can be updated when is both user or dependency. - if (inst->is_type()) { - decltype(_program->get_processing_order().get_processing_number(inst)) proc_num = 0; - for (auto& user : inst->get_users()) { - auto user_proc_num = _program->get_processing_order().get_processing_number(user); - if (user_proc_num > proc_num) { - _events[inst->id()] = _events[user->id()]; - proc_num = user_proc_num; + for (auto& inst : _program->get_processing_order()) { + // Special handling for mutable data. The event should be the same as the user or dependency with highest + // processing_num as the mutable_data can be updated when is both user or dependency. + if (inst->is_type()) { + decltype(_program->get_processing_order().get_processing_number(inst)) proc_num = 0; + for (auto& user : inst->get_users()) { + auto user_proc_num = _program->get_processing_order().get_processing_number(user); + if (user_proc_num > proc_num) { + _events[inst->id()] = _events[user->id()]; + proc_num = user_proc_num; + } } - } - if (!inst->get_dependencies().empty()) { - for (auto& dep : inst->get_dependencies()) { - auto dep_proc_num = _program->get_processing_order().get_processing_number(dep.first); - if (dep_proc_num > proc_num) { - _events[inst->id()] = _events[dep.first->id()]; - proc_num = dep_proc_num; + if (!inst->get_dependencies().empty()) { + for (auto& dep : inst->get_dependencies()) { + auto dep_proc_num = _program->get_processing_order().get_processing_number(dep.first); + if (dep_proc_num > proc_num) { + _events[inst->id()] = _events[dep.first->id()]; + proc_num = dep_proc_num; + } } } } } } - } for (auto& dout : _data_outputs) { // data primitives are not executed so if they are marked as output we need to add // them valid events manually @@ -1278,73 +845,6 @@ void network::execute_impl(const std::vector& events) { // Deallocate events from the previos iteration _old_events.clear(); - - GPU_DEBUG_IF(debug_config->dump_memory_pool > 0) { - auto& iters = debug_config->dump_memory_pool_iters; - if (iters.empty() || iters.find(curr_iter) != iters.end()) { - dump_memory_pool(debug_config->dump_memory_pool_path, curr_iter); - GPU_DEBUG_COUT << "============================================================================" << std::endl; - } - } - -#ifdef GPU_DEBUG_CONFIG - iteration++; -#endif -} - -void network::dump_memory_pool(std::string dump_path, int64_t curr_iter) { -#ifdef GPU_DEBUG_CONFIG - get_memory_pool().dump(get_id(), curr_iter, dump_path); - auto get_constants_mem_size = [&](allocation_type type) -> size_t { - size_t mem_size = 0; - for (auto& prim : _primitives) { - if (prim.second->get_node().is_constant()) { - for (size_t i = 0; i < prim.second->outputs_memory_count(); i++) { - if (prim.second->output_memory_ptr(i)->get_allocation_type() == type) - mem_size += prim.second->output_memory_ptr(i)->size(); - } - } - } - return mem_size; - }; - auto get_variables_mem_size = [&](allocation_type type) -> size_t { - size_t mem_size = 0; - for (auto& var : get_variables()) { - if (var.second->get_memory() && var.second->get_memory()->get_allocation_type() == type) - mem_size += var.second->get_actual_mem_size(); - } - return mem_size; - }; - auto get_mb_size = [&](int64_t size) -> std::string { - if (size == 0) return "0 MB"; - return std::to_string(static_cast(size) / (1024 * 1024)) + " MB"; - }; - int64_t usm_host_const_mem_size = get_constants_mem_size(allocation_type::usm_host); - int64_t usm_device_const_mem_size = get_constants_mem_size(allocation_type::usm_device); - int64_t usm_host_var_mem_size = get_variables_mem_size(allocation_type::usm_host); - int64_t usm_device_var_mem_size = get_variables_mem_size(allocation_type::usm_device); - int64_t host_mem_size = get_engine().get_used_device_memory(allocation_type::usm_host); - int64_t device_mem_size = get_engine().get_used_device_memory(allocation_type::usm_device); - int64_t usm_host_mem_pool_size = get_memory_pool().get_total_mem_pool_size(allocation_type::usm_host); - int64_t usm_host_etc_size = host_mem_size - usm_host_mem_pool_size - - usm_host_const_mem_size - usm_host_var_mem_size; - int64_t usm_device_mem_pool_size = get_memory_pool().get_total_mem_pool_size(allocation_type::usm_device); - int64_t usm_device_etc_size = device_mem_size - usm_device_mem_pool_size - - usm_device_const_mem_size - usm_device_var_mem_size; - GPU_DEBUG_COUT << "------------------------------------------------------------------------" << std::endl; - GPU_DEBUG_COUT << "Memory statistics for (net_id:" << get_id() << ", iter:" << curr_iter << ")" << std::endl; - GPU_DEBUG_COUT << " Total host mem size : " << get_mb_size(host_mem_size) << std::endl; - GPU_DEBUG_COUT << " * Memory pool : " << get_mb_size(usm_host_mem_pool_size) << std::endl; - GPU_DEBUG_COUT << " * Constant : " << get_mb_size(usm_host_const_mem_size) << std::endl; - GPU_DEBUG_COUT << " * Variable : " << get_mb_size(usm_host_var_mem_size) << std::endl; - GPU_DEBUG_COUT << " * ETC : " << get_mb_size(usm_host_etc_size) << std::endl; - GPU_DEBUG_COUT << " Total device mem size : " << get_mb_size(device_mem_size) << std::endl; - GPU_DEBUG_COUT << " * Memory pool : " << get_mb_size(usm_device_mem_pool_size) << std::endl; - GPU_DEBUG_COUT << " * Constant : " << get_mb_size(usm_device_const_mem_size) << std::endl; - GPU_DEBUG_COUT << " * Variable : " << get_mb_size(usm_device_var_mem_size) << std::endl; - GPU_DEBUG_COUT << " * ETC : " << get_mb_size(usm_device_etc_size) << std::endl; - GPU_DEBUG_COUT << "------------------------------------------------------------------------" << std::endl; -#endif } std::vector network::get_input_ids() const { diff --git a/src/plugins/intel_gpu/src/graph/program_dump_graph.cpp b/src/plugins/intel_gpu/src/graph/program_dump_graph.cpp index bff45cd81f9900..4a2f43b28d9360 100644 --- a/src/plugins/intel_gpu/src/graph/program_dump_graph.cpp +++ b/src/plugins/intel_gpu/src/graph/program_dump_graph.cpp @@ -170,7 +170,7 @@ std::string get_dir_path(const ExecutionConfig& config) { void dump_graph_init(std::ofstream& graph, const program& program, - std::function(const primitive_id&)> get_primitive_inst) { + std::function(const primitive_id&)> get_primitive_inst) { const std::string invalid_layout_msg = "(invalid layout)"; const auto dump_mem_info = [&invalid_layout_msg, &get_primitive_inst](const program_node* ptr) { diff --git a/src/plugins/intel_gpu/src/graph/program_node.cpp b/src/plugins/intel_gpu/src/graph/program_node.cpp index 3c9ad0f7317a27..21ba4e656fae0d 100644 --- a/src/plugins/intel_gpu/src/graph/program_node.cpp +++ b/src/plugins/intel_gpu/src/graph/program_node.cpp @@ -611,9 +611,9 @@ bool program_node::is_padded_spatial(size_t idx) const { auto& layout = get_output_layout(idx); const auto& lower_size = layout.data_padding._lower_size; const auto& upper_size = layout.data_padding._upper_size; - return std::any_of(std::begin(lower_size) + 2, std::begin(lower_size) + layout.get_spatial_rank() - 1, + return std::any_of(std::begin(lower_size) + 2, std::begin(lower_size) + 2 + layout.get_spatial_rank(), [](const tensor::value_type& el) { return el != 0; }) || - std::any_of(std::begin(upper_size) + 2, std::begin(upper_size) + layout.get_spatial_rank() - 1, + std::any_of(std::begin(upper_size) + 2, std::begin(upper_size) + 2 + layout.get_spatial_rank(), [](const tensor::value_type& el) { return el != 0; }); } diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.hpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.hpp index 72a62781580cda..e7e5121b1240e7 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.hpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.hpp @@ -16,8 +16,6 @@ namespace ov { namespace npuw { namespace online { -class Group; // forward declaration - namespace detail { // At partitioning level we exclude some "non-Ops" to not interfere with the passes. // We include some of them back to properly link everything at plugin level @@ -33,6 +31,8 @@ class Snapshot : public std::enable_shared_from_this { m_node_to_prod_cons(std::make_shared()), m_node_to_gr(std::make_shared()) {} + friend class Group; // forward declaration + // Simple passes void singleGroup(); @@ -49,27 +49,27 @@ class Snapshot : public std::enable_shared_from_this { void repeatedBlocks(); void earlyAvoids(); void earlyRegroup(); - void markInternalCompute(); - void resetExcludedRep(); // Utility std::shared_ptr getGraph() const; - size_t graphSize() const; - const detail::OVNodeSet& getNodeProducers(const detail::OVNodePtr& node) const; - const detail::OVNodeSet& getNodeConsumers(const detail::OVNodePtr& node) const; const detail::OVPortsMap& getPortsMap() const; const detail::OVNodeToGroupMapPtr& getNodeToGroupMap() const; const std::map>>& getMatches() const; - detail::GPtrSet getRepGroups(const std::shared_ptr& group) const; void repeat(detail::Pass&& pass); void setCtx(const PassContext& ctx); + size_t graphSize() const; private: + detail::GPtrSet getRepGroups(const std::shared_ptr& group) const; + const detail::OVNodeSet& getNodeProducers(const detail::OVNodePtr& node) const; + const detail::OVNodeSet& getNodeConsumers(const detail::OVNodePtr& node) const; void identifyUniques(); void mergeUniques(); void mergeTriangles(); void cleanUpUniques(); void afterUniques(); + void markInternalCompute(); + void resetExcludedRep(); bool cleanUpUniquesImpl(const detail::GPtrSet& gset); std::shared_ptr tryGrowRepeatingGroups(const detail::GPtrSet& repeating_groups); std::shared_ptr tryMergeTriangles(const detail::GPtrSet& repeating_groups); diff --git a/src/plugins/intel_npu/tests/CMakeLists.txt b/src/plugins/intel_npu/tests/CMakeLists.txt index 4c41f008eb7f81..0f5bd7a6b093b2 100644 --- a/src/plugins/intel_npu/tests/CMakeLists.txt +++ b/src/plugins/intel_npu/tests/CMakeLists.txt @@ -8,3 +8,4 @@ if (MSVC) ov_add_compiler_flags(/wd5105) endif() add_subdirectory(functional) +add_subdirectory(unit) diff --git a/src/plugins/intel_npu/tests/unit/CMakeLists.txt b/src/plugins/intel_npu/tests/unit/CMakeLists.txt new file mode 100644 index 00000000000000..861a0ff6a47076 --- /dev/null +++ b/src/plugins/intel_npu/tests/unit/CMakeLists.txt @@ -0,0 +1,46 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +set(TARGET_NAME "ov_npu_unit_tests") + +set(MANDATORY_UNIT_TESTS_LIBS + "openvino::commonTestUtils" + "openvino::gmock" + "openvino::gtest" + "openvino::gtest_main" + "openvino::runtime" + "openvino::npu_al" + "openvino::npu_logger_utils" +) + +ov_add_test_target( + NAME ${TARGET_NAME} + ROOT ${CMAKE_CURRENT_SOURCE_DIR} + ADDITIONAL_SOURCE_DIRS + ${OpenVINO_SOURCE_DIR}/src/plugins/intel_npu/src/plugin/npuw/ + DEPENDENCIES + openvino::runtime + INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/npuw + ${OpenVINO_SOURCE_DIR}/src/plugins/intel_npu/src/plugin/npuw + ${OpenVINO_SOURCE_DIR}/src/plugins/intel_npu/src/utils/include + ${OpenVINO_SOURCE_DIR}/src/plugins/intel_npu/src/plugin/include + ${OpenVINO_SOURCE_DIR}/src/plugins/intel_npu/src/al/include + LINK_LIBRARIES + ${MANDATORY_UNIT_TESTS_LIBS} + LABELS + NPUW +) + +if(ENABLE_AVX2) + ov_avx2_optimization_flags(avx2_flags) + target_compile_options(${TARGET_NAME} PRIVATE "${avx2_flags}") +endif() + +install(TARGETS ${TARGET_NAME} + RUNTIME DESTINATION tests + COMPONENT tests + EXCLUDE_FROM_ALL +) diff --git a/src/plugins/intel_npu/tests/unit/npuw/online_partitioning.cpp b/src/plugins/intel_npu/tests/unit/npuw/online_partitioning.cpp new file mode 100644 index 00000000000000..af1fc5de8e92c7 --- /dev/null +++ b/src/plugins/intel_npu/tests/unit/npuw/online_partitioning.cpp @@ -0,0 +1,692 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "partitioning/online/compiler.hpp" +#include "partitioning/online/snapshot.hpp" +#include "partitioning/online/group.hpp" + +#include "intel_npu/al/config/config.hpp" +#include "intel_npu/al/config/npuw.hpp" + +#include "openvino/openvino.hpp" +#include "openvino/op/ops.hpp" +#include "openvino/op/util/op_types.hpp" + +bool isEqualEns(ov::npuw::Ensemble& ens1, ov::npuw::Ensemble& ens2); +bool isEqualEns(ov::npuw::Ensemble& ens1, ov::npuw::Ensemble& ens2) { + if (ens1.groups.size() != ens2.groups.size()) { + return false; + } + + for (auto& g : ens1.groups) { + std::sort(g.input_layers.begin(), g.input_layers.end()); + std::sort(g.output_layers.begin(), g.output_layers.end()); + std::sort(g.all_layers.begin(), g.all_layers.end()); + } + + for (auto& g : ens2.groups) { + std::sort(g.input_layers.begin(), g.input_layers.end()); + std::sort(g.output_layers.begin(), g.output_layers.end()); + std::sort(g.all_layers.begin(), g.all_layers.end()); + } + + std::sort(ens1.groups.begin(), ens1.groups.end(), [](const ov::npuw::Group& g1, + const ov::npuw::Group& g2){ + return g1.all_layers.front() < g2.all_layers.front(); + }); + + std::sort(ens2.groups.begin(), ens2.groups.end(), [](const ov::npuw::Group& g1, + const ov::npuw::Group& g2){ + return g1.all_layers.front() < g2.all_layers.front(); + }); + + for (size_t i = 0; i < ens1.groups.size(); ++i) { + const auto& g1 = ens1.groups.at(i); + const auto& g2 = ens2.groups.at(i); + + if (g1.avoid_list != g2.avoid_list || + g1.input_layers != g2.input_layers || + g1.output_layers != g2.output_layers || + g1.all_layers != g2.all_layers) { + return false; + } + + // Can't compare them directly since they are random, but dont't affect the structure + if ((g1.repeated_id.empty() && !g2.repeated_id.empty()) || + (!g1.repeated_id.empty() && g2.repeated_id.empty())) { + return false; + } + } + + if (ens1.repeated.size() != ens2.repeated.size()) { + return false; + } + + auto get_sorted_rep = [](const std::map& rep) { + std::vector>> sorted_rep; + + std::transform(rep.begin(), rep.end(), std::back_inserter(sorted_rep), [](const auto& v) { + return v.second.matches; + }); + + for (auto& g : sorted_rep) { + std::sort(g.begin(), g.end(), + [](const auto& a, const auto& b) {return *a.begin() < *b.begin();}); + } + + std::sort(sorted_rep.begin(), sorted_rep.end(), + [](const auto& a, const auto& b) {return *a.front().begin() < *b.front().begin();}); + + return sorted_rep; + }; + + + if (get_sorted_rep(ens1.repeated) != get_sorted_rep(ens2.repeated)) { + return false; + } + + return true; +} + +class ModelGenerator { +public: + ModelGenerator() = default; + + std::shared_ptr get_model_without_repeated_blocks() { + std::shared_ptr input = std::make_shared(ov::element::i32, ov::Shape{1, 1, 40}); + m_nodes.push_back(input); + set_name(input); + + std::shared_ptr res = get_block(input); + + auto result = std::make_shared(res); + m_nodes.push_back(result); + set_name(result); + + ov::ParameterVector params = {input}; + ov::ResultVector results = {result}; + + return std::make_shared(results, params); + } + + std::shared_ptr get_model_with_repeated_blocks() { + // Generate head + std::shared_ptr input = std::make_shared(ov::element::i32, ov::Shape{1, 1, 40}); + m_nodes.push_back(input); + set_name(input); + + std::vector> head(7, nullptr); + head[0] = std::make_shared(input, input); + head[1] = std::make_shared(ov::element::i32, ov::Shape{1}, std::vector{2}); + head[2] = std::make_shared(head[0], head[1], true); + head[3] = std::make_shared(ov::element::i64, ov::Shape{4}, std::vector{1, 1, 4, 10}); + head[4] = std::make_shared(ov::element::i64, ov::Shape{3}, std::vector{1, 1, 40}); + head[5] = std::make_shared(head[2], head[3], false); + head[6] = std::make_shared(head[5], head[4], false); + + for (const auto& h : head) { + m_nodes.push_back(h); + set_name(h); + } + + // Generate repeated blocks + std::shared_ptr output = get_block(head[6]); + std::vector> outputs; + outputs.push_back(output); + + for (size_t i = 0; i < 9; ++i) { + output = get_block(output); + outputs.push_back(output); + } + + // Generate tail + std::vector> tail(6, nullptr); + tail[0] = std::make_shared(outputs, -1); + tail[1] = std::make_shared(ov::element::i32, ov::Shape{3}, std::vector{1, 20, 20}); + tail[2] = std::make_shared(tail[0], tail[1], false); + tail[3] = std::make_shared(ov::element::i32, ov::Shape{1, 1, 1}); + tail[4] = std::make_shared(tail[2], tail[3]); + tail[5] = std::make_shared(tail[4], tail[4]); + + for (const auto& t : tail) { + m_nodes.push_back(t); + set_name(t); + } + + // Create model + auto result = std::make_shared(tail[5]); + m_nodes.push_back(result); + set_name(result); + + ov::ParameterVector params = {input}; + ov::ResultVector results = {result}; + + return std::make_shared(results, params); + } + + std::shared_ptr get_block(const std::shared_ptr& input) { + // Parameters + // input + + // Constants + std::vector> model_c(18, nullptr); + model_c[0] = std::make_shared(ov::element::i32, ov::Shape{4}, std::vector{0, 2, 1, 3}); + model_c[1] = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); + model_c[2] = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{0}); + model_c[3] = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{2}); + model_c[4] = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{0}); + model_c[5] = std::make_shared(ov::element::i64, ov::Shape{4}, std::vector{1, 1, 1, 1}); + model_c[6] = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); + model_c[7] = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{0}); + model_c[8] = std::make_shared(ov::element::i64, ov::Shape{4}, std::vector{1, 1, 1, 1}); + model_c[9] = std::make_shared(ov::element::i32, ov::Shape{4}, std::vector{1, 1, 1, 2}); + model_c[10] = std::make_shared(ov::element::i32, ov::Shape{4}, std::vector{1, 1, 1, 1}); + model_c[11] = std::make_shared(ov::element::i32, ov::Shape{4}, std::vector{1, 1, 1, 2}); + model_c[12] = std::make_shared(ov::element::i32, ov::Shape{1, 1, 1, 1}); + model_c[13] = std::make_shared(ov::element::i32, ov::Shape{1, 1, 1, 1}); + model_c[14] = std::make_shared(ov::element::i32, ov::Shape{1, 1, 1, 1}); + model_c[15] = std::make_shared(ov::element::f32, ov::Shape{40, 40}); + model_c[16] = std::make_shared(ov::element::i64, ov::Shape{4}, std::vector{1, 1, 4, 10}); + model_c[17] = std::make_shared(ov::element::i32, ov::Shape{3}, std::vector{1, 1, 40}); + + for (const auto& c : model_c) { + m_nodes.push_back(c); + set_name(c); + } + + // Converts + std::vector> convert(3, nullptr); + convert[0] = std::make_shared(model_c[15], ov::element::f16); + convert[1] = std::make_shared(convert[0], ov::element::i32); + convert[2] = std::make_shared(model_c[12], ov::element::i32); + + for (const auto& c : convert) { + m_nodes.push_back(c); + set_name(c); + } + + // Ops + std::vector> op(16, nullptr); + op[0] = std::make_shared(input, convert[1], false, true); + op[1] = std::make_shared(op[0], model_c[16], false); + op[2] = std::make_shared(op[1], model_c[0]); + op[3] = std::make_shared(op[2]); + op[4] = std::make_shared(op[3], model_c[1], model_c[2]); + op[5] = std::make_shared(op[4], model_c[3], true); + op[6] = std::make_shared(op[5]); + op[7] = std::make_shared(model_c[5], model_c[6], op[6], model_c[7]); + op[8] = std::make_shared(op[2], + model_c[8], + op[7], + model_c[9], + std::vector{1, 1, 1, 1}, + std::vector{1, 1, 1, 1}); + op[9] = std::make_shared(op[2], + op[7], + model_c[10], + model_c[11], + std::vector{1, 1, 1, 1}, + std::vector{1, 1, 1, 1}); + op[10] = std::make_shared(op[9], convert[2]); + op[11] = std::make_shared(std::vector>{op[10], op[8]}, -1); + op[12] = std::make_shared(model_c[13], op[11]); + op[13] = std::make_shared(model_c[14], op[2]); + op[14] = std::make_shared(op[13], op[12]); + op[15] = std::make_shared(op[14], model_c[17], false); + + for (const auto& o : op) { + m_nodes.push_back(o); + set_name(o); + } + + return op[15]; + } + +private: + void set_name(const std::shared_ptr& node) { + node->set_friendly_name("node_" + std::to_string(m_name_idx++)); + } + + std::vector> m_nodes; + size_t m_name_idx; +}; + +TEST(OnlinePartitioningTest, Partitioning_IsTheSame_SmallModel) { + ModelGenerator mg; + auto model = mg.get_model_without_repeated_blocks(); + + auto opt_desc = std::make_shared<::intel_npu::OptionsDesc>(); + auto cfg = ::intel_npu::Config(opt_desc); + ::intel_npu::registerNPUWOptions(*opt_desc); + std::map cfg_map = {{ "NPUW_ONLINE_KEEP_BLOCK_SIZE", "9" }}; + cfg.update(cfg_map); + + auto ens = ov::npuw::online::buildPartitioning(model, cfg); + + for (size_t i = 0; i < 100; ++i) { + auto ens_again = ov::npuw::online::buildPartitioning(model, cfg); + EXPECT_TRUE(isEqualEns(ens, ens_again)); + } +} + +TEST(OnlinePartitioningTest, Partitioning_IsTheSame_RepeatedModel) { + ModelGenerator mg; + auto model = mg.get_model_with_repeated_blocks(); + + auto opt_desc = std::make_shared<::intel_npu::OptionsDesc>(); + auto cfg = ::intel_npu::Config(opt_desc); + ::intel_npu::registerNPUWOptions(*opt_desc); + std::map cfg_map = {{ "NPUW_ONLINE_KEEP_BLOCK_SIZE", "9" }}; + cfg.update(cfg_map); + + auto ens = ov::npuw::online::buildPartitioning(model, cfg); + + for (size_t i = 0; i < 100; ++i) { + auto ens_again = ov::npuw::online::buildPartitioning(model, cfg); + EXPECT_TRUE(isEqualEns(ens, ens_again)); + } +} + +TEST(OnlinePartitioningTest, Partitioning_SingleGroup_SmallModel) { + ModelGenerator mg; + auto model = mg.get_model_without_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->singleGroup(); + EXPECT_EQ(snap->graphSize(), 1); +} + +TEST(OnlinePartitioningTest, Partitioning_SingleGroup_RepeatedModel) { + ModelGenerator mg; + auto model = mg.get_model_with_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->singleGroup(); + EXPECT_EQ(snap->graphSize(), 1); +} + +TEST(OnlinePartitioningTest, Partitioning_buildGraph_SmallModel) { + ModelGenerator mg; + auto model = mg.get_model_without_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->buildGraph(); + auto g = snap->getGraph(); + for (const auto& nh : g->sorted()) { + ov::npuw::online::Group::GPtr group = g->meta(nh).get(); + EXPECT_EQ(group->size(), 1); + } + EXPECT_EQ(snap->getNodeToGroupMap()->size(), snap->graphSize()); +} + +TEST(OnlinePartitioningTest, Partitioning_buildGraph_RepeatedModel) { + ModelGenerator mg; + auto model = mg.get_model_with_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->buildGraph(); + auto g = snap->getGraph(); + for (const auto& nh : g->sorted()) { + ov::npuw::online::Group::GPtr group = g->meta(nh).get(); + EXPECT_EQ(group->size(), 1); + } + EXPECT_EQ(snap->getNodeToGroupMap()->size(), snap->graphSize()); +} + +TEST(OnlinePartitioningTest, Partitioning_earlyAvoids_SmallModel) { + ModelGenerator mg; + auto model = mg.get_model_without_repeated_blocks(); + + auto snap = std::make_shared(model); + ov::npuw::online::PassContext ctx; + ctx.avoids = {{ov::npuw::online::PatternType::OP, "Gather", "mydevice"}, {ov::npuw::online::PatternType::OP, "MatMul", "mydevice"}}; + snap->setCtx(ctx); + snap->buildGraph(); + snap->earlyAvoids(); + auto g = snap->getGraph(); + size_t count = 0; + for (const auto& nh : g->sorted()) { + ov::npuw::online::Group::GPtr group = g->meta(nh).get(); + EXPECT_EQ(group->size(), 1); + if (group->avoidedTargets().size() == 1 && *(group->avoidedTargets().begin()) == "mydevice") { + ++count; + } + } + EXPECT_EQ(count, 2); +} + +TEST(OnlinePartitioningTest, Partitioning_earlyAvoids_RepeatedModel) { + ModelGenerator mg; + auto model = mg.get_model_with_repeated_blocks(); + + auto snap = std::make_shared(model); + ov::npuw::online::PassContext ctx; + ctx.avoids = {{ov::npuw::online::PatternType::OP, "Gather", "mydevice"}, {ov::npuw::online::PatternType::OP, "MatMul", "mydevice"}}; + snap->setCtx(ctx); + snap->buildGraph(); + snap->earlyAvoids(); + auto g = snap->getGraph(); + size_t count = 0; + for (const auto& nh : g->sorted()) { + ov::npuw::online::Group::GPtr group = g->meta(nh).get(); + EXPECT_EQ(group->size(), 1); + if (group->avoidedTargets().size() == 1 && *(group->avoidedTargets().begin()) == "mydevice") { + ++count; + } + } + EXPECT_EQ(count, 20); +} + +TEST(OnlinePartitioningTest, Partitioning_collectLHF_SmallModel) { + ModelGenerator mg; + auto model = mg.get_model_without_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->buildGraph(); + + std::vector sizes = {10, 10}; + size_t iter = 0; + + snap->repeat([&]{ + snap->collectLHF(); + EXPECT_LT(iter, sizes.size()); + EXPECT_EQ(snap->graphSize(), sizes[iter++]); + }); +} + +TEST(OnlinePartitioningTest, Partitioning_collectLHF_RepeatedModel) { + ModelGenerator mg; + auto model = mg.get_model_with_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->buildGraph(); + + std::vector sizes = {82, 82}; + size_t iter = 0; + + snap->repeat([&]{ + snap->collectLHF(); + EXPECT_LT(iter, sizes.size()); + EXPECT_EQ(snap->graphSize(), sizes[iter++]); + }); +} + +TEST(OnlinePartitioningTest, Partitioning_fuseRemnants_SmallModel) { + ModelGenerator mg; + auto model = mg.get_model_without_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->buildGraph(); + + std::vector sizes = {10, 10}; + size_t iter = 0; + + snap->repeat([&]{ + snap->fuseRemnants(); + EXPECT_LT(iter, sizes.size()); + EXPECT_EQ(snap->graphSize(), sizes[iter++]); + }); +} + +TEST(OnlinePartitioningTest, Partitioning_fuseRemnants_RepeatedModel) { + ModelGenerator mg; + auto model = mg.get_model_with_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->buildGraph(); + + std::vector sizes = {75, 38, 19, 10}; + size_t iter = 0; + + snap->repeat([&]{ + snap->fuseRemnants(); + EXPECT_LT(iter, sizes.size()); + EXPECT_EQ(snap->graphSize(), sizes[iter++]); + }); +} + +TEST(OnlinePartitioningTest, Partitioning_fuseRemnantsExtended_SmallModel) { + ModelGenerator mg; + auto model = mg.get_model_without_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->buildGraph(); + + std::vector sizes = {10, 10}; + size_t iter = 0; + + snap->repeat([&]{ + snap->fuseRemnantsExtended(); + EXPECT_LT(iter, sizes.size()); + EXPECT_EQ(snap->graphSize(), sizes[iter++]); + }); +} + +TEST(OnlinePartitioningTest, Partitioning_fuseRemnantsExtended_RepeatedModel) { + ModelGenerator mg; + auto model = mg.get_model_with_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->buildGraph(); + + std::vector sizes = {10, 10}; + size_t iter = 0; + + snap->repeat([&]{ + snap->fuseRemnantsExtended(); + EXPECT_LT(iter, sizes.size()); + EXPECT_EQ(snap->graphSize(), sizes[iter++]); + }); +} + +TEST(OnlinePartitioningTest, Partitioning_fuseInputs_SmallModel) { + ModelGenerator mg; + auto model = mg.get_model_without_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->buildGraph(); + + std::vector sizes = {15, 14, 14}; + size_t iter = 0; + + snap->repeat([&]{ + snap->fuseInputs(); + EXPECT_LT(iter, sizes.size()); + EXPECT_EQ(snap->graphSize(), sizes[iter++]); + }); +} + +TEST(OnlinePartitioningTest, Partitioning_fuseInputs_RepeatedModel) { + ModelGenerator mg; + auto model = mg.get_model_with_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->buildGraph(); + + std::vector sizes = {148, 138, 138}; + size_t iter = 0; + + snap->repeat([&]{ + snap->fuseInputs(); + EXPECT_LT(iter, sizes.size()); + EXPECT_EQ(snap->graphSize(), sizes[iter++]); + }); +} + +TEST(OnlinePartitioningTest, Partitioning_Compiler_Just_SmallModel) { + ModelGenerator mg; + auto model = mg.get_model_without_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->buildGraph(); + + std::vector sizes_lhf = {10, 10}; + size_t iter_lhf = 0; + + std::vector sizes_fr = {10, 10}; + size_t iter_fr = 0; + + snap->repeat([&] { + snap->collectLHF(); + EXPECT_LT(iter_lhf, sizes_lhf.size()); + EXPECT_EQ(snap->graphSize(), sizes_lhf[iter_lhf++]); + }); + snap->repeat([&] { + snap->fuseRemnants(); + EXPECT_LT(iter_fr, sizes_fr.size()); + EXPECT_EQ(snap->graphSize(), sizes_fr[iter_fr++]); + }); +} + +TEST(OnlinePartitioningTest, Partitioning_Compiler_Just_RepeatedModel) { + ModelGenerator mg; + auto model = mg.get_model_with_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->buildGraph(); + + std::vector sizes_lhf = {82, 82}; + size_t iter_lhf = 0; + + std::vector sizes_fr = {41, 21, 11, 10, 10}; + size_t iter_fr = 0; + + snap->repeat([&] { + snap->collectLHF(); + EXPECT_LT(iter_lhf, sizes_lhf.size()); + EXPECT_EQ(snap->graphSize(), sizes_lhf[iter_lhf++]); + }); + snap->repeat([&] { + snap->fuseRemnants(); + EXPECT_LT(iter_fr, sizes_fr.size()); + EXPECT_EQ(snap->graphSize(), sizes_fr[iter_fr++]); + }); +} + +TEST(OnlinePartitioningTest, Partitioning_Compiler_RepeatedBlocks_SmallModel) { + ModelGenerator mg; + auto model = mg.get_model_without_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->buildGraph(); + + + std::vector sizes_fr = {10, 10}; + size_t iter_fr = 0; + + snap->earlyAvoids(); + snap->earlyRegroup(); + snap->repeatedBlocks(); + EXPECT_EQ(snap->graphSize(), 17); + + auto matches = snap->getMatches(); + EXPECT_EQ(matches.size(), 0); + + snap->repeat([&] { + snap->fuseRemnantsExtended(); + EXPECT_LT(iter_fr, sizes_fr.size()); + EXPECT_EQ(snap->graphSize(), sizes_fr[iter_fr++]); + }); +} + +TEST(OnlinePartitioningTest, Partitioning_Compiler_RepeatedBlocks_RepeatedModel) { + ModelGenerator mg; + auto model = mg.get_model_with_repeated_blocks(); + + auto snap = std::make_shared(model); + snap->buildGraph(); + + + std::vector sizes_fr = {12, 12}; + size_t iter_fr = 0; + + snap->earlyAvoids(); + snap->earlyRegroup(); + snap->repeatedBlocks(); + EXPECT_EQ(snap->graphSize(), 18); + + auto matches = snap->getMatches(); + EXPECT_EQ(matches.size(), 1); + + for (const auto& m : matches) { + EXPECT_EQ(m.second.size(), 17); + for (const auto& layers : m.second) { + EXPECT_EQ(layers.size(), 10); + } + } + + snap->repeat([&] { + snap->fuseRemnantsExtended(); + EXPECT_LT(iter_fr, sizes_fr.size()); + EXPECT_EQ(snap->graphSize(), sizes_fr[iter_fr++]); + }); +} + +TEST(OnlinePartitioningTest, Partitioning_Compiler_Compute_SmallModel) { + ModelGenerator mg; + auto model = mg.get_model_without_repeated_blocks(); + + auto snap = std::make_shared(model); + + std::vector sizes_fr = {10, 10}; + size_t iter_fr = 0; + + ov::npuw::online::PassContext ctx; + ctx.isolates = {{ov::npuw::online::PatternType::OP, "Transpose", "test_compute"}, {ov::npuw::online::PatternType::OP, "ScatterUpdate", "test_compute"}}; + ctx.nofolds = {"test_compute"}; + snap->setCtx(ctx); + + snap->buildGraph(); + snap->earlyAvoids(); + snap->earlyRegroup(); + snap->repeatedBlocks(); + EXPECT_EQ(snap->graphSize(), 17); + + auto matches = snap->getMatches(); + EXPECT_EQ(matches.size(), 0); + + snap->repeat([&] { + snap->fuseRemnantsExtended(); + EXPECT_LT(iter_fr, sizes_fr.size()); + EXPECT_EQ(snap->graphSize(), sizes_fr[iter_fr++]); + }); +} + +TEST(OnlinePartitioningTest, Partitioning_Compiler_Compute_RepeatedModel) { + ModelGenerator mg; + auto model = mg.get_model_with_repeated_blocks(); + + auto snap = std::make_shared(model); + + std::vector sizes_fr = {10, 10}; + size_t iter_fr = 0; + + ov::npuw::online::PassContext ctx; + ctx.isolates = {{ov::npuw::online::PatternType::OP, "Gather", "test_compute"}, + {ov::npuw::online::PatternType::OP, "ScatterUpdate", "test_compute"}, + {ov::npuw::online::PatternType::OP, "ShapeOf", "test_compute"}, + {ov::npuw::online::PatternType::OP, "Divide", "test_compute"}, + {ov::npuw::online::PatternType::OP, "Floor", "test_compute"}}; + ctx.nofolds = {"test_compute"}; + snap->setCtx(ctx); + + snap->buildGraph(); + snap->earlyAvoids(); + snap->earlyRegroup(); + snap->repeatedBlocks(); + EXPECT_EQ(snap->graphSize(), 29); + + // FIXME: create a config in which there will be repeated blocks + auto matches = snap->getMatches(); + EXPECT_EQ(matches.size(), 0); + + snap->repeat([&] { + snap->fuseRemnantsExtended(); + EXPECT_LT(iter_fr, sizes_fr.size()); + EXPECT_EQ(snap->graphSize(), sizes_fr[iter_fr++]); + }); +} diff --git a/src/plugins/intel_npu/tests/unit/npuw/unpack.cpp b/src/plugins/intel_npu/tests/unit/npuw/unpack.cpp new file mode 100644 index 00000000000000..1049832f6ead7c --- /dev/null +++ b/src/plugins/intel_npu/tests/unit/npuw/unpack.cpp @@ -0,0 +1,103 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#ifdef HAVE_AVX2 +#include "unpack.hpp" + +namespace { + +const auto TestCases = ::testing::Combine( + ::testing::ValuesIn({ov::element::Type_t::i4}), + ::testing::ValuesIn({ov::element::Type_t::i8, ov::element::Type_t::f16}), + ::testing::ValuesIn({ov::element::Type_t::undefined}), // no used in this test + ::testing::ValuesIn({ov::element::Type_t::undefined}), // no used in this test + ::testing::ValuesIn({3lu, 0lu}), + ::details::ShapesIn({Tensors{input={1, 1, 1, 32};}, + Tensors{input={1,1,1, 128};}, + Tensors{input={1,1,1, 390};}, + Tensors{input={1,1,1, 82};}}), + ::testing::ValuesIn({true, false}), + ::testing::ValuesIn({true, false}) +); + +INSTANTIATE_TEST_SUITE_P(UnpackTests, UnpackTests, + TestCases, + UnpackTests::getTestCaseName); + +const auto TestCasesScale = ::testing::Combine( + ::testing::ValuesIn({ov::element::Type_t::i4}), // TODO: add i8 as input for test + ::testing::ValuesIn({ov::element::Type_t::f16, ov::element::Type_t::f32}), + ::testing::ValuesIn({ov::element::Type_t::f16, ov::element::Type_t::f32}), + ::testing::ValuesIn({ov::element::Type_t::undefined}), // no used in this test + ::testing::ValuesIn({3lu, 0lu}), + ::details::ShapesIn({Tensors{input={1,32, 128}; scale = {1, 32, 1};}, + Tensors{input={32, 128}; scale = {32, 1};}, + Tensors{input={64, 160}; scale = {64, 1};}, + Tensors{input={1024, 4}; scale = {64, 1};}, + Tensors{input={1, 1, 1024, 4}; scale = {1, 1, 64, 1};}}), + ::testing::ValuesIn({true, false}), + ::testing::ValuesIn({true, false}) +); + +INSTANTIATE_TEST_SUITE_P(UnpackWithScaleTests, UnpackWithScaleTests, + TestCasesScale, + UnpackWithScaleTests::getTestCaseName); + + +const auto TestCasesScaleAndZeroPoints = ::testing::Combine( + ::testing::ValuesIn({ov::element::Type_t::u4}), + ::testing::ValuesIn({ov::element::Type_t::f16}), + ::testing::ValuesIn({ov::element::Type_t::f16}), + ::testing::ValuesIn({ov::element::Type_t::u4}), + ::testing::ValuesIn({3lu, 0lu}), + ::details::ShapesIn({Tensors{input={1,32, 128}; scale = {1, 32, 1};}, + Tensors{input={1,64, 160}; scale = {1, 64, 1};}, + Tensors{input={1,1024, 4}; scale = {1, 64, 1};}, + Tensors{input={1,1, 1024, 4}; scale = {1, 1, 64, 1};}, + Tensors{input={64, 1}; scale = {64, 1};}}), + ::testing::ValuesIn({true, false}), + ::testing::ValuesIn({true, false}) +); + +INSTANTIATE_TEST_SUITE_P(UnpackTestsWithScaleAndZeroPoint, UnpackTestsWithScaleAndZeroPoint, + TestCasesScaleAndZeroPoints, + UnpackTestsWithScaleAndZeroPoint::getTestCaseName); + +const auto TestCasesScaleAndZeroPoints2 = ::testing::Combine( + ::testing::ValuesIn({ov::element::Type_t::u4}), + ::testing::ValuesIn({ov::element::Type_t::f16}), + ::testing::ValuesIn({ov::element::Type_t::f32}), + ::testing::ValuesIn({ov::element::Type_t::f32}), + ::testing::ValuesIn({3lu, 0lu}), + ::details::ShapesIn({Tensors{input={32, 32, 64}; scale = {32, 1, 64};}, + Tensors{input={64, 64, 128}; scale = {64, 1, 128};}, + Tensors{input={64, 32, 32}; scale = {64, 1, 32};}}), + ::testing::ValuesIn({true, false}), + ::testing::ValuesIn({true, false}) +); + +INSTANTIATE_TEST_SUITE_P(UnpackTestsWithScaleAndZeroPointTest2, UnpackTestsWithScaleAndZeroPointTest2, + TestCasesScaleAndZeroPoints2, + UnpackTestsWithScaleAndZeroPointTest2::getTestCaseName); + +const auto TestCasesScaleAndZeroPoints3 = ::testing::Combine( + ::testing::ValuesIn({ov::element::Type_t::u4}), + ::testing::ValuesIn({ov::element::Type_t::f16}), + ::testing::ValuesIn({ov::element::Type_t::f16}), + ::testing::ValuesIn({ov::element::Type_t::u4}), + ::testing::ValuesIn({3lu, 0lu}), + ::details::ShapesIn({Tensors{input={1, 32, 128}; scale = {1, 32, 1}; zerop = {1, 32, 1};}, + Tensors{input={16, 64, 64}; scale = {16, 64, 1}; zerop = {16, 64, 1};}, + Tensors{input={1, 1024, 4}; scale = {1, 64, 1}; zerop = {1, 32, 1};}}), + ::testing::ValuesIn({true, false}), + ::testing::ValuesIn({true, false}) +); + +INSTANTIATE_TEST_SUITE_P(UnpackTestsWithScaleAndZeroPointTest3, UnpackTestsWithScaleAndZeroPointTest3, + TestCasesScaleAndZeroPoints3, + UnpackTestsWithScaleAndZeroPointTest3::getTestCaseName); + +} // anonymous namespace + +#endif // __AVX2__ diff --git a/src/plugins/intel_npu/tests/unit/npuw/unpack.hpp b/src/plugins/intel_npu/tests/unit/npuw/unpack.hpp new file mode 100644 index 00000000000000..da5bb4e4720f3e --- /dev/null +++ b/src/plugins/intel_npu/tests/unit/npuw/unpack.hpp @@ -0,0 +1,628 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include + +#include "openvino/runtime/make_tensor.hpp" + +#include "util.hpp" + +namespace { + +#define ASSERT_NO_THROW_WITH_MESSAGE(code) do{ \ + try {\ + code;\ + }catch (const std::exception &ex ) {\ + FAIL()<> 4) | ((x & (1 << 6)) >> 4) | ((x & (1 << 5)) >> 4) | ((x & (1 << 4)) >> 4); +} + +inline int8_t lo4(int8_t x) { + return (x & (1 << 3)) | (x & (1 << 2)) | (x & (1 << 1)) | (x & (1 << 0)); +} + +inline uint8_t hi4(uint8_t x) { + return x >> 4; +} + +inline uint8_t lo4(uint8_t x) { + return x & 0x0F; +} + +inline int8_t upc(int8_t h) { + return h | (-((h & (1 << 3)) >> 3) & (-8)); +} + +typedef unsigned short ushort; +typedef unsigned int uint; + +float half_to_float(const ushort x) { + + __m128i halfVector = _mm_cvtsi32_si128(x); + __m128 floatVector = _mm_cvtph_ps(halfVector); + return _mm_cvtss_f32(floatVector); +} + +ushort float_to_half(const float x) { + __m128 floatVector = _mm_set_ss(x); + __m128i halfVector = _mm_cvtps_ph(floatVector, _MM_FROUND_TO_NEAREST_INT); + return _mm_extract_epi16(halfVector, 0); +} + +inline uint16_t int2hfloat(int8_t x) +{ + float inputFl32 = static_cast(x); + float* inputFl32_ptr = &inputFl32; + unsigned int* fltInt32Ptr = reinterpret_cast(inputFl32_ptr); + unsigned int fltInt32 = *fltInt32Ptr; + unsigned short fltInt16; + + fltInt16 = (fltInt32 >> 31) << 5; + unsigned short tmp = (fltInt32 >> 23) & 0xff; + tmp = (tmp - 0x70) & ((unsigned int)((int)(0x70 - tmp) >> 4) >> 27); + fltInt16 = (fltInt16 | tmp) << 10; + fltInt16 |= (fltInt32 >> 13) & 0x3ff; + + return fltInt16; +} + + +void unpack(const int8_t* in, int8_t* out, int size) { + for (int i = 0; i < size / 2; i++) { + *(out++) = upc(lo4(*in)); + *(out++) = upc(hi4(*in)); + in++; + } +} + +void unpack_i4f16(const int8_t* in, int8_t* out, int size) { + uint16_t *hFloatOut = reinterpret_cast(out); + + for (int i = 0; i < size / 2; i++) { + *(hFloatOut++) = int2hfloat(upc(lo4(*in))); + *(hFloatOut++) = int2hfloat(upc(hi4(*in))); + in++; + } +} + +/*u4 order*/ +void unpack_u4f32(const int8_t* in, float* out, int size) { + for (int i = 0; i < size / 2; i++) { + *(out++) = static_cast(lo4(*in)); + *(out++) = static_cast(hi4(*in)); + in++; + } +} + +template +::testing::AssertionResult fp16ArraysMatch(const T &actual, + const T &expected, + const T &i4Input, + bool int4 = 1 /*i4 or u4*/){ + for (size_t i = 0; i < expected.size() / 2; ++i) { + + int int8Input[] ={ + details::lo4(i4Input[i / 2]), + details::hi4(i4Input[i / 2]) + }; + + if (int4) { + int8Input[0] = details::upc(int8Input[1]); + int8Input[1] = details::upc(int8Input[0]); + }; + + auto fp16ref = int{*((uint16_t*)expected.data() + i)}; + auto fp16out = int{*((uint16_t*)actual.data() + i)}; + +#define _P(x) std::dec << std::setw(5) << (x) << '(' << std::setw(4) << std::hex << (x) << ')' + if (fp16ref != fp16out) { + return ::testing::AssertionFailure() << std::dec << std::setw(4) << i << ", i4:" + << std::setw(2) << int8Input[i % 2] + << " | ref " << _P(fp16ref) + << ", test " << _P(fp16out) << "\n"; + } +#undef _P + + } + + return ::testing::AssertionSuccess(); +} + +} // namespace details + +using ShapesInitializer = std::function&, std::vector&, std::vector&)>; + + +using UnpackTestsParams = std::tuple< + ov::element::Type_t, // fromPrecision + ov::element::Type_t, // toPrecision + ov::element::Type_t, // scalePrecision + ov::element::Type_t, // zeroPointPrecision + unsigned long, // nPartitions + ShapesInitializer, // input_shape , scale_shape, zerop initializer + bool, // use parallel_for + bool // strict partitioning + >; + +class UnpackTestsBase { +protected: + ov::element::Type fromType; + ov::element::Type toType; + ov::element::Type scaleType; + ov::element::Type zeropType; + std::shared_ptr from, to, scale, zerop; + + std::vector input; + std::vector output; + std::vector ref_output; + std::vector scalesStorage; + std::vector zeropStorage; + float zeropValue; + ov::Shape input_shape; + ov::Shape scale_shape; + ov::Shape zerop_shape; + + size_t nPartitions; + bool useParallelFor = false; + bool strictPartitions = false; + + void make_zeropoints() { + if (zeropType == ov::element::undefined) { + return; + } + + const std::vector zeropValues = {15.0f, 12.0f, 0.0f, 31.0f}; + const size_t nElements = shape_size(zerop_shape); + + // Set zeropValue if there's only one element + if (nElements == 1) { + zeropValue = zeropValues.front(); + } + + // Determine the size of the storage based on the type and resize the storage vector + if (zeropType == ov::element::Type_t::u4) { + zeropStorage.resize((nElements + 1) / 2, 0); // Each u4 zeropoint is 4 bits, so two zeropoints fit in one byte + } else if (zeropType == ov::element::Type_t::f32) { + zeropStorage.resize(nElements * sizeof(float), 0); + } else { + ASSERT_TRUE(zeropType == ov::element::u4 || zeropType == ov::element::f32); + } + + // Fill the storage with the appropriate values + if (zeropType == ov::element::Type_t::u4) { + for (size_t i = 0; i < nElements; ++i) { + uint8_t zeropValueU4 = static_cast(zeropValues[i % zeropValues.size()]) & 0x0F; + size_t byteIndex = i / 2; + if (i % 2 == 0) { + zeropStorage[byteIndex] = zeropValueU4; + } else { + zeropStorage[byteIndex] = (zeropValueU4 << 4); + } + } + } else if (zeropType == ov::element::Type_t::f32) { + float* ptrWork = reinterpret_cast(zeropStorage.data()); + for (size_t i = 0; i < nElements; ++i) { + ptrWork[i] = zeropValues[i % zeropValues.size()]; + } + } + + // Create the tensor + zerop = ov::make_tensor(zeropType, zerop_shape, zeropStorage.data()); + } + + void make_scales() { + if (scaleType == ov::element::undefined) { + return; + } + ASSERT_TRUE(scaleType == ov::element::f16 || scaleType == ov::element::f32); + size_t nElements = shape_size(scale_shape); + + // creating custom scale factors + const size_t nScaleBytes = scaleType.bitwidth() * nElements / 8; + + std::vector sc(nElements); + float coeffTable[] = { + 0.1f, + 0.5f, + 1.f, + 2.f + }; + for (size_t i = 0; i != nElements; i++) { + sc[i] = coeffTable[i % (sizeof (coeffTable) / sizeof(*coeffTable))]; + } + scalesStorage.resize(nScaleBytes); + + if (scaleType == ov::element::f16) { + uint16_t * ptrWork = reinterpret_cast(scalesStorage.data()); + for (size_t i = 0; i != nElements; i++) { + ptrWork[i] = details::float_to_half(sc[i]); + } + } + if (scaleType == ov::element::f32) { + float* ptrWork = reinterpret_cast(scalesStorage.data()); + for (size_t i = 0; i != nElements; i++) { + ptrWork[i] = sc[i]; + } + } + scale = ov::make_tensor(scaleType, scale_shape, scalesStorage.data()); + } + + void make_input() { + + size_t nElements = shape_size(input_shape); + + ASSERT_EQ((fromType.bitwidth() * nElements) % 8, 0) << "Input len has to be byte boundary aligned, but was " + << fromType.bitwidth() * nElements << " bits"; + ASSERT_EQ((toType.bitwidth() * nElements) % 8, 0) << "Output len has to be byte boundary aligned"; + + const size_t nInputBytes = fromType.bitwidth() * nElements / 8; + const size_t nOutputBytes = toType.bitwidth() * nElements / 8; + + input.resize(nInputBytes); + ref_output.resize(nOutputBytes); + output.resize(nOutputBytes); + std::fill(ref_output.begin(), ref_output.end(), 0); + std::fill(output.begin(), output.end(), 0); + + std::array input_local = { + 0x0A, 0x0B, 0x1C, 0x1D, 0x2E, 0x2F, 0x35, 0x36, + 0x4A, 0x4B, 0x5A, 0x5B, 0x6A, 0x6B, 0x7A, 0x7B, + 0x0C, 0x0D, 0x1C, 0x1D, 0x2C, 0x2D, 0x3C, 0x3D, + 0x4C, 0x4D, 0x5C, 0x5D, 0x6C, 0x6D, 0x7C, 0x7D, + }; + + for (size_t idx = 0, k = 0; k < nInputBytes; k++, idx = (idx + 1) % input_local.size()) { + input[k] = input_local[idx]; + } + + from = ov::make_tensor(fromType, input_shape, input.data()); + to = ov::make_tensor(toType, input_shape, output.data()); + } +public: + void SetUp(const UnpackTestsParams & getParam) { + ShapesInitializer shapeInit; + + std::tie(fromType, toType, scaleType, zeropType, nPartitions, shapeInit, useParallelFor, strictPartitions) = getParam; + + std::vector input, scale, zerop; + shapeInit(input, scale, zerop); + + input_shape = ov::Shape{input.begin(), input.end()}; + scale_shape = ov::Shape{scale.begin(), scale.end()}; + if (zerop.empty()) { + zerop_shape = ov::Shape({1}); + } else { + zerop_shape = ov::Shape{zerop.begin(), zerop.end()}; + } + + make_input(); + make_scales(); + make_zeropoints(); + + make_ref_output(); + } + std::string ToString() const { + std::ostringstream result; + result << (isNegative() ? "NEGATIVE_" : "") + <<"["; + + for (size_t i = 0; i != input_shape.size(); i++) { + result << input_shape[i] << ((i + 1 == input_shape.size()) ? "" : "x"); + } + result <<"]" + << "_p" << nPartitions + << (strictPartitions ? "_SP" : "") + << (useParallelFor ? "_parallel" : "_serial") + << "_from_" << fromType + << "_to_" << toType; + if (scaleType != ov::element::Type_t::undefined) + result << "_scale_" << scaleType; + if (zeropType != ov::element::Type_t::undefined) + result << "_zerop_" << zeropType; + + return result.str(); + } + + /** + * Negative test cases has to be carefully reviewed, to still remain positive runs at some points + * @return + */ + virtual bool isNegative() const { + return false; + } + + virtual void make_ref_output() { + size_t nElements = 1; + for (size_t dim : input_shape) { + nElements *= dim; + } + if (toType == ov::element::i8) { + details::unpack(input.data(), ref_output.data(), static_cast(nElements)); + } else if (toType == ov::element::f16) { + details::unpack_i4f16(input.data(), ref_output.data(), static_cast(nElements)); + } + } +}; + +template +class UnpackTestsTmpl : + public ::testing::Test, + public T, + public ::testing::WithParamInterface { +protected: + + void SetUp() override { + T::SetUp(GetParam()); + } +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + T _bt; + _bt.SetUp(obj.param); + return _bt.ToString(); + } +}; + +using UnpackTests = UnpackTestsTmpl; +class UnpackTestsRef : public UnpackTests {}; + +TEST_P(UnpackTests, i4) { + ASSERT_NO_THROW_WITH_MESSAGE(ov::npuw::util::unpack(from, to, ov::npuw::util::UnpackOptions{useParallelFor, nPartitions, strictPartitions})); + ASSERT_TRUE(details::fp16ArraysMatch(output, ref_output, input)); +} + +class UnpackWithScaleTestsBase : public UnpackTestsBase { +protected: + bool isNegative() const override { + if (scale_shape.size() != 3 && scale_shape.size() != 2) return true; + if (input_shape.back() % 64) return true; + if ((from->get_size() / scale->get_size()) % 64) return true; + if (toType != ov::element::f16) return true; + + return false; + } + + void make_ref_output() override { + if (isNegative()) return; + + size_t nElements = from->get_size(); + + const size_t nOutputElementsPerScale = ref_output.size() / (toType.bitwidth() / 8) / scale->get_size(); + + details::unpack_i4f16(input.data(), ref_output.data(), static_cast(nElements)); + + // lets apply per channel scale + uint16_t * pRef = reinterpret_cast(ref_output.data()); + uint16_t * pScale_f16 = reinterpret_cast(scale->data()); + float * pScale_f32 = reinterpret_cast(scale->data()); + + for (size_t i = 0; i < scale->get_size(); i++) { + for (size_t sc = 0; sc != nOutputElementsPerScale; sc++) { + float ref_scaled = details::half_to_float(pRef[0]); + if (scaleType == ov::element::f32) { + ref_scaled *= pScale_f32[0]; + } else if (scaleType == ov::element::f16) { + ref_scaled *= details::half_to_float(pScale_f16[0]); + } + *pRef = details::float_to_half(ref_scaled); + pRef++; + } + pScale_f32++; + pScale_f16++; + } + } + +}; + +using UnpackWithScaleTests = UnpackTestsTmpl; + + +TEST_P(UnpackWithScaleTests, i4_scale) { + ASSERT_NO_THROW_IF(!isNegative(), + ov::npuw::util::unpack(from, scale, to, ov::npuw::util::UnpackOptions{useParallelFor, nPartitions, strictPartitions})); + if (!isNegative()) { + ASSERT_TRUE(details::fp16ArraysMatch(output, ref_output, input)); + } +} + + +class UnpackTestsWithScaleAndZeroPointBase : public UnpackTestsBase { +protected: + bool isNegative() const override { + if (scale_shape.size() != 3 && scale_shape.size() != 2) return true; + if (input_shape.back() % 64) return true; + + return false; + } + + void make_ref_output() override { + if (isNegative()) return; + + size_t nElements = from->get_size(); + + const size_t nOutputElementsPerScale = ref_output.size() / (toType.bitwidth() / 8) / scale->get_size(); + + std::vector floatRef(nElements); + details::unpack_u4f32(input.data(), floatRef.data(), static_cast(nElements)); + + + // lets apply per channel scale + uint16_t * pRef = reinterpret_cast(ref_output.data()); + float * pFloatRef = reinterpret_cast(floatRef.data()); + const uint16_t * pScale_f16 = reinterpret_cast(scale->data()); + const float * pScale_f32 = reinterpret_cast(scale->data()); + + for (size_t i = 0; i < scale->get_size(); i++) { + for (size_t sc = 0; sc != nOutputElementsPerScale; sc++) { + // applying zeropoint + float ref_scaled = *pFloatRef - zeropValue; + + if (scaleType == ov::element::f32) { + ref_scaled *= pScale_f32[0]; + } else if (scaleType == ov::element::f16) { + ref_scaled *= details::half_to_float(pScale_f16[0]); + } + *pRef = details::float_to_half(ref_scaled); + + pFloatRef++; + pRef++; + } + pScale_f32++; + pScale_f16++; + } + } +}; + +using UnpackTestsWithScaleAndZeroPoint = UnpackTestsTmpl; + +TEST_P(UnpackTestsWithScaleAndZeroPoint, u4) { + ASSERT_NO_THROW_IF(!isNegative(), + ov::npuw::util::unpack(from, zerop, scale, to, ov::npuw::util::UnpackOptions{useParallelFor, nPartitions, strictPartitions})); + if (!isNegative()) { + ASSERT_TRUE(details::fp16ArraysMatch(output, ref_output, input, false)); + } +} + +class UnpackTestsWithScaleAndZeroPoint2 : public UnpackTestsWithScaleAndZeroPointBase { +protected: + bool isNegative() const override { + if (input_shape.back() % 64 || input_shape.size() != 3) return true; + if (scale_shape.back() % 64 || scale_shape.size() != 3) return true; + + return false; + } + + void make_ref_output() override { + if (isNegative()) return; + + size_t nElements = from->get_size(); + const auto from_shape = from->get_shape(); + + const size_t C = from_shape[from_shape.size() - 3]; + const size_t H = from_shape[from_shape.size() - 2]; + const size_t W = from_shape[from_shape.size() - 1]; + + std::vector floatRef(nElements); + details::unpack_u4f32(input.data(), floatRef.data(), static_cast(nElements)); + + uint16_t * pRef = reinterpret_cast(ref_output.data()); + float * pFloatRef = reinterpret_cast(floatRef.data()); + const uint16_t * pScale_f16 = reinterpret_cast(scale->data()); + const float * pScale_f32 = reinterpret_cast(scale->data()); + + for (size_t c = 0; c < C; ++c) { + for (size_t h = 0; h < H; ++h) { + for (size_t w = 0; w < W; ++w) { + size_t input_index = w + W * h + W * H * c; + size_t scale_index = w + W * c; + float ref_scaled = pFloatRef[input_index] - zeropValue; + if (scaleType == ov::element::f32) { + ref_scaled *= pScale_f32[scale_index]; + } else if (scaleType == ov::element::f16) { + ref_scaled *= details::half_to_float(pScale_f16[scale_index]); + } + pRef[w + W * h + c * W * H] = details::float_to_half(ref_scaled); + } + } + } + } +}; + +using UnpackTestsWithScaleAndZeroPointTest2 = UnpackTestsTmpl; + +TEST_P(UnpackTestsWithScaleAndZeroPointTest2, u4) { + ASSERT_NO_THROW_IF(!isNegative(), + ov::npuw::util::unpack(from, zerop, scale, to, ov::npuw::util::UnpackOptions{useParallelFor, nPartitions, strictPartitions})); + if (!isNegative()) { + ASSERT_TRUE(details::fp16ArraysMatch(output, ref_output, input, false)); + } +} + +class UnpackTestsWithScaleAndZeroPoint3 : public UnpackTestsWithScaleAndZeroPointBase { +protected: + bool isNegative() const override { + if (scale_shape.size() != 3 || zerop_shape.size() != 3) return true; + if (input_shape[2] % 64 || input_shape.size() != 3) return true; + + return false; + } + + void make_ref_output() override { + if (isNegative()) return; + + size_t nElements = from->get_size(); + + const size_t nOutputElementsPerScale = ref_output.size() / (toType.bitwidth() / 8) / scale->get_size(); + + std::vector floatRef(nElements); + details::unpack_u4f32(input.data(), floatRef.data(), static_cast(nElements)); + + + // lets apply per channel scale + uint16_t * pRef = reinterpret_cast(ref_output.data()); + const uint8_t* pZer = static_cast(zerop->data()); + float * pFloatRef = reinterpret_cast(floatRef.data()); + const uint16_t * pScale_f16 = reinterpret_cast(scale->data()); + const float * pScale_f32 = reinterpret_cast(scale->data()); + + for (size_t i = 0; i < scale->get_size(); i++) { + float zeroPointValue = static_cast((i % 2 == 0) ? details::lo4(pZer[i / 2]) : details::hi4(pZer[i / 2])); + for (size_t sc = 0; sc != nOutputElementsPerScale; sc++) { + // applying zeropoint + float ref_scaled = *pFloatRef - zeroPointValue; + + if (scaleType == ov::element::f32) { + ref_scaled *= pScale_f32[0]; + } else if (scaleType == ov::element::f16) { + ref_scaled *= details::half_to_float(pScale_f16[0]); + } + *pRef = details::float_to_half(ref_scaled); + + pFloatRef++; + pRef++; + } + pScale_f32++; + pScale_f16++; + } + } +}; + +using UnpackTestsWithScaleAndZeroPointTest3 = UnpackTestsTmpl; + +TEST_P(UnpackTestsWithScaleAndZeroPointTest3, u4) { + ASSERT_NO_THROW_IF(!isNegative(), + ov::npuw::util::unpack(from, zerop, scale, to, ov::npuw::util::UnpackOptions{useParallelFor, nPartitions, strictPartitions})); + if (!isNegative()) { + ASSERT_TRUE(details::fp16ArraysMatch(output, ref_output, input, false)); + } +} + +#define Tensors [](std::vector& input, std::vector&scale, std::vector&zerop) + + +namespace details { +::testing::internal::ParamGenerator::value_type> ShapesIn( + const std::vector& container) { + return ::testing::ValuesIn(container.begin(), container.end()); +} + +} // namespace details +} // anonymous namespace diff --git a/tests/layer_tests/tensorflow_tests/test_tf_ExpandDims.py b/tests/layer_tests/tensorflow_tests/test_tf_ExpandDims.py index f0f9085d32ba2f..e982867c9ac08d 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_ExpandDims.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_ExpandDims.py @@ -6,6 +6,7 @@ import tensorflow as tf from common.tf_layer_test_class import CommonTFLayerTest +rng = np.random.default_rng(62362) class TestExpandDims(CommonTFLayerTest): def _prepare_input(self, inputs_info): @@ -40,3 +41,54 @@ def test_expand_dims_basic(self, params, ie_device, precision, ir_version, temp_ self._test(*self.create_expand_dims_net(**params), ie_device, precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) + + +class TestExpandDimsComplex(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + # generate elements so that the input tensor may contain repeating elements + assert 'param_real:0' in inputs_info + assert 'param_imag:0' in inputs_info + + input_shape = inputs_info['param_real:0'] + + inputs_data = {} + inputs_data['param_real:0'] = rng.integers(-10.0, 10.0, input_shape).astype(np.float32) + inputs_data['param_imag:0'] = rng.integers(-10.0, 10.0, input_shape).astype(np.float32) + + return inputs_data + + def create_expand_dims_complex_net(self, axis_dtype, input_shape, axis): + tf.compat.v1.reset_default_graph() + with tf.compat.v1.Session() as sess: + param_real = tf.compat.v1.placeholder(np.float32, input_shape, 'param_real') + param_imag = tf.compat.v1.placeholder(np.float32, input_shape, 'param_imag') + + complex = tf.raw_ops.Complex(real=param_real, imag=param_imag) + + axis = tf.constant(axis, dtype=axis_dtype) + + result = tf.raw_ops.ExpandDims(input=complex, axis=axis) + + tf.raw_ops.Real(input=result) + tf.raw_ops.Imag(input=result) + + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_basic = [ + dict(input_shape=[], axis=0), + dict(input_shape=[2, 3], axis=1), + dict(input_shape=[2, 3, 4], axis=-1), + dict(input_shape=[2, 6, 5], axis=-2), + ] + + @pytest.mark.parametrize("axis_dtype", [np.int32, np.int64]) + @pytest.mark.parametrize("op_args", test_basic) + @pytest.mark.nightly + @pytest.mark.precommit + def test_expand_dims_basic_complex(self, axis_dtype, op_args, ie_device, precision, ir_version, temp_dir, use_legacy_frontend): + self._test(*self.create_expand_dims_complex_net(axis_dtype, **op_args), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_legacy_frontend=use_legacy_frontend)