Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPUW] extend DQ & PMM processing and make reduceSum not to keep axis #26779

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {
auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({qweight});
auto qmuls = opp::wrap_type<ov::op::v1::Multiply>({qcvtw, qcoeff});
auto qreshp = opp::wrap_type<ov::op::v1::Reshape>({qmuls, opp::any_input()});
auto qcvtr = opp::wrap_type<ov::op::v0::Convert>({qreshp});
auto qcvtr = opp::optional<ov::op::v0::Convert>({qreshp->output(0)});
auto qmmi = opp::any_input();
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({qmmi, qcvtr});

Expand Down Expand Up @@ -409,13 +409,18 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {
auto rshp_ccat = std::make_shared<ov::op::v1::Reshape>(scaled, rshp_ccat_c, false);

auto reduce_axis = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 1);
auto reduce = std::make_shared<ov::op::v1::ReduceSum>(rshp_ccat, reduce_axis, true);
// Make reduceSum not to keep axis because then it will convert to poolings in compiler.
// Otherwise reduceSum will convert to the convolution which is less efficient than poolings.
auto reduce = std::make_shared<ov::op::v1::ReduceSum>(rshp_ccat, reduce_axis, false);

auto rshp_out_c = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, out_shape);
auto rshp_out = std::make_shared<ov::op::v1::Reshape>(reduce, rshp_out_c, false);

// Convert the result to f32 to maintain the graph contracts. FIXME should be avoided
auto out = std::make_shared<ov::op::v0::Convert>(rshp_out, ov::element::f32);
// Convert the result to f32 to maintain the graph contracts if required.
std::shared_ptr<ov::Node> out = rshp_out;
if (matched_matmul->get_element_type() == ov::element::f32) {
out = std::make_shared<ov::op::v0::Convert>(rshp_out, ov::element::f32);
}

// Now.. Reconnect the matmul readers to the new output (reducesum)
for (auto&& r : matched_matmul->output(0).get_target_inputs()) {
Expand Down Expand Up @@ -746,7 +751,7 @@ void mergeParallelMatMuls(const std::shared_ptr<ov::Model>& m, Context& ctx) {
auto new_cvt = std::make_shared<ov::op::v0::Convert>(new_w, new_s->get_element_type());

std::shared_ptr<ov::Node> new_mul = std::make_shared<ov::op::v1::Multiply>(new_cvt, new_s);
if (new_s->get_element_type() == ov::element::f16) {
if ((new_s->get_element_type() == ov::element::f16) && (orig_multiply.get_element_type() == ov::element::f32)) {
new_mul = std::make_shared<ov::op::v0::Convert>(new_mul, ov::element::f32);
}
auto new_w_shape = new_w->get_shape();
Expand Down
Loading