diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index e95ac44af158..6a71283f9dbd 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -238,6 +238,14 @@ Status MatMul::Compute(OpKernelContext* ctx) const { if (y->Shape().Size() == 0) return Status::OK(); + if (helper.K() == 0) { + // When we have (M, 0, N) then the inputs are empty, but the output should + // be filled out with zeros. + auto output_span = y->MutableDataAsSpan(); + std::fill(output_span.begin(), output_span.end(), float{}); + return Status::OK(); + } + const auto* a_data = a->Data(); const auto* b_data = b ? b->Data() : nullptr; auto* y_data = y->MutableData(); diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index b7ae0a9f0d66..90370560859a 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -219,18 +219,16 @@ TEST(MathOpTest, MatMulUint64Type) { RunMatMulTest(9); } -TEST(MathOpTest, MatMul_ZeroK) { +template +void RunMatMulZeroKTest() { // test with empty inputs and zero filled output - constexpr const std::array empty_input{}; - const std::vector expected_output{0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0}; - OpTester test("MatMul", 14); + constexpr const std::array empty_input{}; + const std::vector expected_output(4 * 4, T{}); + OpTester test("MatMul", 13); - test.AddInput("A", {4, 0}, empty_input); - test.AddInput("B", {0, 4}, empty_input); - test.AddOutput("Y", {4, 4}, expected_output); + test.AddInput("A", {4, 0}, empty_input); + test.AddInput("B", {0, 4}, empty_input); + test.AddOutput("Y", {4, 4}, expected_output); // No special case is implemented. test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider, @@ -240,6 +238,14 @@ TEST(MathOpTest, MatMul_ZeroK) { .RunWithConfig(); } +TEST(MathOpTest, MatMulZeroKFloatType) { + RunMatMulZeroKTest(); +} + +TEST(MathOpTest, MatMulZeroKInt32Type) { + RunMatMulZeroKTest(); +} + #if defined(USE_CUDA) || defined(USE_ROCM) TEST(MathOpTest, MatMul_Float16) { #ifdef USE_CUDA