diff --git a/src/common/snippets/src/shape_inference/shape_infer_instances.cpp b/src/common/snippets/src/shape_inference/shape_infer_instances.cpp index c456b6e2ba0254..1c06e5e4df17e4 100644 --- a/src/common/snippets/src/shape_inference/shape_infer_instances.cpp +++ b/src/common/snippets/src/shape_inference/shape_infer_instances.cpp @@ -197,16 +197,9 @@ Result BrgemmShapeInfer::infer(const std::vector& input_shapes) { size_t max_rank = arg0_shape_tmp.size(); VectorDims output_shape(max_rank); for (size_t i = 0; i < max_rank - 2; ++i) { - if (arg0_shape_tmp[i] == arg1_shape_tmp[i]) { - output_shape[i] = arg0_shape_tmp[i]; - } else { - if (arg0_shape_tmp[i] == 1 || utils::is_dynamic_value(arg0_shape_tmp[i])) - output_shape[i] = arg1_shape_tmp[i]; - else if (arg1_shape_tmp[i] == 1 || utils::is_dynamic_value(arg1_shape_tmp[i])) - output_shape[i] = arg0_shape_tmp[i]; - else - OPENVINO_THROW("Incompatible Brgemm batch dimension"); - } + OPENVINO_ASSERT(utils::broadcast_merge_dim(output_shape[i], arg0_shape_tmp[i], arg1_shape_tmp[i]), + "Incompatible MatMul batch dimension. Can't merge dim ", arg0_shape_tmp[i], + " with dim ", arg1_shape_tmp[i], " at index=", i); } output_shape[output_shape.size() - 2] = arg0_shape_tmp[arg0_shape_tmp.size() - 2]; // M output_shape[output_shape.size() - 1] = arg1_shape_tmp[arg1_shape_tmp.size() - 1]; // N