Skip to content

Commit

Permalink
[PT FE] Use GroupNormalization for aten::group_norm instead of MVN (#…
Browse files Browse the repository at this point in the history
…26063)

### Details:
 - *Use GroupNormalization for `aten::group_norm` instead of MVN*

### Tickets:
 - *ticket-id*
  • Loading branch information
mvafin authored Aug 14, 2024
1 parent aa4455a commit 5a119fb
Showing 1 changed file with 24 additions and 27 deletions.
51 changes: 24 additions & 27 deletions src/frontends/pytorch/src/op/group_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/mvn.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/group_normalization.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "utils.hpp"

namespace ov {
Expand All @@ -33,30 +31,29 @@ OutputVector translate_group_norm_common(const NodeContext& context,
auto data = context.get_input(0);
auto num_groups = context.const_input<int64_t>(group_idx);
// input 2 - weights and input 3 - bias are optional without default value, we handle them later
auto eps = static_cast<float>(context.const_input<double>(eps_idx));
Output<Node> input_shape;
Output<Node> input_rank;
std::tie(input_shape, input_rank) = get_shape_rank(context, data, true, element::i32);
auto scalar_one = context.mark_node(v0::Constant::create(element::i32, {}, {1}));
auto shape = context.mark_node(
std::make_shared<v0::Constant>(element::i32, Shape({3}), std::vector<int64_t>{0, num_groups, -1}));
auto reshaped_input = context.mark_node(std::make_shared<v1::Reshape>(data, shape, true));
auto reduction_axes = context.mark_node(v0::Constant::create(element::i32, Shape({1}), std::vector<int64_t>(1, 2)));
auto reshaped_norm = context.mark_node(
std::make_shared<v6::MVN>(reshaped_input, reduction_axes, true, eps, MVNEpsMode::INSIDE_SQRT));
auto norm = context.mark_node(std::make_shared<v1::Reshape>(reshaped_norm, input_shape, true));
auto skip_last = context.mark_node(std::make_shared<v1::Subtract>(input_rank, scalar_one));
auto axes = context.mark_node(std::make_shared<v4::Range>(scalar_one, skip_last, scalar_one, element::i32));
auto eps = context.const_input<double>(eps_idx);

auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(data, element::i32));
auto channels = context.mark_node(std::make_shared<v8::Gather>(shape, one, zero));
channels = context.mark_node(std::make_shared<v0::Unsqueeze>(channels, zero));

Output<Node> scale;
if (!context.input_is_none(weights_idx)) {
auto weights = context.get_input(static_cast<int>(weights_idx));
weights = context.mark_node(std::make_shared<v0::Unsqueeze>(weights, axes));
norm = context.mark_node(std::make_shared<v1::Multiply>(norm, weights));
scale = context.get_input(static_cast<int>(weights_idx));
} else {
scale = context.mark_node(std::make_shared<v3::Broadcast>(one, channels));
scale = context.mark_node(std::make_shared<v1::ConvertLike>(scale, data));
}
Output<Node> bias;
if (!context.input_is_none(bias_idx)) {
auto bias = context.get_input(static_cast<int>(bias_idx));
bias = context.mark_node(std::make_shared<v0::Unsqueeze>(bias, axes));
norm = context.mark_node(std::make_shared<v1::Add>(norm, bias));
bias = context.get_input(static_cast<int>(bias_idx));
} else {
bias = context.mark_node(std::make_shared<v3::Broadcast>(zero, channels));
bias = context.mark_node(std::make_shared<v1::ConvertLike>(bias, data));
}
auto norm = context.mark_node(std::make_shared<v12::GroupNormalization>(data, scale, bias, num_groups, eps));
// Input with index 5 is flag "cudnn_enabled" we can ignore it
return {norm};
};
Expand Down

0 comments on commit 5a119fb

Please sign in to comment.