Skip to content

Commit

Permalink
Revert MLFloat16 changes (except checker fix and test name typos)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreeve committed Jul 29, 2024
1 parent 51ed042 commit 4abbd41
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 118 deletions.
16 changes: 6 additions & 10 deletions onnxruntime/core/providers/cpu/math/element_wise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,11 +757,9 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) {
EigenVectorArrayMap<Eigen::half> output_vec_map(output, num_elements);

if (is_min) {
output_vec_map = input_1_vec_map.template min<Eigen::PropagateNaN>(
static_cast<Eigen::half>(per_iter_bh.ScalarInput0<MLFloat16>()));
output_vec_map = input_1_vec_map.min(static_cast<Eigen::half>(per_iter_bh.ScalarInput0<MLFloat16>()));
} else {
output_vec_map = input_1_vec_map.template max<Eigen::PropagateNaN>(
static_cast<Eigen::half>(per_iter_bh.ScalarInput0<MLFloat16>()));
output_vec_map = input_1_vec_map.max(static_cast<Eigen::half>(per_iter_bh.ScalarInput0<MLFloat16>()));
}
},
[](BroadcastHelper& per_iter_bh) {
Expand All @@ -774,11 +772,9 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) {
EigenVectorArrayMap<Eigen::half> output_vec_map(output, num_elements);

if (is_min) {
output_vec_map = input_0_vec_map.template min<Eigen::PropagateNaN>(
static_cast<Eigen::half>(per_iter_bh.ScalarInput1<MLFloat16>()));
output_vec_map = input_0_vec_map.min(static_cast<Eigen::half>(per_iter_bh.ScalarInput1<MLFloat16>()));
} else {
output_vec_map = input_0_vec_map.template max<Eigen::PropagateNaN>(
static_cast<Eigen::half>(per_iter_bh.ScalarInput1<MLFloat16>()));
output_vec_map = input_0_vec_map.max(static_cast<Eigen::half>(per_iter_bh.ScalarInput1<MLFloat16>()));
}
},
[](BroadcastHelper& per_iter_bh) {
Expand All @@ -794,9 +790,9 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) {
EigenVectorArrayMap<Eigen::half> output_vec_map(output, num_elements);

if (is_min) {
output_vec_map = input_0_vec_map.template min<Eigen::PropagateNaN>(input_1_vec_map);
output_vec_map = input_0_vec_map.min(input_1_vec_map);
} else {
output_vec_map = input_0_vec_map.template max<Eigen::PropagateNaN>(input_1_vec_map);
output_vec_map = input_0_vec_map.max(input_1_vec_map);
}
}};

Expand Down
108 changes: 0 additions & 108 deletions onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1787,60 +1787,6 @@ TEST(MathOpTest, Min_12_MLFloat16_Scalar1) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Min_12_MLFloat16_Nan) {
OpTester test("Min", 12);
test.AddInput<MLFloat16>("data_2", {3, 3},
MakeMLFloat16({std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
-0.5f, 0.0f, -2.0f,
0.5f, 0.0f, 2.0f}));
test.AddInput<MLFloat16>("data_1", {3, 1},
MakeMLFloat16({0.0f, -1.0f, 1.0f}));
test.AddOutput<MLFloat16>("min", {3, 3},
MakeMLFloat16({std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
-1.0f, -1.0f, -2.0f,
0.5f, 0.0f, 1.0f}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Min_12_MLFloat16_Nan_with_scalar) {
OpTester test("Min", 12);
test.AddInput<MLFloat16>("data_1", {3, 1},
MakeMLFloat16({std::numeric_limits<float>::quiet_NaN(), -0.5f, 0.5f}));
test.AddInput<MLFloat16>("data_2", {1}, MakeMLFloat16({0.25f}));
test.AddOutput<MLFloat16>("min", {3, 1},
MakeMLFloat16({std::numeric_limits<float>::quiet_NaN(), -0.5f, 0.25f}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Min_12_MLFloat16_with_scalar_Nan) {
OpTester test("Min", 12);
test.AddInput<MLFloat16>("data_1", {2, 2},
MakeMLFloat16({0.25f, -0.25f, -0.5f, 0.5f}));
test.AddInput<MLFloat16>("data_2", {1}, MakeMLFloat16({std::numeric_limits<float>::quiet_NaN()}));
test.AddOutput<MLFloat16>("min", {2, 2},
MakeMLFloat16({std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN()}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_6) {
OpTester test("Max", 6);
std::vector<int64_t> dims{3, 3};
Expand Down Expand Up @@ -2191,60 +2137,6 @@ TEST(MathOpTest, Max_12_MLFloat16_Scalar1) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Max_12_MLFloat16_Nan) {
OpTester test("Max", 12);
test.AddInput<MLFloat16>("data_2", {3, 3},
MakeMLFloat16({std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
-0.5f, 0.0f, -2.0f,
0.5f, 0.0f, 2.0f}));
test.AddInput<MLFloat16>("data_1", {3, 1},
MakeMLFloat16({0.0f, -1.0f, 1.0f}));
test.AddOutput<MLFloat16>("max", {3, 3},
MakeMLFloat16({std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
-0.5f, 0.0f, -1.0f,
1.0f, 1.0f, 2.0f}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_12_MLFloat16_Nan_with_scalar) {
OpTester test("Max", 12);
test.AddInput<MLFloat16>("data_1", {3, 1},
MakeMLFloat16({std::numeric_limits<float>::quiet_NaN(), -0.5f, 0.5f}));
test.AddInput<MLFloat16>("data_2", {1}, MakeMLFloat16({0.25f}));
test.AddOutput<MLFloat16>("max", {3, 1},
MakeMLFloat16({std::numeric_limits<float>::quiet_NaN(), 0.25f, 0.5f}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_12_MLFloat16_with_scalar_Nan) {
OpTester test("Max", 12);
test.AddInput<MLFloat16>("data_1", {2, 2},
MakeMLFloat16({0.25f, -0.25f, -0.5f, 0.5f}));
test.AddInput<MLFloat16>("data_2", {1}, MakeMLFloat16({std::numeric_limits<float>::quiet_NaN()}));
test.AddOutput<MLFloat16>("max", {2, 2},
MakeMLFloat16({std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN()}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Not) {
OpTester test("Not");
std::vector<int64_t> dims{2};
Expand Down

0 comments on commit 4abbd41

Please sign in to comment.