Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added embedding_bag and fixed unbind int #4

Merged
merged 2 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, options):
"torch.ops.aten._adaptive_avg_pool2d.default": None,
"torch.ops.aten._adaptive_avg_pool3d.default": None,
"torch.ops.aten._convolution.default": None,
"torch.ops.aten._embedding_bag.default": None,
"torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default": None,
"torch.ops.aten._local_scalar_dense.default": None,
"torch.ops.aten._log_softmax.default": None,
Expand Down
20 changes: 16 additions & 4 deletions src/frontends/pytorch/src/op/embedding_bag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_embedding_bag(const NodeContext& context) {
OutputVector translate_embedding_bag_common(const NodeContext& context) {
// aten::embedding_bag(weight, input, offsets=None, scale_grad_by_freq=False, mode_enum=1, sparse=False,
// per_sample_weights=None, include_last_offset=False, padding_idx=None)
num_inputs_check(context, 9, 9);
// we have only EmbeddingBagSum case support, check it before translation
auto mode = context.const_input<int64_t>(4);
PYTORCH_OP_CONVERSION_CHECK(mode == 0, "Only sum mode supported for aten::embedding_bag translation");
Expand All @@ -43,7 +42,9 @@ OutputVector translate_embedding_bag(const NodeContext& context) {
// with offsets case
auto offsets = context.get_input(2);
offsets = context.mark_node(std::make_shared<ov::op::v0::Convert>(offsets, element::i32));
auto include_last_offset = context.const_input<bool>(7);
bool include_last_offset = false;
if (!context.input_is_none(7))
include_last_offset = context.const_input<bool>(7);
PYTORCH_OP_CONVERSION_CHECK(!include_last_offset, "Inclusion last offset is not supported");
// no per_sample_wights
if (context.input_is_none(6)) {
Expand All @@ -63,7 +64,18 @@ OutputVector translate_embedding_bag(const NodeContext& context) {
return {result, zero, zero, zero};
};

OutputVector translate_embedding_bag(const NodeContext& context) {
num_inputs_check(context, 9, 9);
return translate_embedding_bag_common(context);
}

OutputVector translate_embedding_bag_fx(const NodeContext& context) {
num_inputs_check(context, 7, 9);
ov::OutputVector output = translate_embedding_bag_common(context);
return {context.mark_node(make_list_construct(output))};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
14 changes: 10 additions & 4 deletions src/frontends/pytorch/src/op/split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,18 @@ OutputVector translate_chunk_fx(const NodeContext& context) {
}

OutputVector translate_unbind_int_fx(const NodeContext& context) {
num_inputs_check(context, 2, 3);
num_inputs_check(context, 1, 3);
auto input = context.get_input(0);
auto dim = context.get_input(1);
auto dim_val = context.const_input<int>(1);
Output<Node> dim;
int64_t dim_val = 0;
if (context.input_is_none(1)) {
dim = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
} else {
dim = context.get_input(1);
dim_val = context.const_input<int>(1);
}

cavusmustafa marked this conversation as resolved.
Show resolved Hide resolved
auto shape = input.get_shape();

if (dim_val < 0) {
dim_val = static_cast<int>(shape.size()) + dim_val;
}
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ OP_CONVERTER(translate_constant_pad_nd_fx);
OP_CONVERTER(translate_cumsum_fx);
OP_CONVERTER(translate_chunk_fx);
OP_CONVERTER(translate_div_fx);
OP_CONVERTER(translate_embedding_bag_fx);
OP_CONVERTER(translate_expand_fx);
OP_CONVERTER(translate_fake_quantize_per_channel_affine_fx);
OP_CONVERTER(translate_fake_quantize_per_tensor_affine_fx);
Expand Down Expand Up @@ -691,6 +692,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten._adaptive_avg_pool2d.default", op::translate_adaptive_avg_pool2d},
{"aten._adaptive_avg_pool3d.default", op::translate_adaptive_avg_pool3d},
{"aten._convolution.default", op::translate_convolution},
{"aten._embedding_bag.default", op::translate_embedding_bag_fx},
{"aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default",
op::translate_fake_quantize_per_tensor_affine_fx},
{"aten._local_scalar_dense.default", op::skip_node},
Expand Down
Loading