-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NPUW: Adding new config option to reshape weights #25691
Changes from 57 commits
d5fc909
5b7601c
0cf6a28
7a7a95c
cfb0b4b
36fde73
106a4c8
a938c39
ce72c44
e55169e
ec04cb2
ce76d55
a6a1abe
d649336
c25f928
0fb3efb
b300fe3
834cdfa
650ec8b
f37a9a6
1c09664
db80df5
cd90ae3
d75cc58
4eea38d
b1f3d53
72523b4
7b27abe
b59160b
3aee8bb
87f59ce
21bfcff
794c453
3055d6a
30ac9b6
ebaec3f
3bc472c
1380cc6
4a6551b
fc64a0e
7120553
98931cb
317a874
0114d19
73060e4
f208181
01d867b
13274c8
7bfa2ae
ab7d61a
7d0b326
957b2f3
f56e601
a09a867
462d27c
8f5de02
ed3682b
e1abf9d
73de9d9
95cbba9
7aeea0a
0d1b68f
fee3370
236d4f9
1af1103
02758b5
51094f8
3781f84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -1606,29 +1606,32 @@ void Partitioner::decompressionCutOff(const std::string& func_name) { | |||||
{ | ||||||
LOG_BLOCK(); | ||||||
|
||||||
bool enable_transpose = cfg.get<::intel_npu::NPUW_TRANSPOSE_WEIGHTS>(); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
ov::npuw::patterns::DCOFFParams params_to; | ||||||
|
||||||
ov::pass::GraphRewrite rewr; | ||||||
|
||||||
// Old LLaMa-v2 patterns (Symmetric) | ||||||
rewr.add_matcher<ov::npuw::patterns::SymmNoZP::DCOFFPassMatMul>(dcoff_mode, dcoff_type, std::ref(params_to)) | ||||||
rewr.add_matcher<ov::npuw::patterns::SymmNoZP::DCOFFPassMatMul>(dcoff_mode, dcoff_type, std::ref(params_to), enable_transpose) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as the configuration parameters go first, and the "output" remapping goes last, probably it should be
Suggested change
|
||||||
->build(); | ||||||
rewr.add_matcher<ov::npuw::patterns::SymmNoZP::DCOFFPassGather>(dcoff_mode, dcoff_type, std::ref(params_to)) | ||||||
rewr.add_matcher<ov::npuw::patterns::SymmNoZP::DCOFFPassGather>(dcoff_mode, dcoff_type, std::ref(params_to), enable_transpose) | ||||||
->build(); | ||||||
|
||||||
// ChatGLM (GPTQ) and New LLaMa-v2 patterns (Symmetric) | ||||||
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassReshape1>(dcoff_mode, dcoff_type, std::ref(params_to)) | ||||||
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassReshape1>(dcoff_mode, dcoff_type, std::ref(params_to), enable_transpose) | ||||||
->build(); | ||||||
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassConvert1>(dcoff_mode, dcoff_type, std::ref(params_to)) | ||||||
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassConvert1>(dcoff_mode, dcoff_type, std::ref(params_to), enable_transpose) | ||||||
->build(); | ||||||
|
||||||
// LLaMaGPTQ | ||||||
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassReshape2>(dcoff_mode, dcoff_type, std::ref(params_to)); | ||||||
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassReshape2>(dcoff_mode, dcoff_type, std::ref(params_to), enable_transpose); | ||||||
|
||||||
// Phi-3 4SymW16A/GPTQ | ||||||
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassCWAI3>(dcoff_mode, dcoff_type, std::ref(params_to)); | ||||||
// Phi-3 4SymW16A | ||||||
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassReshape3>(dcoff_mode, dcoff_type, std::ref(params_to), enable_transpose); | ||||||
|
||||||
// Asymmetric zeropoints | ||||||
rewr.add_matcher<ov::npuw::patterns::AsymmZP::DCOFFPassReshape>(dcoff_mode, dcoff_type, std::ref(params_to)); | ||||||
rewr.add_matcher<ov::npuw::patterns::AsymmZP::DCOFFPassReshape>(dcoff_mode, dcoff_type, std::ref(params_to), enable_transpose); | ||||||
|
||||||
rewr.run_on_model(f._model); | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,97 @@ | |
namespace ov { | ||
namespace npuw { | ||
|
||
namespace pattern_utils { | ||
|
||
std::shared_ptr<ov::op::v0::MatMul> find_matmul_downwards(const std::shared_ptr<ov::Node>& start_node) { | ||
std::shared_ptr<ov::Node> current_node = start_node; | ||
while (current_node) { | ||
// Check if the current node is a MatMul | ||
if (auto matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(current_node)) { | ||
return matmul; | ||
} | ||
// Move to the next node in the path if there is one | ||
if (!current_node->outputs().empty()) { | ||
auto output = current_node->outputs().at(0); | ||
if (!output.get_target_inputs().empty()) { | ||
current_node = output.get_target_inputs().begin()->get_node()->shared_from_this(); | ||
} else { | ||
// No further outputs, end the search | ||
break; | ||
} | ||
} else { | ||
// No outputs, end the search | ||
break; | ||
} | ||
} | ||
Comment on lines
+30
to
+48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should be no loop, it must be a straight direct link from your There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might not be true for every case. After the root node for example, reshape there maybe a convert present. |
||
return nullptr; // MatMul not found | ||
} | ||
|
||
|
||
std::shared_ptr<ov::op::v0::MatMul> get_root_matmul(ov::pass::pattern::Matcher& m) { | ||
auto& node_to_output = m.get_pattern_value_map(); | ||
|
||
// If the map is not empty, start the search from the first matched node | ||
if (!node_to_output.empty()) { | ||
auto start_node = node_to_output.begin()->second.get_node_shared_ptr(); | ||
return find_matmul_downwards(start_node); | ||
} | ||
|
||
// If the map is empty or no MatMul node is found, return nullptr | ||
LOG_DEBUG("NO MATMUL FOUND!"); | ||
return nullptr; | ||
} | ||
|
||
bool transpose_required(const std::shared_ptr<ov::op::v0::MatMul>& matmul_node) { | ||
if (!matmul_node) { | ||
LOG_DEBUG("NOT a MATMUL NODE!"); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just place an assert that the pointer is not null. Also, who'd pass a null pointer here? |
||
|
||
// Get the shape of the second input to the MatMul node | ||
const auto& input_shape = matmul_node->input_value(1).get_shape(); | ||
|
||
// Check if the highest dimension is not at the first index | ||
if (input_shape.size() >= 2) { | ||
size_t max_dim = *std::max_element(input_shape.begin(), input_shape.end()); | ||
if (input_shape[0] != max_dim) { | ||
return true; // Transpose is required | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this is not the check that's supposed to be here. |
||
|
||
return false; // Transpose is not required | ||
} | ||
|
||
void transpose_param_shape(std::shared_ptr<ov::op::v0::Parameter>& param) { | ||
auto partial_shape = param->get_partial_shape(); | ||
|
||
// Ensure the shape is static before proceeding | ||
if (partial_shape.is_static()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. replace with ASSERT. |
||
auto shape = partial_shape.to_shape(); | ||
|
||
// Check if the shape is 2D or 3D and needs transposing | ||
if (shape.size() == 2 && shape[0] < shape[1]) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe we shouldn't look at There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Stopped review at this point) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rectified! |
||
// For 2D shapes, swap the dimensions if the second dimension is larger | ||
std::swap(shape[0], shape[1]); | ||
} else if (shape.size() == 3) { | ||
// For 3D shapes, bring the largest dimension to the front | ||
auto max_dim_it = std::max_element(shape.begin(), shape.end()); | ||
if (max_dim_it != shape.begin()) { | ||
std::rotate(shape.begin(), max_dim_it, shape.end()); | ||
} | ||
} | ||
|
||
// Set the new shape to the parameter | ||
param->set_partial_shape(ov::PartialShape(shape)); | ||
LOG_DEBUG("Modifying the shape of: " << param << " to " << param->get_partial_shape()); | ||
} | ||
} | ||
|
||
ov::Tensor transpose_tensor(const ov::Tensor&) { | ||
|
||
} | ||
|
||
} // namespace matmul_utils | ||
|
||
namespace patterns { | ||
|
||
namespace opp = ov::pass::pattern; | ||
|
@@ -73,6 +164,10 @@ ClosureRemap build_remap(const Function& fbody, const DCOFFParams& params_to) { | |
LOG_DEBUG("Checking the function parameter " << param); | ||
LOG_BLOCK(); | ||
|
||
if(params_to.transpose_required.find(param) != params_to.transpose_required.end()) { | ||
m.transpose_indices.push_back(i - fbody._param_offset); | ||
} | ||
|
||
// First find among scale factors... | ||
auto pscale_iter = params_to.scales.find(param); | ||
auto pzerop_iter = params_to.zerops_asymm.find(param); | ||
|
@@ -122,15 +217,48 @@ void apply_remap(Subgraph& fcall, const ClosureRemap& m) { | |
// reserve a new_scales vector to have the same size, filled with | ||
// empty tensors by default. | ||
for (auto&& i : m.closure_remap) { | ||
new_closure.push_back(fcall._closure[i]); | ||
// Check if the index is marked for transposition | ||
if (std::find(m.transpose_indices.begin(), m.transpose_indices.end(), i) != m.transpose_indices.end()) { | ||
// Transpose the tensor before adding it to new_closure | ||
new_closure.push_back(pattern_utils::transpose_tensor(fcall._closure[i])); | ||
} else { | ||
// Add the original tensor to new_closure | ||
new_closure.push_back(fcall._closure[i]); | ||
} | ||
|
||
// Handle scale remap | ||
auto scale_iter = m.scale_remap.find(i); | ||
new_scales.push_back(scale_iter != m.scale_remap.end() ? fcall._closure[scale_iter->second] : ov::Tensor()); | ||
// Check for asymmetric zero points and add them to new_zerops | ||
if (scale_iter != m.scale_remap.end()) { | ||
// Check if the scale index is marked for transposition | ||
if (std::find(m.transpose_indices.begin(), m.transpose_indices.end(), scale_iter->second) != m.transpose_indices.end()) { | ||
// Transpose the tensor before adding it to new_scales | ||
new_scales.push_back(pattern_utils::transpose_tensor(fcall._closure[scale_iter->second])); | ||
} else { | ||
// Add the original tensor to new_scales | ||
new_scales.push_back(fcall._closure[scale_iter->second]); | ||
} | ||
} else { | ||
new_scales.push_back(ov::Tensor()); | ||
} | ||
|
||
// Handle zero point remap | ||
auto zerop_iter = m.zerop_remap.find(i); | ||
const auto& zerop = zerop_iter != m.zerop_remap.end() ? fcall._closure[zerop_iter->second] : m.zero_points[i]; | ||
new_zerops.push_back(zerop); | ||
if (zerop_iter != m.zerop_remap.end()) { | ||
// Check if the zero point index is marked for transposition | ||
if (std::find(m.transpose_indices.begin(), m.transpose_indices.end(), zerop_iter->second) != m.transpose_indices.end()) { | ||
// Transpose the tensor before adding it to new_zerops | ||
new_zerops.push_back(pattern_utils::transpose_tensor(fcall._closure[zerop_iter->second])); | ||
} else { | ||
// Add the original tensor to new_zerops | ||
new_zerops.push_back(fcall._closure[zerop_iter->second]); | ||
} | ||
} else { | ||
// Add the zero point tensor from the closure remap | ||
const auto& zerop = m.zero_points[i]; | ||
new_zerops.push_back(zerop); | ||
} | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks too much, you can form the tensor first and then transpose only the affected once. It may be one more separate loop. Will make the code clear. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed! |
||
fcall._closure = std::move(new_closure); | ||
fcall._scales = std::move(new_scales); | ||
fcall._zerops = std::move(new_zerops); | ||
|
@@ -192,10 +320,14 @@ void finalize_remap(Function& fbody, const ClosureRemap& m) { | |
// its Parameter B). | ||
namespace SymmNoZP { | ||
|
||
DCOFFPassBase::DCOFFPassBase(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) | ||
DCOFFPassBase::DCOFFPassBase(DCOffMode dcoff_mode, | ||
ov::element::Type dcoff_type, | ||
DCOFFParamRef pref, | ||
bool enable_transpose) | ||
: m_dcoff_mode(dcoff_mode), | ||
m_dcoff_type(dcoff_type), | ||
m_params_to(pref) {} | ||
m_params_to(pref), | ||
m_enable_transpose(enable_transpose) {} | ||
|
||
void DCOFFPassBase::build() { | ||
paramA = opp::wrap_type<ov::op::v0::Parameter>(); | ||
|
@@ -327,10 +459,14 @@ namespace SymmZP { | |
// V > | ||
// | ||
|
||
DCOFFPassBase::DCOFFPassBase(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) | ||
DCOFFPassBase::DCOFFPassBase(DCOffMode dcoff_mode, | ||
ov::element::Type dcoff_type, | ||
DCOFFParamRef pref, | ||
bool enable_transpose) | ||
: m_dcoff_mode(dcoff_mode), | ||
m_dcoff_type(dcoff_type), | ||
m_params_to(pref) {} | ||
m_params_to(pref), | ||
m_enable_transpose(enable_transpose) {} | ||
|
||
void DCOFFPassBase::build() { | ||
paramA = opp::wrap_type<ov::op::v0::Parameter>(); | ||
|
@@ -342,6 +478,7 @@ void DCOFFPassBase::build() { | |
mulply = opp::wrap_type<ov::op::v1::Multiply>({subtr, paramC}); | ||
} | ||
|
||
|
||
bool DCOFFPassBase::matcher_callback(ov::pass::pattern::Matcher& m) { | ||
auto& node_to_output = m.get_pattern_value_map(); | ||
auto matched_nodeA = node_to_output.at(paramA).get_node_shared_ptr(); | ||
|
@@ -362,6 +499,17 @@ bool DCOFFPassBase::matcher_callback(ov::pass::pattern::Matcher& m) { | |
LOG_DEBUG("Matched: " << matched_paramA << ", set element type to " << m_dcoff_type); | ||
matched_paramA->set_element_type(m_dcoff_type); | ||
|
||
auto matched_MM = pattern_utils::get_root_matmul(m); | ||
const bool need_transpose = pattern_utils::transpose_required(matched_MM); | ||
if (m_enable_transpose && need_transpose) { | ||
m_params_to.get().transpose_required.insert(matched_paramA); | ||
m_params_to.get().transpose_required.insert(matched_paramC); | ||
pattern_utils::transpose_param_shape(matched_paramA); | ||
pattern_utils::transpose_param_shape(matched_paramC); | ||
matched_MM->set_transpose_b(true); | ||
} | ||
|
||
|
||
if (m_dcoff_mode == DCOffMode::CAST_SCALE) { | ||
NPUW_ASSERT(m_dcoff_type == ov::element::f16); | ||
|
||
|
@@ -464,7 +612,7 @@ void DCOFFPassConvert1::reconnect_root(ov::pass::pattern::Matcher& m) { | |
// V > | ||
// | ||
|
||
DCOFFPassReshape2::DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) { | ||
DCOFFPassReshape2::DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref, bool enable_transpose) { | ||
auto paramA = opp::wrap_type<ov::op::v0::Parameter>(); | ||
auto constB = opp::wrap_type<ov::op::v0::Constant>(); | ||
auto paramC = opp::wrap_type<ov::op::v0::Parameter>(); | ||
|
@@ -555,7 +703,7 @@ DCOFFPassReshape2::DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dco | |
// V > | ||
// Convert | ||
|
||
DCOFFPassCWAI3::DCOFFPassCWAI3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) { | ||
DCOFFPassReshape3::DCOFFPassReshape3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref, bool enable_transpose) { | ||
auto paramA = opp::wrap_type<ov::op::v0::Parameter>(); | ||
auto paramC = opp::wrap_type<ov::op::v0::Parameter>(); | ||
auto cvtA = opp::wrap_type<ov::op::v0::Convert>({paramA}); | ||
|
@@ -847,7 +995,7 @@ namespace AsymmZP { | |
// : > | ||
// V > | ||
// | ||
DCOFFPassReshape::DCOFFPassReshape(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) { | ||
DCOFFPassReshape::DCOFFPassReshape(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref, bool enable_transpose) { | ||
auto paramA = opp::wrap_type<ov::op::v0::Parameter>(); | ||
auto paramB = opp::wrap_type<ov::op::v0::Parameter>(); | ||
auto paramC = opp::wrap_type<ov::op::v0::Parameter>(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.