Skip to content

Commit

Permalink
Merge pull request #7 from ynimmaga/dyn_shapes
Browse files Browse the repository at this point in the history
Dynamic shape support for llama3
  • Loading branch information
cavusmustafa authored Jun 12, 2024
2 parents 924b368 + 1d545b2 commit 063b33b
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def __init__(self, options):
"torch.ops.aten.transpose.int": None,
"torch.ops.aten.tril.default": None,
"torch.ops.aten.tril_.default": None,
"torch.ops.aten.triu.default": None,
"torch.ops.aten.unbind.int": None,
"torch.ops.aten.unfold.default": None,
"torch.ops.aten.unsqueeze.default": None,
Expand Down
33 changes: 21 additions & 12 deletions src/frontends/pytorch/src/op/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/abs.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/shape_of.hpp"
#include "utils.hpp"

Expand Down Expand Up @@ -46,23 +47,31 @@ OutputVector translate_expand_fx(const NodeContext& context) {
num_inputs_check(context, 2, num_inputs);
auto x = context.get_input(0);
std::vector<int32_t> shape_vec;
auto sizes = context.get_input(1);
if (num_inputs != 2) {
if (context.get_input_type(1).is<type::List>()) {
std::deque<Output<Node>> list_elems;
for (size_t i = 1; i < num_inputs; i++) {
auto a = context.get_input_from_visible_context(i).get_node_shared_ptr();
auto shape_input = context.get_input(static_cast<int>(i));
if (std::dynamic_pointer_cast<ov::op::v0::Parameter>(a) ||
shape_input.get_partial_shape().rank().is_dynamic() ||
shape_input.get_partial_shape().rank().get_length() == 0) {
shape_vec.push_back(-1);
if (context.get_input_type(i).as<type::List>().element_type.is<type::PyScalar>()) {
auto const_val = context.const_input<int32_t>(i);
std::vector<int32_t> dim_vec;
dim_vec.push_back(const_val);
auto dim_const = ov::op::v0::Constant::create(element::i32, Shape{1}, dim_vec);
list_elems.push_back(dim_const);
} else {
auto val = context.const_input<int32_t>(i);
shape_vec.push_back(val);
auto converted_dim = context.mark_node(std::make_shared<ov::op::v0::Convert>(context.get_input(static_cast<int>(i)), element::i32));
list_elems.push_back(converted_dim);
}
}
sizes = ov::op::v0::Constant::create(element::i32, Shape{num_inputs - 1}, shape_vec);
auto concat = std::make_shared<ov::op::v0::Concat>(OutputVector(list_elems.begin(), list_elems.end()), 0);
return base_expand(context, x, concat);
} else {
auto x = context.get_input(0);
auto sizes = context.get_input(1);
// TODO: figure out what implicit means
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(2) || context.const_input<bool>(2) == false,
"Unexpected value of implicit for expand operation");
return base_expand(context, x, sizes);
}
return base_expand(context, x, sizes);

};

} // namespace op
Expand Down
1 change: 1 addition & 0 deletions src/frontends/pytorch/src/op/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/power.hpp"
Expand Down
4 changes: 2 additions & 2 deletions src/frontends/pytorch/src/op/slice_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ OutputVector translate_slice_scatter_fx(const NodeContext& context) {
ov::Output<ov::Node> end;
if (!context.input_is_none(4)) {
end = context.get_input(4);
if (end.get_partial_shape().rank().is_dynamic() || end.get_partial_shape().rank().get_length() == 0) {
if (!(end.get_partial_shape().rank().is_dynamic()) && end.get_partial_shape().rank().get_length() == 0) {
end = context.mark_node(std::make_shared<v0::Unsqueeze>(end, axis_0));
}
} else {
Expand All @@ -65,4 +65,4 @@ OutputVector translate_slice_scatter_fx(const NodeContext& context) {
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
3 changes: 2 additions & 1 deletion src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten._scaled_dot_product_flash_attention_for_cpu.default", op::translate_scaled_dot_product_attention_fx},
{"aten._softmax.default", op::translate_softmax_fx},
{"aten._to_copy.default", op::translate_to_fx},
{"aten._unsafe_view.default", op::translate_reshape},
{"aten._unsafe_view.default", op::translate_reshape_fx},
{"aten.abs.default", op::translate_1to1_match_1_inputs<opset10::Abs>},
{"aten.acos.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acos>},
{"aten.acosh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acosh>},
Expand Down Expand Up @@ -963,6 +963,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.topk.default", op::translate_topk_fx},
{"aten.transpose.int", op::translate_transpose},
{"aten.tril.default", op::translate_tril},
{"aten.triu.default", op::translate_triu},
{"aten.unbind.int", op::translate_unbind_int_fx},
{"aten.unfold.default", op::translate_unfold},
{"aten.unsqueeze.default", op::translate_1to1_match_2_inputs<opset10::Unsqueeze>},
Expand Down

0 comments on commit 063b33b

Please sign in to comment.