Skip to content

Commit

Permalink
Add gather/squeeze/unsqueeze/tile etc ops
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfeiyue-cfy committed Jun 4, 2024
1 parent a47a8ff commit 986d083
Show file tree
Hide file tree
Showing 18 changed files with 565 additions and 49 deletions.
8 changes: 7 additions & 1 deletion onnxruntime/core/framework/node_unit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ void NodeUnit::InitForSingleNode() {
const auto& input_defs = target_node_.InputDefs();
const auto& output_defs = target_node_.OutputDefs();
auto qlinear_type = GetQLinearOpType(target_node_);
if (qlinear_type == QLinearOpType::Unknown || IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support
if (qlinear_type == QLinearOpType::Unknown) {
// Not a Qlinear op, add all inputs / outputs
auto add_all_io = [](std::vector<NodeUnitIODef>& defs,
const ConstPointerContainer<std::vector<NodeArg*>>& node_defs) {
Expand Down Expand Up @@ -334,6 +334,12 @@ void NodeUnit::InitForSingleNode() {
outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[1], input_defs.size() == 3
? input_defs[2]
: nullptr}});
} else if (IsVariadicQLinearOp(qlinear_type)) {
int input_num = (input_defs.size() - 2) / 3;
for (int i = 0; i < input_num; i++) {
inputs_.push_back(NodeUnitIODef{*input_defs[3 * i + 2], NodeUnitIODef::QuantParam{*input_defs[3 * i + 3], input_defs[3 * i + 4]}});

Check warning on line 340 in onnxruntime/core/framework/node_unit.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/node_unit.cc:340: Lines should be <= 120 characters long [whitespace/line_length] [2]
}
outputs_.push_back(NodeUnitIODef{*output_defs[0], NodeUnitIODef::QuantParam{*input_defs[0], input_defs[1]}});
} else {
ORT_THROW("The QLinear op [", static_cast<uint8_t>(qlinear_type), "] is not supported");
}
Expand Down
18 changes: 15 additions & 3 deletions onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ bool BaseOpBuilder::HasSupportedInputOutputs(const InitializedTensorSet& initial
return false;
}

// We do not support dynamic shape input yet
// We do not support dynamic shape input yet, but resize op's second input can be empty cause we not care about this value

Check warning on line 52 in onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.cc:52: Lines should be <= 120 characters long [whitespace/line_length] [2]
for (const auto& dim : shape_proto->dim()) {
if (!dim.has_dim_value()) {
LOGS_DEFAULT(WARNING) << "Dynamic shape is not supported for now, for input:" << node_arg.Name();
return false;
}
if (dim.dim_value() == 0) {
if (dim.dim_value() == 0 && op_type != "Resize") {
LOGS_DEFAULT(WARNING) << "Zero in shape is not supported for now, for input:" << node_arg.Name();
return false;
}
Expand Down Expand Up @@ -91,6 +91,10 @@ bool BaseOpBuilder::HasSupportedInputOutputs(const InitializedTensorSet& initial
return false;
if (!has_initialized_quant_param(*input.quant_param->zero_point, initializers))
return false;
if (input.quant_param->zero_point->Type() != input.node_arg.Type()) {
LOGS_DEFAULT(ERROR) << "Invalid input type because the data type mismatch with its' quant param type.";
return false;
}
}
}
}
Expand All @@ -115,7 +119,7 @@ bool BaseOpBuilder::HasSupportedInputOutputs(const InitializedTensorSet& initial

bool BaseOpBuilder::HasSupportedInputOutputsImpl(
const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit) const {
// Check input data type, int64 is generally unsupported
// Check input/output data type, int64 is generally unsupported
// specific op builder can override this if the int64 input corresponds to VSINPU param
for (const auto& input : node_unit.Inputs()) {
auto input_type = input.node_arg.Type();
Expand All @@ -125,6 +129,14 @@ bool BaseOpBuilder::HasSupportedInputOutputsImpl(
return false;
}
}
for (const auto& output : node_unit.Outputs()) {
auto output_type = output.node_arg.Type();
if (*output_type == "tensor(int64)" || !util::IsTypeSupported(&output.node_arg)) {
LOGS_DEFAULT(WARNING) << node_unit.OpType() << " has unsupported output type : "
<< *output_type;
return false;
}
}
return true;
}

Expand Down
43 changes: 43 additions & 0 deletions onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/****************************************************************************
*
* Copyright (c) 2024 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#include "core/providers/shared/utils/utils.h"
#include "core/providers/vsinpu/builders/impl/base_op_builder.h"
namespace onnxruntime {
namespace vsi {
namespace npu {
class CastOpBuilder : public BaseOpBuilder {
protected:
bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, std::vector<std::shared_ptr<tim::vx::Tensor>>& inputs,
std::vector<std::shared_ptr<tim::vx::Tensor>>& outputs, const NodeUnit& node_unit) override {

Check warning on line 32 in onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h:32: Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4]

Check warning on line 32 in onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/vsinpu/builders/impl/cast_op_builder.h:32: Add #include <vector> for vector<> [build/include_what_you_use] [4]
LOGS_DEFAULT(VERBOSE) << "Creating Cast Op.";
NodeAttrHelper helper(node_unit.GetNode());
auto op = graph_ep->GetGraph()->CreateOperation<tim::vx::ops::DataConvert>();
(*op).BindInput(inputs[0]).BindOutputs(outputs);
return true;
}
};

} // namespace npu
} // namespace vsi
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ namespace onnxruntime {
namespace vsi {
namespace npu {
class ClipOpBuilder final : public BaseOpBuilder {
bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer,
bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer,
const Node* node) const override {
if (node->SinceVersion() > 6) {
if (node->SinceVersion() > 6) {
if (node->InputDefs().size() > 1 && !Contains(graph_viewer.GetAllInitializedTensors(), node->InputDefs()[1]->Name())) {
LOGS_DEFAULT(WARNING) << "Min/Max value must be const input or attribute.";
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class DequantizeLinearOpBuilder : public BaseOpBuilder {
};
bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers,
const NodeUnit& node_unit) const override {

auto input_type = node_unit.Inputs()[0].node_arg.Type();
if (*input_type == "tensor(int64)" || !util::IsTypeSupported(&node_unit.Inputs()[0].node_arg)) {
LOGS_DEFAULT(WARNING) << node_unit.OpType() << " has unsupported input type : "
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/****************************************************************************
*
* Copyright (c) 2024 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#include "core/providers/vsinpu/builders/impl/base_op_builder.h"
#include "core/providers/shared/utils/utils.h"

namespace onnxruntime {
namespace vsi {
namespace npu {
class GatherOpBuilder : public BaseOpBuilder {
bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers,
const NodeUnit& node_unit) const override {
auto input = node_unit.Inputs()[0];
auto indices = node_unit.Inputs()[1];
if (util::IsTypeSupported(&input.node_arg) && util::IsTypeSupported(&indices.node_arg)) {
if (*input.node_arg.Type() == "tensor(int64)") {
LOGS_DEFAULT(WARNING) << "Only support indices tensor to be int64 type in gather op.";
return false;
}
if (*indices.node_arg.Type() != "tensor(int64)" && *indices.node_arg.Type() != "tensor(int32)") {
LOGS_DEFAULT(WARNING) << "Unsupported indices tensor type in gather op.";
return false;
}
if (*indices.node_arg.Type() == "tensor(int64)" && !Contains(initializers, indices.node_arg.Name())) {
LOGS_DEFAULT(WARNING) << "Only support const attribute if indice tensor is in int64 type.";
return false;
}
return true;
}
return false;
}

bool HandleBuildOp(vsi::npu::GraphEP* graph_ep,
std::vector<std::shared_ptr<tim::vx::Tensor>>& inputs,
std::vector<std::shared_ptr<tim::vx::Tensor>>& outputs,

Check warning on line 55 in onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h:55: Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4]
const NodeUnit& node_unit) override {
LOGS_DEFAULT(VERBOSE) << "Creating Gather Op.";
NodeAttrHelper helper(node_unit.GetNode());
auto axis = helper.Get("axis", 0);
axis = util::ReverseAxis(axis, inputs[0]->GetShape().size());
auto op = graph_ep->GetGraph()->CreateOperation<tim::vx::ops::Gather>(axis, 0);

bool is_i64_indices = inputs[1]->GetDataType() == tim::vx::DataType::INT64;
if (!is_i64_indices) {
(*op).BindInputs(inputs).BindOutputs(outputs);
} else {
std::vector<int64_t> origin_data(inputs[1]->GetSpec().GetElementNum());
inputs[1]->CopyDataFromTensor(origin_data.data());
std::vector<int32_t> transformed_data(origin_data.begin(), origin_data.end());

Check warning on line 69 in onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h:69: Add #include <vector> for vector<> [build/include_what_you_use] [4]
auto transformed_indices = graph_ep->GetGraph()->CreateTensor(
inputs[1]->GetSpec().SetAttribute(tim::vx::TensorAttribute::INPUT).SetDataType(tim::vx::DataType::INT32), transformed_data.data());

Check warning on line 71 in onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h:71: Lines should be <= 120 characters long [whitespace/line_length] [2]
(*op).BindInput(inputs[0]).BindInput(transformed_indices).BindOutput(outputs[0]);
}
graph_ep->GetOps().push_back(std::move(op));

Check warning on line 74 in onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h:74: Add #include <utility> for move [build/include_what_you_use] [4]
return true;
}
};

} // namespace npu

} // namespace vsi
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BatchNormOpBuilder : public BaseOpBuilder {
mean_tensor = 3,
var_tensor = 4
};
int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override{ return 9; }
int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 9; }

bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer,
const Node* node) const override {
Expand Down
24 changes: 12 additions & 12 deletions onnxruntime/core/providers/vsinpu/builders/impl/reduce_op_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,40 +38,40 @@ class ReduceMeanOpBuilder : public BaseOpBuilder {
return true;
}
bool HandleBuildOp(vsi::npu::GraphEP* graph_ep,
std::vector<std::shared_ptr<tim::vx::Tensor>>& inputs,
std::vector<std::shared_ptr<tim::vx::Tensor>>& outputs,
const NodeUnit& node_unit) override {
std::vector<std::shared_ptr<tim::vx::Tensor>>& inputs,
std::vector<std::shared_ptr<tim::vx::Tensor>>& outputs,

Check warning on line 42 in onnxruntime/core/providers/vsinpu/builders/impl/reduce_op_builder.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/vsinpu/builders/impl/reduce_op_builder.h:42: Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4]
const NodeUnit& node_unit) override {
LOGS_DEFAULT(INFO) << "Creating ReduceMean Op.";

NodeAttrHelper helper(node_unit.GetNode());
std::vector<int64_t> def_axes;
auto input_shape_size = inputs[0]->GetShape().size();

if (node_unit.SinceVersion() < 18 && helper.HasAttr("axes")) {
def_axes = helper.Get("axes", def_axes);
def_axes = helper.Get("axes", def_axes);
} else if (inputs.size() > 1) {
def_axes.resize(inputs[1]->GetSpec().GetElementNum());
inputs[1]->CopyDataFromTensor(def_axes.data());
def_axes.resize(inputs[1]->GetSpec().GetElementNum());
inputs[1]->CopyDataFromTensor(def_axes.data());
} else {
for (int64_t i = 0; i < input_shape_size; ++i) {
def_axes.push_back(i);
}
for (int64_t i = 0; i < input_shape_size; ++i) {
def_axes.push_back(i);
}
}

std::vector<int32_t> axes(def_axes.begin(), def_axes.end());
axes = util::ReverseAxis(axes, input_shape_size);

if (helper.HasAttr("noop_with_empty_axes") && inputs.size() == 1 && helper.Get("noop_with_empty_axes", 0) == 1) {
outputs[0] = inputs[0];
return true;
outputs[0] = inputs[0];
return true;
}

bool keepdims = helper.Get("keepdims", 1) == 1;
auto op = graph_ep->GetGraph()->CreateOperation<tim::vx::ops::ReduceMean>(axes, keepdims);
(*op).BindInput(inputs[0]).BindOutputs(outputs);
graph_ep->GetOps().push_back(std::move(op));
return true;
}
}
};
} // namespace npu

Expand Down
Loading

0 comments on commit 986d083

Please sign in to comment.