Skip to content

Commit

Permalink
Fix Brgemm shape inference
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Sep 18, 2024
1 parent c86c9ae commit f8b7baa
Showing 1 changed file with 3 additions and 10 deletions.
13 changes: 3 additions & 10 deletions src/common/snippets/src/shape_inference/shape_infer_instances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,9 @@ Result BrgemmShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
size_t max_rank = arg0_shape_tmp.size();
VectorDims output_shape(max_rank);
for (size_t i = 0; i < max_rank - 2; ++i) {
if (arg0_shape_tmp[i] == arg1_shape_tmp[i]) {
output_shape[i] = arg0_shape_tmp[i];
} else {
if (arg0_shape_tmp[i] == 1 || utils::is_dynamic_value(arg0_shape_tmp[i]))
output_shape[i] = arg1_shape_tmp[i];
else if (arg1_shape_tmp[i] == 1 || utils::is_dynamic_value(arg1_shape_tmp[i]))
output_shape[i] = arg0_shape_tmp[i];
else
OPENVINO_THROW("Incompatible Brgemm batch dimension");
}
OPENVINO_ASSERT(utils::broadcast_merge_dim(output_shape[i], arg0_shape_tmp[i], arg1_shape_tmp[i]),
"Incompatible MatMul batch dimension. Can't merge dim ", arg0_shape_tmp[i],
" with dim ", arg1_shape_tmp[i], " at index=", i);
}
output_shape[output_shape.size() - 2] = arg0_shape_tmp[arg0_shape_tmp.size() - 2]; // M
output_shape[output_shape.size() - 1] = arg1_shape_tmp[arg1_shape_tmp.size() - 1]; // N
Expand Down

0 comments on commit f8b7baa

Please sign in to comment.