Skip to content

Commit

Permalink
Drop legacy compare
Browse files Browse the repository at this point in the history
  • Loading branch information
t-jankowski committed Sep 19, 2024
1 parent 327f8e2 commit cf4a4f2
Showing 1 changed file with 37 additions and 57 deletions.
94 changes: 37 additions & 57 deletions src/plugins/template/tests/functional/op_reference/mish.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <gtest/gtest.h>

#include <cmath>
#include <random>

#include "base_reference_test.hpp"
Expand All @@ -16,24 +17,22 @@ using namespace ov;
namespace {
struct MishParams {
template <class IT>
MishParams(const ov::PartialShape& dynamicShape,
const ov::Shape& inputShape,
MishParams(const ov::PartialShape& parameterShape,
const ov::Shape& tensorShape,
const ov::element::Type& iType,
const std::vector<IT>& iValues,
const std::vector<IT>& oValues,
const std::string& test_name = "")
: dynamicShape(dynamicShape),
inputShape(inputShape),
: parameterShape(parameterShape),
tensorShape(tensorShape),
inType(iType),
outType(iType),
inputData(CreateTensor(inputShape, iType, iValues)),
refData(CreateTensor(inputShape, iType, oValues)),
inputData(CreateTensor(tensorShape, iType, iValues)),
refData(CreateTensor(tensorShape, iType, oValues)),
testcaseName(test_name) {}

ov::PartialShape dynamicShape;
ov::PartialShape inputShape;
ov::PartialShape parameterShape;
ov::Shape tensorShape;
ov::element::Type inType;
ov::element::Type outType;
ov::Tensor inputData;
ov::Tensor refData;
std::string testcaseName;
Expand All @@ -42,32 +41,27 @@ struct MishParams {
class ReferenceMishLayerTest : public testing::TestWithParam<MishParams>, public CommonReferenceTest {
public:
void SetUp() override {
legacy_compare = true;
auto params = GetParam();
function = CreateFunction(params.dynamicShape, params.inType);
const auto& params = GetParam();
function = CreateFunction(params.parameterShape, params.inType);
inputData = {params.inputData};
refOutData = {params.refData};
}
static std::string getTestCaseName(const testing::TestParamInfo<MishParams>& obj) {
auto param = obj.param;
const auto& param = obj.param;
std::ostringstream result;
result << "dShape=" << param.dynamicShape << "_";
result << "iShape=" << param.inputShape << "_";
result << "iType=" << param.inType << "_";
if (param.testcaseName != "") {
result << "oType=" << param.outType << "_";
result << param.testcaseName;
} else {
result << "oType=" << param.outType;
}
result << "dShape=" << param.parameterShape << "_";
result << "iShape=" << param.tensorShape << "_";
result << "iType=" << param.inType;
if (!param.testcaseName.empty())
result << "_" << param.testcaseName;
return result.str();
}

private:
static std::shared_ptr<Model> CreateFunction(const PartialShape& input_shape, const element::Type& input_type) {
const auto in = std::make_shared<op::v0::Parameter>(input_type, input_shape);
const auto Mish = std::make_shared<op::v4::Mish>(in);
return std::make_shared<ov::Model>(NodeVector{Mish}, ParameterVector{in});
return std::make_shared<Model>(NodeVector{Mish}, ParameterVector{in});
}
};

Expand All @@ -76,57 +70,43 @@ TEST_P(ReferenceMishLayerTest, CompareWithRefs) {
}

template <element::Type_t IN_ET>
std::vector<MishParams> generateMishFloatParams(const PartialShape& dynamicShape,
const Shape& staticShape,
const std::string& test_name = "") {
MishParams generateMishFloatParams(const PartialShape& parameterShape,
const Shape& tensorShape,
const std::string& test_name = "") {
using T = typename element_type_traits<IN_ET>::value_type;

// generate input tensor (with possible type conversion)
auto staticSize = shape_size(staticShape);
const auto staticSize = shape_size(tensorShape);
std::vector<T> expected;
std::vector<T> input;
{
std::mt19937 gen{0}; // use fixed seed for reproducibility of the test
std::normal_distribution<> d{0.0, 20.0};

for (auto i = staticSize; i > 0; i--) {
auto x = static_cast<T>(d(gen));
auto y = static_cast<T>(static_cast<double>(x) * std::tanh(std::log(1.0 + std::exp(x))));
const auto x = static_cast<T>(d(gen));
const auto y = static_cast<T>(x * std::tanh(std::log(std::exp(x) + T{1})));
input.push_back(x);
expected.push_back(y);
}
}

std::vector<MishParams> mishParams;

if (test_name != "") {
mishParams = {MishParams(dynamicShape, staticShape, IN_ET, input, expected, test_name)};
} else {
mishParams = {MishParams(dynamicShape, staticShape, IN_ET, input, expected)};
}
return mishParams;
return MishParams{parameterShape, tensorShape, IN_ET, input, expected, test_name};
}

std::vector<MishParams> generateMishCombinedParams() {
const std::vector<std::vector<MishParams>> mishTypeParams{
generateMishFloatParams<element::Type_t::f32>({2, 5}, {2, 5}),
generateMishFloatParams<element::Type_t::f32>({2, 3, 4, 5}, {2, 3, 4, 5}),
generateMishFloatParams<element::Type_t::f32>(PartialShape::dynamic(), {2, 3, 4, 5}),
generateMishFloatParams<element::Type_t::f32>({2, Dimension::dynamic(), 4, 5},
{2, 3, 4, 5},
"dimensionDynamic"),
generateMishFloatParams<element::Type_t::f16>({2, 5}, {2, 5}),
generateMishFloatParams<element::Type_t::f16>({2, 3, 4, 5}, {2, 3, 4, 5}),
generateMishFloatParams<element::Type_t::f16>(PartialShape::dynamic(), {2, 3, 4, 5}),
generateMishFloatParams<element::Type_t::f16>({2, Dimension::dynamic(), 4, 5},
{2, 3, 4, 5},
"dimensionDynamic")};
std::vector<MishParams> combinedParams;

for (const auto& params : mishTypeParams) {
combinedParams.insert(combinedParams.end(), params.begin(), params.end());
}
return combinedParams;
return std::vector<MishParams>{generateMishFloatParams<element::Type_t::f32>({2, 5}, {2, 5}),
generateMishFloatParams<element::Type_t::f32>({2, 3, 4, 5}, {2, 3, 4, 5}),
generateMishFloatParams<element::Type_t::f32>(PartialShape::dynamic(), {2, 3, 4, 5}),
generateMishFloatParams<element::Type_t::f32>({2, Dimension::dynamic(), 4, 5},
{2, 3, 4, 5},
"dimensionDynamic"),
generateMishFloatParams<element::Type_t::f16>({2, 5}, {2, 5}),
generateMishFloatParams<element::Type_t::f16>({2, 3, 4, 5}, {2, 3, 4, 5}),
generateMishFloatParams<element::Type_t::f16>(PartialShape::dynamic(), {2, 3, 4, 5}),
generateMishFloatParams<element::Type_t::f16>({2, Dimension::dynamic(), 4, 5},
{2, 3, 4, 5},
"dimensionDynamic")};
}

INSTANTIATE_TEST_SUITE_P(smoke_Mish_With_Hardcoded_Refs,
Expand Down

0 comments on commit cf4a4f2

Please sign in to comment.