Skip to content

Commit

Permalink
Add K=0 check to MatMul<float>::Compute() specialization. (#21803)
Browse files Browse the repository at this point in the history
Add K=0 check to `MatMul<float>::Compute()` specialization.
Add unit test to cover both primary template and float specialization.
  • Loading branch information
edgchen1 authored Aug 21, 2024
1 parent 0e827c2 commit fb9ce18
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/cpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,14 @@ Status MatMul<float>::Compute(OpKernelContext* ctx) const {
if (y->Shape().Size() == 0)
return Status::OK();

if (helper.K() == 0) {
// When we have (M, 0, N) then the inputs are empty, but the output should
// be filled out with zeros.
auto output_span = y->MutableDataAsSpan<float>();
std::fill(output_span.begin(), output_span.end(), float{});
return Status::OK();
}

const auto* a_data = a->Data<float>();
const auto* b_data = b ? b->Data<float>() : nullptr;
auto* y_data = y->MutableData<float>();
Expand Down
26 changes: 16 additions & 10 deletions onnxruntime/test/providers/cpu/math/matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,18 +219,16 @@ TEST(MathOpTest, MatMulUint64Type) {
RunMatMulTest<uint64_t>(9);
}

TEST(MathOpTest, MatMul_ZeroK) {
template <typename T>
void RunMatMulZeroKTest() {
// test with empty inputs and zero filled output
constexpr const std::array<float, 0> empty_input{};
const std::vector<float> expected_output{0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0};
OpTester test("MatMul", 14);
constexpr const std::array<T, 0> empty_input{};
const std::vector<T> expected_output(4 * 4, T{});
OpTester test("MatMul", 13);

test.AddInput<float>("A", {4, 0}, empty_input);
test.AddInput<float>("B", {0, 4}, empty_input);
test.AddOutput<float>("Y", {4, 4}, expected_output);
test.AddInput<T>("A", {4, 0}, empty_input);
test.AddInput<T>("B", {0, 4}, empty_input);
test.AddOutput<T>("Y", {4, 4}, expected_output);

// No special case is implemented.
test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider,
Expand All @@ -240,6 +238,14 @@ TEST(MathOpTest, MatMul_ZeroK) {
.RunWithConfig();
}

TEST(MathOpTest, MatMulZeroKFloatType) {
RunMatMulZeroKTest<float>();
}

TEST(MathOpTest, MatMulZeroKInt32Type) {
RunMatMulZeroKTest<int32_t>();
}

#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(MathOpTest, MatMul_Float16) {
#ifdef USE_CUDA
Expand Down

0 comments on commit fb9ce18

Please sign in to comment.