Skip to content

Commit

Permalink
NPUW Spatial: Enable PMM for spatial subgraphs
Browse files Browse the repository at this point in the history
  • Loading branch information
dmatveev committed Oct 3, 2024
1 parent d99450a commit c8975f1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1744,6 +1744,7 @@ void Partitioner::optimize(const std::string& func_name) {
// Regardless of DQ setting, run this first
{
ov::npuw::patterns::opt::Context ctx;
ctx.is_spatial = f._spatial.has_value();
ctx.pmm_dims = cfg.get<::intel_npu::NPUW_PMM>();

// Run Head/Tail passes
Expand Down Expand Up @@ -1890,6 +1891,8 @@ void Partitioner::optimize(const std::string& func_name) {

// Run "dynamic quantization"
ov::npuw::patterns::opt::Context ctx;
ctx.is_spatial = f._spatial.has_value();

ov::pass::GraphRewrite rewr;
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulCWi>();
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulGQi>(std::ref(ctx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,8 +687,13 @@ DQParMMGQ::DQParMMGQ(Context::Ref ctx) {

auto qmmi_shape = node_to_output.at(qmm).get_shape();

if (qmmi_shape.size() != 3 || qmmi_shape[0] != 1 || qmmi_shape[1] != 1) {
// Limit token to 1-token shapes only (prefill requires its own tranformation)
if (qmmi_shape.size() != 3 || qmmi_shape[0] != 1) {
// Not handling such cases
return false;
}

if (qmmi_shape[1] != 1 && !ctx.get().is_spatial) {
// For non 1-token cases, do transformation if and only if and only if the block is spatial
return false;
}

Expand All @@ -709,9 +714,12 @@ void mergeParallelMatMuls(const std::shared_ptr<ov::Model>& m, Context& ctx) {
continue;
}
ov::Output<ov::Node> orig_multiply;

std::size_t axis_to_concat = -1;
std::tie(orig_multiply, axis_to_concat) = mul_to_mms.first;

const ov::Shape orig_act_shape = orig_multiply.get_shape();

if (!util::is_set(axis_to_concat, ctx.pmm_dims)) {
LOG_VERB("Parallel MatMuls found, but fusion over dim " << axis_to_concat << " is not enabled");
continue;
Expand Down Expand Up @@ -776,7 +784,7 @@ void mergeParallelMatMuls(const std::shared_ptr<ov::Model>& m, Context& ctx) {
auto this_slice_end =
std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
S{1, 1, offset + this_orig_wshape[axis_to_concat]});
S{1, orig_act_shape[1], offset + this_orig_wshape[axis_to_concat]});
auto this_slice_step = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, S{1, 1, 1});
auto this_slice =
std::make_shared<ov::op::v8::Slice>(new_mm, this_slice_start, this_slice_end, this_slice_step);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class DQMatMulCWi : public ov::pass::MatcherPass {

struct Context {
std::string pmm_dims;
bool is_spatial = false;

using PPtr = std::shared_ptr<ov::op::v0::Parameter>;
using NPtr = std::shared_ptr<ov::Node>;
Expand Down

0 comments on commit c8975f1

Please sign in to comment.