Skip to content

Commit

Permalink
Merge branch 'master' into itikhono/bug_fix/eliminate_convert_fp32_to…
Browse files Browse the repository at this point in the history
…_fp32
  • Loading branch information
itikhono authored Oct 3, 2024
2 parents bab21cc + 17ecf03 commit 2d79d21
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ NPU Device
:hidden:

npu-device/remote-tensor-api-npu-plugin
npu-device/batching-on-npu-plugin


The Neural Processing Unit is a low-power hardware solution, introduced with the
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
NPU Plugin Batching
===============================


.. meta::
:description: OpenVINO™ NPU plugin supports batching
either by executing concurrent inferences or by
relying on native compiler support for batching.

OpenVINO™ NPU plugin supports batching either by executing concurrent inferences or by relying on native compiler support for batching.

First, the NPU plugin checks if the following conditions are met:

* The batch size is on the first axis.
* All inputs and outputs have the same batch size.
* The model does not contain states.

**If the conditions are met**, the NPU plugin attempts to compile and execute the original model with batch_size forced to 1. This approach is due to current compiler limitations and ongoing work to improve performance for batch_size greater than one.
If the compilation is successful, the plugin detects a difference in batch size between the original model layout (with a batch size set to N)
and the transformed/compiled layout (with a batch size set to 1). Then it executes the following steps:

1. Internally constructs multiple command lists, one for each input.
2. Executes each command list for the proper offsets of input/output buffers.
3. Notifies the user of the completion of the inference request after all command lists have been executed.

This concurrency-based batching mode is transparent to the application. A single inference request handles all inputs from the batch.
While performance may be lower compared to regular batching (based on native compiler support), this mode provides basic batching functionality for use either with older drivers
or when the model cannot yet be compiled with a batch size larger than one.

**If the conditions are not met**, the NPU plugin tries to compile and execute the original model with the given
batch_size to N as any other regular model.

.. note::

With future performance improvements and support for compiling multiple models with a batch size larger
than one, the default order will change. NPU will try first to compile and execute the original model with the
given batch size and fall back to concurrent batching if compilation fails.
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,20 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)(
uint input_offset = out_b * TILE_IN_B_PITCH + INPUT0_OFFSET;
#endif

#if FILTER_LAYOUT_OS_IS_YX_OSV64_ISV2
const int power_of_two_for_simd = 5;
const int power_of_two_for_osv = 6;
const uint osv64_weight_base = (( (int) (out_f >> power_of_two_for_osv) ) << power_of_two_for_osv);
const uint osv_weight_stride = (INPUT_ELEMENTS_COUNT >> 1);
const uint out_f_offset = (int)((out_f >> power_of_two_for_simd) & 0x1) << power_of_two_for_simd;
// out_f(32) : 0 * osv_weight_stride + 32;
// out_f(64) : 64 * osv_weight_stride + 0;
// out_f(128) : 64 * osv_weight_stride + 32;
// ...
uint weights_offset = osv64_weight_base * osv_weight_stride + out_f_offset;
#else
uint weights_offset = out_f * (INPUT_ELEMENTS_COUNT / 2);
#endif

ACCUMULATOR_VEC_TYPE acc[TILE_B] = { };

Expand Down Expand Up @@ -905,7 +918,11 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)(

__local int* char_slm_weight = (__local int*)wei_local_mem;

#if FILTER_LAYOUT_OS_IS_YX_OSV64_ISV2
uint weights_idx = weights_offset + local_id * SIMD * FILTER_LOAD_ITERS * FILTER_LOAD_BLOCK_SIZE * 2;
#else
uint weights_idx = weights_offset + local_id * SIMD * FILTER_LOAD_ITERS * FILTER_ACTUAL_LOAD_BLOCK_SIZE;
#endif
uint wei_local_idx = local_id * SIMD * FILTER_LOAD_ITERS * (FILTER_LOAD_BLOCK_SIZE/2) + sglid * 2;

// DECOMPRESSION_SCALE_POST_OP SHOULD be enabled for dynamic quantize FC : scale is ACCUMULATOR_VAL_ONE
Expand All @@ -917,6 +934,17 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)(
// loaded weights 'wei_packed' of os_iyx_osv16 format have continuous values along TILE_K. So no need to transpose while unpacking
dq_wei_unpacked.s0123 = UNPACK_INT4(DQ_TYPE, *((INT4_PACKED_TYPE_PRELOAD*)&wei_packed0));
dq_wei_unpacked.s4567 = UNPACK_INT4(DQ_TYPE, *((INT4_PACKED_TYPE_PRELOAD*)&wei_packed1));
#elif FILTER_LAYOUT_OS_IS_YX_OSV64_ISV2
SLM_FILTER_PACKED_VEC wei_packed0 = BLOCK_READN(FILTER_TYPE, FILTER_ACTUAL_LOAD_BLOCK_SIZE, weights, weights_idx);
SLM_FILTER_PACKED_VEC wei_packed1 = BLOCK_READN(FILTER_TYPE, FILTER_ACTUAL_LOAD_BLOCK_SIZE, weights, (weights_idx + (FILTER_LOAD_BLOCK_SIZE * SIMD)));
DQ_SLM_FILTER_UNPACKED_VEC dq_wei_unpacked;
DQ_SLM_FILTER_UNPACKED_VEC dq_wei_unpacked_tmp;
dq_wei_unpacked_tmp.s0123 = UNPACK_INT4(DQ_TYPE, *((INT4_PACKED_TYPE_PRELOAD*)&wei_packed0));
dq_wei_unpacked_tmp.s4567 = UNPACK_INT4(DQ_TYPE, *((INT4_PACKED_TYPE_PRELOAD*)&wei_packed1));
dq_wei_unpacked.s01 = dq_wei_unpacked_tmp.s01;
dq_wei_unpacked.s23 = dq_wei_unpacked_tmp.s45;
dq_wei_unpacked.s45 = dq_wei_unpacked_tmp.s23;
dq_wei_unpacked.s67 = dq_wei_unpacked_tmp.s67;
#else
SLM_FILTER_PACKED_VEC wei_packed = BLOCK_READN(FILTER_TYPE, FILTER_LOAD_BLOCK_SIZE, weights, weights_idx);
DQ_SLM_FILTER_UNPACKED_VEC dq_wei_unpacked = UNPACK_TRANSPOSED_INT4(DQ_TYPE, *((INT4_PACKED_TYPE_PRELOAD *)&wei_packed));
Expand Down Expand Up @@ -996,11 +1024,7 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)(
acc_tmp[1][bi] = imad_SW(acc_tmp[1][bi], input_val, second_weight);
}

#if FILTER_LAYOUT_OS_IYX_OSV16 && TILE_OFM == 2
weights_offset += (TILE_K_OFM_PACKED/2) * SIMD;
#else
weights_offset += TILE_K_OFM_PACKED * SIMD;
#endif
weights_offset += TILE_K_OFM_PACKED * TILE_OFM_PER_OSV_SIZE * SIMD;

#if DECOMPRESSION_SCALE_POST_OP && (TILE_IFM_ELEMENTS_SIZE > DECOMPRESSION_SCALE_GROUP_SIZE)
unroll_for (uint bi = 0; bi < TILE_B; ++bi) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,7 @@ KernelsData FullyConnected_bf_tiled::GetTunedKernelsDataByIndex(const Params &pa
auto output_f = get_output_aligned_bf_size(fc_params, false).second;

WeightsLayout weights_layout = WeightsLayout::os_iyx_osv16;
// TODO: Update may also be required to fc_bf_tiled_kernel_dyn_quan kernel to support os_is_yx_osv64_isv2 format as needed
if (!should_dynamic_quantize(fc_params) && fc_params.compressed && fc_params.inputs[0].GetDType() == Datatype::F16
if (fc_params.compressed && fc_params.inputs[0].GetDType() == Datatype::F16
&& (fc_params.weights.GetLayout() == WeightsLayout::oiyx || fc_params.weights.GetLayout() == WeightsLayout::os_is_yx_osv64_isv2)
&& (fc_params.weights.GetDType() == WeightsType::INT4 || fc_params.weights.GetDType() == WeightsType::UINT4)
&& is_weight_horizontal(fc_params, output_f)) {
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ ov::npuw::CompiledModel::CompiledModel(const std::shared_ptr<ov::Model>& model,
}
auto process_params = [&](const ov::ParameterVector& _parameters) {
for (size_t i = 0; i < _parameters.size(); i++) {
NPUW_ASSERT(_parameters[i]);
LOG_VERB(_parameters[i]);
for (size_t j = 0; j < orig_parameters.size(); j++) {
if (_parameters[i] == orig_parameters[j]) {
Expand Down
58 changes: 45 additions & 13 deletions src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "partitioning.hpp"

#include <memory>

#include "../logging.hpp"
#include "../util.hpp"
#include "intel_npu/al/config/npuw.hpp"
Expand All @@ -20,6 +22,26 @@
#include "patterns/dcoff.hpp"
#include "patterns/opt.hpp"

namespace ov {
namespace npuw {
inline bool operator==(const std::reference_wrapper<Subgraph>& lhs, const std::reference_wrapper<Subgraph>& rhs) {
ov::npuw::Subgraph& llink = lhs.get();
ov::npuw::Subgraph& rlink = rhs.get();
return &llink == &rlink;
}
} // namespace npuw
} // namespace ov

template <typename T2>
struct std::hash<std::pair<ov::npuw::Subgraph::Ref, T2>> {
std::size_t operator()(std::pair<ov::npuw::Subgraph::Ref, T2> const& p) const noexcept {
ov::npuw::Subgraph& sg = p.first.get();
std::size_t h1 = std::hash<void*>{}(&sg);
std::size_t h2 = std::hash<T2>{}(p.second);
return h1 ^ (h2 << 1);
}
};

namespace {

class FuncallEverywhere {
Expand Down Expand Up @@ -161,6 +183,8 @@ class Partitioner {

using PPtr = std::shared_ptr<ov::op::v0::Parameter>;
using RPtr = std::shared_ptr<ov::op::v0::Result>;
using SubgParam = std::pair<ov::npuw::Subgraph::Ref, PPtr>;
using SubgResult = std::pair<ov::npuw::Subgraph::Ref, RPtr>;
using LinkPtrTo = std::pair<size_t /*submodel_idx*/
,
PPtr /*param ptr*/
Expand All @@ -182,8 +206,8 @@ class Partitioner {

// Map every function call instance' Parameter and result
// back to its prototype Parameter and Result
std::unordered_map<PPtr, PPtr> param_call_to_proto;
std::unordered_map<RPtr, RPtr> result_call_to_proto;
std::unordered_map<SubgParam, PPtr> param_call_to_proto;
std::unordered_map<SubgResult, RPtr> result_call_to_proto;
};
std::map<std::string, FunctionPipeline> all_functions;

Expand All @@ -203,7 +227,10 @@ class Partitioner {
void createFunction(FunctionPipeline& func_ggg);

template <typename T, typename M>
void rearrange_to_function_protocol(const std::vector<T>& protocol, std::vector<T>& call, const M& call_to_proto) {
void rearrange_to_function_protocol(ov::npuw::Subgraph::Ref func_ref,
const std::vector<T>& protocol,
std::vector<T>& call,
const M& call_to_proto) {
LOG_DEBUG("Rearranging...");
LOG_BLOCK();
LOG_DEBUG("Protocol: " << protocol.size());
Expand All @@ -215,7 +242,7 @@ class Partitioner {
LOG_DEBUG("Call: " << call.size());
for (auto&& c : call) {
LOG_BLOCK();
auto p_c = call_to_proto.at(c);
auto p_c = call_to_proto.at(typename M::key_type(func_ref, c));
to_proto.push_back(p_c);
LOG_DEBUG(c << " (which is " << p_c << ")");
}
Expand Down Expand Up @@ -536,7 +563,7 @@ void Partitioner::identifySubgraphs() {
LOG_VERB("Processing group's output layer " << output_layer_name);
LOG_BLOCK();
auto output_layer_ptr = node_id_cache.at(output_layer_name);
if (output_layer_ptr->inputs().empty()) {
if (output_layer_ptr->outputs().empty()) {
OPENVINO_THROW("The group's output layer ",
output_layer_name,
" has NO OUTPUTS!! - Graph contracts are broken??");
Expand Down Expand Up @@ -1327,9 +1354,12 @@ void Partitioner::matchParameters(const std::string& func_name) {

// Now walk other submodels and match parameters with the same key
// (yes, including the first one)
for (auto&& call : model_group) {
for (std::size_t call_id = 0; call_id < model_group.size(); ++call_id) {
LOG_DEBUG("Handle function call...");
LOG_BLOCK();
auto call = model_group[call_id];
auto subg_ref = func.refs[call_id];

std::unordered_set<ov::Node*> this_model_nodes;
for (auto&& node_ptr : call->get_ordered_ops()) {
this_model_nodes.insert(node_ptr.get());
Expand All @@ -1348,7 +1378,7 @@ void Partitioner::matchParameters(const std::string& func_name) {
LOG_DEBUG("Find orig parameter for " << node);
auto& orig_param = proto_parameters.at(pkey);
auto this_param = std::dynamic_pointer_cast<PPtr::element_type>(node);
func.param_call_to_proto[this_param] = orig_param;
func.param_call_to_proto[SubgParam(subg_ref, this_param)] = orig_param;
}
}
}
Expand Down Expand Up @@ -1386,14 +1416,16 @@ void Partitioner::matchResults(const std::string& func_name) {

// Now walk all submodels and match parameters with the same key
// (yes, including the first one)
for (auto&& call : model_group) {
for (std::size_t call_idx = 0; call_idx < model_group.size(); ++call_idx) {
auto call = model_group[call_idx];
auto subg_ref = func.refs[call_idx];
for (auto&& node : call->get_ordered_ops()) {
if (ov::op::util::is_output(node)) {
auto&& port = node->input(0).get_source_output();
RKey rkey = {layer_to_prototype.at(port.get_node()->get_friendly_name()), port.get_index()};
auto& orig_result = proto_results.at(rkey);
auto this_result = std::dynamic_pointer_cast<RPtr::element_type>(node);
func.result_call_to_proto[this_result] = orig_result;
func.result_call_to_proto[SubgResult(subg_ref, this_result)] = orig_result;
}
}
}
Expand Down Expand Up @@ -1517,8 +1549,8 @@ void Partitioner::matchRepeatedSubgraphs(const std::string& func_name) {
funcall._gflops = this_sg._gflops; // duplicated code again!
funcall._ops = this_sg._ops; // duplicated code again!
funcall._avoid_list = this_sg._avoid_list; // duplicated code again!
rearrange_to_function_protocol(body_params, funcall._parameters, func_ggg.param_call_to_proto);
rearrange_to_function_protocol(body_results, funcall._results, func_ggg.result_call_to_proto);
rearrange_to_function_protocol(this_sg, body_params, funcall._parameters, func_ggg.param_call_to_proto);
rearrange_to_function_protocol(this_sg, body_results, funcall._results, func_ggg.result_call_to_proto);

auto func_iter = P.functions.find(func_name);
NPUW_ASSERT(func_iter != P.functions.end());
Expand Down Expand Up @@ -1883,7 +1915,7 @@ void Partitioner::finalizeLinks() {
auto& params = P.functions.at(sg_desc._funcall)._model->get_parameters();
auto& proto = func_pipeline_type == FunctionPipelineType::CWAI
? ptr // no protos in the CWAI case..
: all_functions.at(sg_desc._funcall).param_call_to_proto.at(ptr);
: all_functions.at(sg_desc._funcall).param_call_to_proto.at(SubgParam(sg_desc, ptr));
auto param_iter = std::find(params.begin(), params.end(), proto);
NPUW_ASSERT(param_iter != params.end());
return std::distance(params.begin(), param_iter);
Expand All @@ -1904,7 +1936,7 @@ void Partitioner::finalizeLinks() {
auto& results = P.functions.at(sg_desc._funcall)._model->get_results();
auto& proto = func_pipeline_type == FunctionPipelineType::CWAI
? ptr // no protos in the CWAI case...
: all_functions.at(sg_desc._funcall).result_call_to_proto.at(ptr);
: all_functions.at(sg_desc._funcall).result_call_to_proto.at(SubgResult(sg_desc, ptr));
auto result_iter = std::find(results.begin(), results.end(), proto);
NPUW_ASSERT(result_iter != results.end());
return std::distance(results.begin(), result_iter);
Expand Down
8 changes: 5 additions & 3 deletions src/plugins/intel_npu/tools/single-image-test/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,8 @@ bool computeRRMSE(const ov::Tensor& output, const ov::Tensor& reference) {

double rrmseLoss = sqrt(error / sum);

std::cout << "RRMSE loss : " << rrmseLoss << " RRMSE threshold : " << FLAGS_rrmse_loss_threshold << std::endl;
std::cout << "RRMSE loss : " << std::fixed << std::setprecision(4) << rrmseLoss
<< " RRMSE threshold : " << FLAGS_rrmse_loss_threshold << std::endl;
return rrmseLoss <= FLAGS_rrmse_loss_threshold;
}

Expand Down Expand Up @@ -1267,7 +1268,8 @@ bool computeNRMSE(const ov::Tensor& output, const ov::Tensor& reference) {
double nrmseLoss =
sqrt(error / size) / std::max(0.001f, std::max(maxOutput - minOutput, maxReference - minReference));

std::cout << "NRMSE loss : " << nrmseLoss << " NRMSE threshold : " << FLAGS_nrmse_loss_threshold << std::endl;
std::cout << "NRMSE loss : " << std::fixed << std::setprecision(4) << nrmseLoss
<< " NRMSE threshold : " << FLAGS_nrmse_loss_threshold << std::endl;
return nrmseLoss <= FLAGS_nrmse_loss_threshold;
}

Expand Down Expand Up @@ -1319,7 +1321,7 @@ bool testPSNR(const TensorMap& outputs, const TensorMap& references, const int d

auto result = utils::runPSNRMetric(actOutput, refOutput, dstHeight, dstWidth, scaleBorder, normalizedImage);

if (std::fabs(result - FLAGS_psnr_reference) > FLAGS_psnr_tolerance) {
if (FLAGS_psnr_reference - result > FLAGS_psnr_tolerance) {
std::cout << "Absolute difference between actual value " << result << " and reference value "
<< FLAGS_psnr_reference << " larger then tolerance " << FLAGS_psnr_tolerance << std::endl;
return false;
Expand Down

0 comments on commit 2d79d21

Please sign in to comment.