Skip to content

Commit

Permalink
fix lint issues + one bug in IsCompatible
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Nov 17, 2023
1 parent 9ba122f commit 2364589
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
10 changes: 6 additions & 4 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#endif

#include <type_traits>
#include <unordered_set>

#include "core/common/gsl.h"
#include "core/framework/data_types.h"
Expand Down Expand Up @@ -804,7 +805,8 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vect
if (!all_types.empty()) {
std::vector<std::string> input_types;
for (auto type : all_types) {
const ONNX_NAMESPACE::TypeProto* type_proto = DataTypeImpl::TensorTypeFromONNXEnum(static_cast<int>(type))->GetTypeProto();
const ONNX_NAMESPACE::TypeProto* type_proto =
DataTypeImpl::TensorTypeFromONNXEnum(static_cast<int>(type))->GetTypeProto();
input_types.push_back(*ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(*type_proto));
}
schema.TypeConstraint(input_name, input_types, "defined list of types");
Expand Down Expand Up @@ -864,7 +866,8 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vect
if (!all_types.empty()) {
std::vector<std::string> output_types;
for (auto otype : all_types) {
const ONNX_NAMESPACE::TypeProto* type_proto = DataTypeImpl::TensorTypeFromONNXEnum(static_cast<int>(otype))->GetTypeProto();
const ONNX_NAMESPACE::TypeProto* type_proto =
DataTypeImpl::TensorTypeFromONNXEnum(static_cast<int>(otype))->GetTypeProto();
output_types.push_back(*ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(*type_proto));
}
schema.TypeConstraint(output_name, output_types, "defined list of types");
Expand All @@ -891,7 +894,6 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vect
return schema;
}


Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* op) {
const size_t input_count = op->GetInputTypeCount(op);
const size_t output_count = op->GetOutputTypeCount(op);
Expand Down Expand Up @@ -949,7 +951,7 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" output to keep same homogeneity");
ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op),
ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicOutputMinArity(op),
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" output to keep same arity");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <string>
#include <utility>

#include "custom_gemm.h"

#ifndef ORT_ENFORCE
Expand Down Expand Up @@ -328,7 +331,7 @@ const char* CustomGemmOp::GetExecutionProviderType() const {
return "CPUExecutionProvider";
}

size_t CustomGemmOp::GetInputTypeCount() const { return 6; };
size_t CustomGemmOp::GetInputTypeCount() const { return 6; }

ONNXTensorElementDataType CustomGemmOp::GetInputType(size_t index) const {
switch (index) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

#include "custom_op_local_function.h"

#include <vector>
#include <cmath>
#include <mutex>
#include <utility>
#include <vector>

#include "core/common/common.h"
#include "core/framework/ortdevice.h"
Expand Down

0 comments on commit 2364589

Please sign in to comment.