Skip to content

Commit

Permalink
[ONNX] Added support for dynamic input shapes in com.microsoft.MatMul…
Browse files Browse the repository at this point in the history
…NBits
  • Loading branch information
gkrivor authored Oct 3, 2024
1 parent 17ecf03 commit a57340d
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
"Expected rank of quantized weights is 3 [N][n_blocks_per_col][blob_size], got: ",
b_quantized.get_partial_shape().rank());
CHECK_VALID_NODE(node,
a.get_element_type() == ov::element::f16 || a.get_element_type() == ov::element::f32,
"Unsupported input A type, accepted FP16, FP32, got: ",
a.get_element_type() == ov::element::f16 || a.get_element_type() == ov::element::f32 ||
a.get_element_type() == ov::element::dynamic,
"Unsupported input A type, accepted dynamic, FP16, FP32, got: ",
a.get_element_type());
CHECK_VALID_NODE(
node,
Expand Down Expand Up @@ -96,7 +97,9 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
if (inputs.size() > 5) {
bias = inputs[5];
CHECK_VALID_NODE(node,
bias.get_element_type() == a.get_element_type(),
bias.get_element_type() == a.get_element_type() ||
a.get_element_type() == ov::element::dynamic ||
bias.get_element_type() == ov::element::dynamic,
"Unsupported input bias type, must be equal to input A type, got: ",
bias.get_element_type());
CHECK_VALID_NODE(node,
Expand All @@ -121,17 +124,35 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
case 2:
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 4)};
casted_b = std::make_shared<v0::Constant>(ov::element::u2, casted_b_shape, b_const->get_data_ptr());
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 2);
if (a.get_element_type() != ov::element::dynamic) {
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 2);
} else {
default_zp =
std::make_shared<v1::ConvertLike>(a,
std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 2.f));
}
break;
case 4:
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 2)};
casted_b = std::make_shared<v0::Constant>(ov::element::u4, casted_b_shape, b_const->get_data_ptr());
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 8);
if (a.get_element_type() != ov::element::dynamic) {
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 8);
} else {
default_zp =
std::make_shared<v1::ConvertLike>(a,
std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 8.f));
}
break;
case 8:
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size)};
casted_b = op::util::reshape(b_const, casted_b_shape);
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 128);
if (a.get_element_type() != ov::element::dynamic) {
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 128);
} else {
default_zp =
std::make_shared<v1::ConvertLike>(a,
std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 128.f));
}
break;
default:
FRONT_END_THROW("Unsupported bits count");
Expand Down

0 comments on commit a57340d

Please sign in to comment.