From 7841444f1ab3e3f9d1f36f6c7d37549d6c9abd7f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 3 Oct 2024 11:53:01 +0000 Subject: [PATCH 01/18] add conv fp16 kernel in xnnpack --- .../core/providers/xnnpack/detail/utils.cc | 13 +++ .../core/providers/xnnpack/detail/utils.h | 3 + onnxruntime/core/providers/xnnpack/nn/conv.cc | 85 +++++++++++-------- .../core/providers/xnnpack/nn/conv_base.cc | 37 ++++++-- .../xnnpack/xnnpack_execution_provider.cc | 5 ++ 5 files changed, 102 insertions(+), 41 deletions(-) diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc index f9cb45ebc8abc..f1781fc698985 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.cc +++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc @@ -4,14 +4,17 @@ #include "utils.h" #include #include +#include #include "core/common/common.h" #include "core/common/safeint.h" #include "core/framework/node_unit.h" #include "core/framework/tensorprotoutils.h" +#include "core/graph/graph.h" #include "core/graph/indexed_sub_graph.h" #include "core/graph/node_attr_utils.h" #include "core/optimizer/initializer.h" +#include "core/providers/xnnpack/xnnpack_init.h" #include "onnx/defs/attr_proto_util.h" @@ -111,6 +114,16 @@ bool IsPaddingTypeSupported(AutoPadType auto_pad) { auto_pad == AutoPadType::SAME_UPPER; } +bool IsComputeTypeSupported(uint8_t op_compute_type) { +#ifdef XNNPACK_FP16_SUPPORTED + std::set SupportedComputeType = {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, ONNX_NAMESPACE::TensorProto_DataType_UINT8, ONNX_NAMESPACE::TensorProto_DataType_INT8, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16}; +#else + std::set SupportedComputeType = {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, ONNX_NAMESPACE::TensorProto_DataType_UINT8, ONNX_NAMESPACE::TensorProto_DataType_INT8}; +#endif + return std::find(SupportedComputeType.begin(), SupportedComputeType.end(), op_compute_type) != SupportedComputeType.end(); +} + typedef std::string ONNXOpType; static const std::unordered_map qdq_to_onnx_type_map = { diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.h b/onnxruntime/core/providers/xnnpack/detail/utils.h index d555ee2286b84..67898d286eda5 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.h +++ b/onnxruntime/core/providers/xnnpack/detail/utils.h @@ -77,6 +77,8 @@ struct XnnpackOperatorDeleter { bool IsPaddingTypeSupported(AutoPadType auto_pad); +bool IsComputeTypeSupported(uint8_t op_compute_type); + using XnnpackOperator = std::unique_ptr; std::unique_ptr FuseActivation(const NodeUnit& conv_unit, const NodeUnit& activation, @@ -99,5 +101,6 @@ auto xnn_u8s8_quantize(float val, float scale, T zero_point) { auto zp = static_cast(zero_point); return static_cast(lrintf(fminf(fmaxf(val / scale + zp, typed_min), typed_max))); } + } // namespace xnnpack } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index b815cc1570c96..e43fc37d8904c 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -22,41 +22,40 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; // only layout of weight input is adjusted via PrePack - if ((conv_type_ == OpComputeType::op_compute_type_fp32 && input_idx == 1) || - (conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W - // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} - auto orig_shape = tensor.Shape(); - const auto rank = orig_shape.NumDimensions(); - - if (rank == 4) { - InlinedVector perm{0, 2, 3, 1}; - TensorShapeVector new_dims{orig_shape[0], - orig_shape[2], - orig_shape[3], - orig_shape[1]}; - - packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); - - SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 3); - } else { - assert(rank == 3); // ConvBase::IsOnnxNodeSupported validates this - - InlinedVector perm{0, 2, 1}; - TensorShapeVector new_dims{orig_shape[0], - orig_shape[2], - orig_shape[1]}; - - packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); - - SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 2); + const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 || + conv_type_ == OpComputeType::op_compute_type_fp16); + if((conv_type_is_float && input_idx == 1) ||(!conv_type_is_float && input_idx == 3)) { + // InputTensors::IN_W, Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} + auto orig_shape = tensor.Shape(); + const auto rank = orig_shape.NumDimensions(); + + if (rank == 4) { + InlinedVector perm{0, 2, 3, 1}; + TensorShapeVector new_dims{orig_shape[0], + orig_shape[2], + orig_shape[3], + orig_shape[1]}; + + packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + + SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 3); + } else { + assert(rank == 3); // ConvBase::IsOnnxNodeSupported validates this + + InlinedVector perm{0, 2, 1}; + TensorShapeVector new_dims{orig_shape[0], + orig_shape[2], + orig_shape[1]}; + + packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + + SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 2); + } + + is_packed = true; + // we can create the kernel now + ORT_RETURN_IF_ERROR(CreateKernel()); } - - is_packed = true; - - // we can create the kernel now - ORT_RETURN_IF_ERROR(CreateKernel()); - } - return Status::OK(); } @@ -102,8 +101,13 @@ Status Conv::Compute(OpKernelContext* context) const { reshape_fn = xnn_reshape_convolution2d_nhwc_qu8; } else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) { reshape_fn = xnn_reshape_convolution2d_nhwc_qs8_qc8w; + } else if (conv_type_ == OpComputeType::op_compute_type_fp16) { + reshape_fn = xnn_reshape_convolution2d_nhwc_f16; } + if (!op0_.get()) { + throw std::invalid_argument("op0 ------"); + } auto status = reshape_fn(op0_.get(), N, H, W, &workspace_size, &workspace_alignment, /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, @@ -112,7 +116,6 @@ Status Conv::Compute(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_convolution2d_nhwc_", OpTypeToString(conv_type_), "returned ", status); } - workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size)); if (conv_type_ == OpComputeType::op_compute_type_fp32) { @@ -127,6 +130,9 @@ Status Conv::Compute(OpKernelContext* context) const { } else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) { status = xnn_setup_convolution2d_nhwc_qs8_qc8w(op0_.get(), workspace.get(), X.Data(), Y->MutableData()); + } else if (conv_type_ == OpComputeType::op_compute_type_fp16) { + status = xnn_setup_convolution2d_nhwc_f16(op0_.get(), workspace.get(), X.Data(), + Y->MutableData()); } if (status != xnn_status_success) { @@ -149,6 +155,15 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 1, 10, kXnnpackEx ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Conv); +#ifdef XNNPACK_FP16_SUPPORTED +ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 1, 10, MLFloat16, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, MLFloat16, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Conv); +#endif ONNX_OPERATOR_TYPED_KERNEL_EX( QLinearConv, diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc index 2aafc9be7ffd0..89bf4aac1d394 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc @@ -154,6 +154,26 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, flags, code_cache, weights_cache, &p); + } else if (conv_type == OpComputeType::op_compute_type_fp16) { + const auto* B_data = Bias ? Bias->Data() : nullptr; + const float output_min = -65504.0; + const float output_max = 65504.0; + auto create_func = is_transpose ? xnn_create_deconvolution2d_nhwc_f16 + : xnn_create_convolution2d_nhwc_f16; + status = create_func( + input_padding_top, input_padding_right, input_padding_bottom, input_padding_left, + kernel_height, kernel_width, + subsampling_height, subsampling_width, + dilation_height, dilation_width, + group_count, + group_input_channels, + group_output_channels, + C, M, // input channel stride, output channel stride + Weight.Data(), B_data, // kernel, bias + output_min, output_max, + flags, + code_cache, weights_cache, + &p); } if (status != xnn_status_success) { @@ -236,6 +256,13 @@ OpComputeType GetConvCompType( return op_compute_type_qu8; } break; + case TensorTypeFp16: + if (input_datatype == TensorTypeFp16 && + (!bias_datatype || *bias_datatype == TensorTypeInt32) && + output_datatype == TensorTypeFp16) { + return op_compute_type_fp16; + } + break; default: break; } @@ -326,10 +353,7 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& // we only support float and u8 currently const auto* x_type = x_arg.TypeAsProto(); - if (x_type == nullptr || - (x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && - x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 && - x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) { + if (x_type == nullptr || !IsComputeTypeSupported(x_type->tensor_type().elem_type())) { break; } // require C, H, W to be known so we can construct the xnnpack kernel prior to Compute @@ -420,9 +444,11 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) input_dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { weight_index = 3; conv_type_ = ParseQuantParamAndConType(info, quant_param_, input_dtype); + } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + conv_type_ = OpComputeType::op_compute_type_fp16; } else { auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X.TypeAsProto())); - ORT_THROW("unsupported Conv in XnnpackEP, we have FLOAT|UINT8|INT8, but got ", stype); + ORT_THROW("unsupported Conv in XnnpackEP, we have FLOAT|UINT8|INT8|FLOAT16, but got ", stype); } ORT_ENFORCE(info.TryGetConstantInput(weight_index, &Weight), @@ -491,7 +517,6 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) output_shape_.push_back(M_); } - // have to delay creating the xnnpack kernel until after the weights are pre-packed. } diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index df7df0b4376ce..067efb92dd469 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -64,6 +64,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWC class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, Conv); +CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv); +CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, Conv); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); @@ -161,6 +163,9 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO(1, QLinearSoftmax, kDynamicDomainByCreate), #ifdef XNNPACK_FP16_SUPPORTED + KERNEL_CREATE_INFO_VERSIONED_TYPED(1, 10, MLFloat16, Conv, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_TYPED(11, MLFloat16, Conv, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_VERSIONED_TYPED(8, 9, MLFloat16, MaxPool, kMSInternalNHWCDomain), KERNEL_CREATE_INFO_VERSIONED_TYPED(10, 10, MLFloat16, MaxPool, kMSInternalNHWCDomain), KERNEL_CREATE_INFO_VERSIONED_TYPED(11, 11, MLFloat16, MaxPool, kMSInternalNHWCDomain), From d4a863dca077b33d3067cab39c3e8719a6b8c04c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 3 Oct 2024 11:55:30 +0000 Subject: [PATCH 02/18] fix lint --- onnxruntime/core/providers/xnnpack/nn/conv.cc | 64 +++++++++---------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index e43fc37d8904c..7eaa79d09fb64 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -23,39 +23,39 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, is_packed = false; // only layout of weight input is adjusted via PrePack const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 || - conv_type_ == OpComputeType::op_compute_type_fp16); - if((conv_type_is_float && input_idx == 1) ||(!conv_type_is_float && input_idx == 3)) { - // InputTensors::IN_W, Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} - auto orig_shape = tensor.Shape(); - const auto rank = orig_shape.NumDimensions(); - - if (rank == 4) { - InlinedVector perm{0, 2, 3, 1}; - TensorShapeVector new_dims{orig_shape[0], - orig_shape[2], - orig_shape[3], - orig_shape[1]}; - - packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); - - SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 3); - } else { - assert(rank == 3); // ConvBase::IsOnnxNodeSupported validates this - - InlinedVector perm{0, 2, 1}; - TensorShapeVector new_dims{orig_shape[0], - orig_shape[2], - orig_shape[1]}; - - packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); - - SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 2); - } - - is_packed = true; - // we can create the kernel now - ORT_RETURN_IF_ERROR(CreateKernel()); + conv_type_ == OpComputeType::op_compute_type_fp16); + if ((conv_type_is_float && input_idx == 1) || (!conv_type_is_float && input_idx == 3)) { + // InputTensors::IN_W, Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} + auto orig_shape = tensor.Shape(); + const auto rank = orig_shape.NumDimensions(); + + if (rank == 4) { + InlinedVector perm{0, 2, 3, 1}; + TensorShapeVector new_dims{orig_shape[0], + orig_shape[2], + orig_shape[3], + orig_shape[1]}; + + packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + + SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 3); + } else { + assert(rank == 3); // ConvBase::IsOnnxNodeSupported validates this + + InlinedVector perm{0, 2, 1}; + TensorShapeVector new_dims{orig_shape[0], + orig_shape[2], + orig_shape[1]}; + + packed_w_ = Tensor(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + + SingleAxisTranspose(perm, tensor, packed_w_, /*from*/ 1, /*to*/ 2); } + + is_packed = true; + // we can create the kernel now + ORT_RETURN_IF_ERROR(CreateKernel()); + } return Status::OK(); } From abbacdb1febe63b24461a22f2067f8fc629f4809 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 3 Oct 2024 11:56:12 +0000 Subject: [PATCH 03/18] add missing changes --- onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index ce1ac7591ec34..66bb34bb269dd 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -3,7 +3,7 @@ #include "core/mlas/inc/mlas.h" -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(XNNPACK_FP16_SUPPORTED) #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" From e84f9eb7d0db496bccc68994fa14f22a2679c1d3 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 3 Oct 2024 13:37:51 +0000 Subject: [PATCH 04/18] add convtranspose --- onnxruntime/core/providers/xnnpack/nn/conv.cc | 2 +- .../providers/xnnpack/nn/conv_transpose.cc | 34 +++++++++++++++---- .../xnnpack/xnnpack_execution_provider.cc | 6 ++++ 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index 7eaa79d09fb64..3588e348c3810 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -10,8 +10,8 @@ #include "core/framework/tensorprotoutils.h" #include "core/framework/transpose_helper.h" #include "core/providers/utils.h" -#include "core/providers/xnnpack/xnnpack_init.h" #include "core/providers/xnnpack/detail/utils.h" +#include "core/providers/xnnpack/xnnpack_init.h" namespace onnxruntime { namespace xnnpack { diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc index 01c8119fea79d..ed3cbe0b1edce 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc @@ -7,6 +7,7 @@ #include "core/framework/transpose_helper.h" #include "core/providers/utils.h" #include "core/providers/xnnpack/detail/utils.h" +#include "core/providers/xnnpack/xnnpack_init.h" #include "core/framework/tensorprotoutils.h" namespace onnxruntime { @@ -18,8 +19,9 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*out*/ PrePackedWeights* /*prepacked_weights*/) { is_packed = false; // only layout of weight input is adjusted via PrePack - if ((conv_type_ == OpComputeType::op_compute_type_fp32 && input_idx == 1) || - (conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W + const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 || + conv_type_ == OpComputeType::op_compute_type_fp16); + if ((conv_type_is_float && input_idx == 1) || (!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W auto orig_shape = tensor.Shape(); const auto rank = orig_shape.NumDimensions(); @@ -129,6 +131,8 @@ Status ConvTranspose::Compute(OpKernelContext* context) const { reshape_fn = xnn_reshape_deconvolution2d_nhwc_qs8; } else if (conv_type_ == OpComputeType::op_compute_type_qu8) { reshape_fn = xnn_reshape_deconvolution2d_nhwc_qu8; + } else if (conv_type_ == OpComputeType::op_compute_type_fp16) { + reshape_fn = xnn_reshape_deconvolution2d_nhwc_f16; } status = reshape_fn(op0_.get(), N, H, W, output_pad_0, output_pad_1, @@ -146,6 +150,8 @@ Status ConvTranspose::Compute(OpKernelContext* context) const { status = xnn_setup_deconvolution2d_nhwc_qs8(op0_.get(), X.Data(), Y->MutableData()); } else if (conv_type_ == OpComputeType::op_compute_type_qu8) { status = xnn_setup_deconvolution2d_nhwc_qu8(op0_.get(), X.Data(), Y->MutableData()); + } else if (conv_type_ == OpComputeType::op_compute_type_fp16) { + status = xnn_setup_deconvolution2d_nhwc_f16(op0_.get(), X.Data(), Y->MutableData()); } if (status != xnn_status_success) { @@ -161,16 +167,16 @@ Status ConvTranspose::Compute(OpKernelContext* context) const { return Status::OK(); } -ONNX_OPERATOR_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint( - "T", DataTypeImpl::GetTensorType()), - ConvTranspose); - ONNX_OPERATOR_VERSIONED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 1, 10, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint( "T", DataTypeImpl::GetTensorType()), ConvTranspose); +ONNX_OPERATOR_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint( + "T", DataTypeImpl::GetTensorType()), + ConvTranspose); + ONNX_OPERATOR_KERNEL_EX(QLinearConvTranspose, kMSInternalNHWCDomain, 1, kXnnpackExecutionProvider, KernelDefBuilder() .TypeConstraint( @@ -179,5 +185,19 @@ ONNX_OPERATOR_KERNEL_EX(QLinearConvTranspose, kMSInternalNHWCDomain, 1, kXnnpack DataTypeImpl::GetTensorType()}), ConvTranspose); +#ifdef XNNPACK_FP16_SUPPORTED +ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 1, 10, MLFloat16, + kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint( + "T", DataTypeImpl::GetTensorType()), + ConvTranspose); + +ONNX_OPERATOR_TYPED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, MLFloat16, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint( + "T", DataTypeImpl::GetTensorType()), + ConvTranspose); + + +#endif } // namespace xnnpack } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index 067efb92dd469..288bb95311cce 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -69,6 +69,9 @@ CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInterna class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); +CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, \ + ConvTranspose); +CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, uint8_t, QLinearConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, int8_t, QLinearConv); @@ -166,6 +169,9 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO_VERSIONED_TYPED(1, 10, MLFloat16, Conv, kMSInternalNHWCDomain), KERNEL_CREATE_INFO_TYPED(11, MLFloat16, Conv, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_VERSIONED_TYPED(1, 10, MLFloat16, ConvTranspose, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_TYPED(11, MLFloat16, ConvTranspose, kMSInternalNHWCDomain), + KERNEL_CREATE_INFO_VERSIONED_TYPED(8, 9, MLFloat16, MaxPool, kMSInternalNHWCDomain), KERNEL_CREATE_INFO_VERSIONED_TYPED(10, 10, MLFloat16, MaxPool, kMSInternalNHWCDomain), KERNEL_CREATE_INFO_VERSIONED_TYPED(11, 11, MLFloat16, MaxPool, kMSInternalNHWCDomain), From e33b78041e11c905692cb79bd762f68812c79abf Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 3 Oct 2024 13:44:41 +0000 Subject: [PATCH 05/18] fix lint --- .../core/providers/xnnpack/nn/conv_transpose.cc | 17 ++++++++--------- .../xnnpack/xnnpack_execution_provider.cc | 4 ++-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc index ed3cbe0b1edce..f9ed3c32160c5 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc @@ -21,7 +21,7 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr // only layout of weight input is adjusted via PrePack const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 || conv_type_ == OpComputeType::op_compute_type_fp16); - if ((conv_type_is_float && input_idx == 1) || (!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W + if ((conv_type_is_float && input_idx == 1) || (!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W auto orig_shape = tensor.Shape(); const auto rank = orig_shape.NumDimensions(); @@ -187,16 +187,15 @@ ONNX_OPERATOR_KERNEL_EX(QLinearConvTranspose, kMSInternalNHWCDomain, 1, kXnnpack #ifdef XNNPACK_FP16_SUPPORTED ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 1, 10, MLFloat16, - kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint( - "T", DataTypeImpl::GetTensorType()), - ConvTranspose); + kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint( + "T", DataTypeImpl::GetTensorType()), + ConvTranspose); ONNX_OPERATOR_TYPED_KERNEL_EX(ConvTranspose, kMSInternalNHWCDomain, 11, MLFloat16, kXnnpackExecutionProvider, - KernelDefBuilder().TypeConstraint( - "T", DataTypeImpl::GetTensorType()), - ConvTranspose); - + KernelDefBuilder().TypeConstraint( + "T", DataTypeImpl::GetTensorType()), + ConvTranspose); #endif } // namespace xnnpack diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index 288bb95311cce..647122d5603e2 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -69,8 +69,8 @@ CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInterna class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); -CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, \ - ConvTranspose); +CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 1, 10, + ConvTranspose); CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, uint8_t, QLinearConv); From 582c3c3616b3c97673d9743d303b8d1f04fc389c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 4 Oct 2024 03:01:18 +0000 Subject: [PATCH 06/18] update torlerance --- onnxruntime/core/providers/xnnpack/detail/utils.cc | 2 +- onnxruntime/core/providers/xnnpack/detail/utils.h | 2 +- onnxruntime/test/providers/checkers.cc | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc index f1781fc698985..b0f00a8aa628e 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.cc +++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc @@ -114,7 +114,7 @@ bool IsPaddingTypeSupported(AutoPadType auto_pad) { auto_pad == AutoPadType::SAME_UPPER; } -bool IsComputeTypeSupported(uint8_t op_compute_type) { +bool IsComputeTypeSupported(int32_t op_compute_type) { #ifdef XNNPACK_FP16_SUPPORTED std::set SupportedComputeType = {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, ONNX_NAMESPACE::TensorProto_DataType_UINT8, ONNX_NAMESPACE::TensorProto_DataType_INT8, ONNX_NAMESPACE::TensorProto_DataType_FLOAT16}; diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.h b/onnxruntime/core/providers/xnnpack/detail/utils.h index 67898d286eda5..263da07b831ab 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.h +++ b/onnxruntime/core/providers/xnnpack/detail/utils.h @@ -77,7 +77,7 @@ struct XnnpackOperatorDeleter { bool IsPaddingTypeSupported(AutoPadType auto_pad); -bool IsComputeTypeSupported(uint8_t op_compute_type); +bool IsComputeTypeSupported(int32_t op_compute_type); using XnnpackOperator = std::unique_ptr; diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index ff5895623fc9b..3c279a6eebeaf 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -80,6 +80,9 @@ struct DefaultTolerance { if (provider_type == kDmlExecutionProvider) { return 0.005f; } + if (provider_type == kXnnpackExecutionProvider) { + return 0.05f; + } return absolute; } }; From 72f69a9df9c1257a058c3cc5c9925a073f3c2efc Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 4 Oct 2024 03:33:26 +0000 Subject: [PATCH 07/18] add one comment --- onnxruntime/test/providers/checkers.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index 3c279a6eebeaf..d82d13afd353d 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -81,6 +81,7 @@ struct DefaultTolerance { return 0.005f; } if (provider_type == kXnnpackExecutionProvider) { + // To allow tests like ConvTranspose_2D_Bias_1 to pass return 0.05f; } return absolute; From 85f8b9a2489cb091beaf3a8edf9972632d951fa9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 4 Oct 2024 04:39:04 +0000 Subject: [PATCH 08/18] update --- onnxruntime/core/providers/xnnpack/detail/utils.cc | 5 +++-- onnxruntime/core/providers/xnnpack/detail/utils.h | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc index b0f00a8aa628e..b0d3f3e17e33e 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.cc +++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc @@ -114,14 +114,15 @@ bool IsPaddingTypeSupported(AutoPadType auto_pad) { auto_pad == AutoPadType::SAME_UPPER; } -bool IsComputeTypeSupported(int32_t op_compute_type) { +bool IsComputeTypeSupported(int32_t tp) { #ifdef XNNPACK_FP16_SUPPORTED std::set SupportedComputeType = {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, ONNX_NAMESPACE::TensorProto_DataType_UINT8, ONNX_NAMESPACE::TensorProto_DataType_INT8, ONNX_NAMESPACE::TensorProto_DataType_FLOAT16}; #else std::set SupportedComputeType = {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, ONNX_NAMESPACE::TensorProto_DataType_UINT8, ONNX_NAMESPACE::TensorProto_DataType_INT8}; #endif - return std::find(SupportedComputeType.begin(), SupportedComputeType.end(), op_compute_type) != SupportedComputeType.end(); + ONNX_NAMESPACE::TensorProto_DataType compute_type = static_cast(tp); + return std::find(SupportedComputeType.begin(), SupportedComputeType.end(), compute_type) != SupportedComputeType.end(); } typedef std::string ONNXOpType; diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.h b/onnxruntime/core/providers/xnnpack/detail/utils.h index 263da07b831ab..000b4888ff379 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.h +++ b/onnxruntime/core/providers/xnnpack/detail/utils.h @@ -77,7 +77,7 @@ struct XnnpackOperatorDeleter { bool IsPaddingTypeSupported(AutoPadType auto_pad); -bool IsComputeTypeSupported(int32_t op_compute_type); +bool IsComputeTypeSupported(int32_t tp); using XnnpackOperator = std::unique_ptr; From d0b507a11e36f2d4764c68fa4959ce909ab5221a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 4 Oct 2024 09:02:00 +0000 Subject: [PATCH 09/18] fix lint --- .../core/providers/xnnpack/detail/utils.cc | 16 ++++++++++++---- onnxruntime/core/providers/xnnpack/nn/conv.cc | 14 ++++++-------- .../core/providers/xnnpack/nn/conv_transpose.cc | 3 ++- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc index b0d3f3e17e33e..7453d7b075ecb 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.cc +++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc @@ -2,9 +2,9 @@ // Licensed under the MIT License. #include "utils.h" +#include #include #include -#include #include "core/common/common.h" #include "core/common/safeint.h" @@ -116,10 +116,18 @@ bool IsPaddingTypeSupported(AutoPadType auto_pad) { bool IsComputeTypeSupported(int32_t tp) { #ifdef XNNPACK_FP16_SUPPORTED - std::set SupportedComputeType = {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, ONNX_NAMESPACE::TensorProto_DataType_UINT8, ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16}; + std::unordered_set SupportedComputeType = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_UINT8, + ONNX_NAMESPACE::TensorProto_DataType_INT8, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 + }; #else - std::set SupportedComputeType = {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, ONNX_NAMESPACE::TensorProto_DataType_UINT8, ONNX_NAMESPACE::TensorProto_DataType_INT8}; + std::unordered_set SupportedComputeType = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_UINT8, + ONNX_NAMESPACE::TensorProto_DataType_INT8 + }; #endif ONNX_NAMESPACE::TensorProto_DataType compute_type = static_cast(tp); return std::find(SupportedComputeType.begin(), SupportedComputeType.end(), compute_type) != SupportedComputeType.end(); diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index 3588e348c3810..a1603b744996e 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -24,8 +24,9 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, // only layout of weight input is adjusted via PrePack const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 || conv_type_ == OpComputeType::op_compute_type_fp16); - if ((conv_type_is_float && input_idx == 1) || (!conv_type_is_float && input_idx == 3)) { - // InputTensors::IN_W, Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} + if ((conv_type_is_float && input_idx == 1) || + (!conv_type_is_float && input_idx == 3)) {// InputTensors::IN_W + // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); const auto rank = orig_shape.NumDimensions(); @@ -105,9 +106,6 @@ Status Conv::Compute(OpKernelContext* context) const { reshape_fn = xnn_reshape_convolution2d_nhwc_f16; } - if (!op0_.get()) { - throw std::invalid_argument("op0 ------"); - } auto status = reshape_fn(op0_.get(), N, H, W, &workspace_size, &workspace_alignment, /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, @@ -121,6 +119,9 @@ Status Conv::Compute(OpKernelContext* context) const { if (conv_type_ == OpComputeType::op_compute_type_fp32) { status = xnn_setup_convolution2d_nhwc_f32(op0_.get(), workspace.get(), X.Data(), Y->MutableData()); + } else if (conv_type_ == OpComputeType::op_compute_type_fp16) { + status = xnn_setup_convolution2d_nhwc_f16(op0_.get(), workspace.get(), X.Data(), + Y->MutableData()); } else if (conv_type_ == OpComputeType::op_compute_type_qs8) { status = xnn_setup_convolution2d_nhwc_qs8(op0_.get(), workspace.get(), X.Data(), Y->MutableData()); @@ -130,9 +131,6 @@ Status Conv::Compute(OpKernelContext* context) const { } else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) { status = xnn_setup_convolution2d_nhwc_qs8_qc8w(op0_.get(), workspace.get(), X.Data(), Y->MutableData()); - } else if (conv_type_ == OpComputeType::op_compute_type_fp16) { - status = xnn_setup_convolution2d_nhwc_f16(op0_.get(), workspace.get(), X.Data(), - Y->MutableData()); } if (status != xnn_status_success) { diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc index f9ed3c32160c5..b399311cd8568 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc @@ -21,7 +21,8 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr // only layout of weight input is adjusted via PrePack const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 || conv_type_ == OpComputeType::op_compute_type_fp16); - if ((conv_type_is_float && input_idx == 1) || (!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W + if ((conv_type_is_float && input_idx == 1) || + (!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W auto orig_shape = tensor.Shape(); const auto rank = orig_shape.NumDimensions(); From c82465546b14edffa9b2cfd7d462957fa526d6c6 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 4 Oct 2024 12:53:39 +0000 Subject: [PATCH 10/18] improve compute type helper function --- .../core/providers/xnnpack/detail/utils.cc | 26 ++++++++----------- .../core/providers/xnnpack/detail/utils.h | 6 +++-- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc index 7453d7b075ecb..912560dd75fe6 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.cc +++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc @@ -114,23 +114,19 @@ bool IsPaddingTypeSupported(AutoPadType auto_pad) { auto_pad == AutoPadType::SAME_UPPER; } -bool IsComputeTypeSupported(int32_t tp) { +bool IsComputeTypeSupported(int32_t compute_type, + std::optional> compute_type_set) { + std::unordered_set default_supported_types { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_UINT8, + ONNX_NAMESPACE::TensorProto_DataType_INT8 + }; + auto supported_types = compute_type_set == std::nullopt ? default_supported_types : compute_type_set->get(); #ifdef XNNPACK_FP16_SUPPORTED - std::unordered_set SupportedComputeType = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 - }; -#else - std::unordered_set SupportedComputeType = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - ONNX_NAMESPACE::TensorProto_DataType_INT8 - }; + supported_types.insert(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); #endif - ONNX_NAMESPACE::TensorProto_DataType compute_type = static_cast(tp); - return std::find(SupportedComputeType.begin(), SupportedComputeType.end(), compute_type) != SupportedComputeType.end(); + ONNX_NAMESPACE::TensorProto_DataType tp = static_cast(compute_type); + return std::find(supported_types.begin(), supported_types.end(), tp) != supported_types.end(); } typedef std::string ONNXOpType; diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.h b/onnxruntime/core/providers/xnnpack/detail/utils.h index 000b4888ff379..4bed0df1e7f95 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.h +++ b/onnxruntime/core/providers/xnnpack/detail/utils.h @@ -6,8 +6,8 @@ #include #include #include -#include #include +#include #include #include "core/framework/node_unit.h" @@ -77,7 +77,9 @@ struct XnnpackOperatorDeleter { bool IsPaddingTypeSupported(AutoPadType auto_pad); -bool IsComputeTypeSupported(int32_t tp); +using COMPUTE_TYPE_SETS = std::unordered_set; +bool IsComputeTypeSupported(int32_t compute_type, + std::optional> compute_type_set = std::nullopt); using XnnpackOperator = std::unique_ptr; From 9c97d4a5829288e07629890b42a1be285152f0ff Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 5 Oct 2024 03:47:19 +0000 Subject: [PATCH 11/18] temp change --- onnxruntime/test/providers/checkers.cc | 4 -- .../test/providers/cpu/nn/conv_fp16_test.cc | 1 + .../cpu/nn/conv_transpose_op_test.cc | 39 ++++++++++++++++--- .../providers/cpu/nn/pool_fp16_op_test.cc | 1 + 4 files changed, 35 insertions(+), 10 deletions(-) diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index d82d13afd353d..ff5895623fc9b 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -80,10 +80,6 @@ struct DefaultTolerance { if (provider_type == kDmlExecutionProvider) { return 0.005f; } - if (provider_type == kXnnpackExecutionProvider) { - // To allow tests like ConvTranspose_2D_Bias_1 to pass - return 0.05f; - } return absolute; } }; diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 66bb34bb269dd..cfd4a21e5c3d1 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/mlas/inc/mlas.h" +#include "core/providers/xnnpack/xnnpack_init.h" #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(XNNPACK_FP16_SUPPORTED) diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 29525f89ef544..7fcbbc56b90ec 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/xnnpack/xnnpack_init.h" + #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "default_providers.h" @@ -28,14 +30,17 @@ void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, const vector>& input_shapes, const std::vector& expected_output, const vector& expected_output_shape, + float rel_error, + float abs_error, bool is_weight_and_bias_initializer = false, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", - const std::unordered_set& excluded_provider_types = {kTensorrtExecutionProvider}) { + const std::unordered_set& excluded_provider_types = {kTensorrtExecutionProvider} + ) { OpTester test("ConvTranspose", 11); test.AddAttribute("kernel_shape", attributes.kernel_shape); test.AddAttribute("group", attributes.group); - + throw std::invalid_argument("initializer 0"); // Only one of pads / auto_pad can be present if (!attributes.pads.empty()) { test.AddAttribute("pads", attributes.pads); @@ -57,16 +62,17 @@ void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, if (!attributes.dilations.empty()) { test.AddAttribute("dilations", attributes.dilations); } - ORT_ENFORCE(inputs.size() <= 3, "Our name array is only setup to handle 3 inputs"); const char* input_names[] = {"X", "W", "B"}; bool is_initializers[] = {false, is_weight_and_bias_initializer, is_weight_and_bias_initializer}; for (size_t i = 0; i < inputs.size(); i++) { test.AddInput(input_names[i], input_shapes[i], inputs[i], is_initializers[i]); } - test.AddOutput("Y", expected_output_shape, expected_output); - test.Run(expect_result, err_str, excluded_provider_types); // Disable TensorRT because weight as input is not supported + test.AddOutput("Y", expected_output_shape, expected_output, false, rel_error, abs_error); + + // Disable TensorRT because weight as input isn't supported + test.Run(expect_result, err_str, excluded_provider_types); } template @@ -78,12 +84,18 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", const std::unordered_set& excluded_provider_types = - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}) { + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}, + float rel_error = 0.0, + float abs_error = 0.0 + ) { std::unordered_set extra_exclude_openvino_for_initializer_filter = excluded_provider_types; extra_exclude_openvino_for_initializer_filter.insert(kOpenVINOExecutionProvider); + TestConvTransposeOpInitializer(attributes, inputs, input_shapes, expected_output, expected_output_shape, + rel_error, abs_error, true, expect_result, err_str, extra_exclude_openvino_for_initializer_filter); TestConvTransposeOpInitializer(attributes, inputs, input_shapes, expected_output, expected_output_shape, + rel_error, abs_error, false, expect_result, err_str, excluded_provider_types); } @@ -245,8 +257,23 @@ TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { 0.07770107f, -0.09561026f, 0.13388641f, 0.30945939f, 0.14015588f, 0.13079405f, -0.00488365f, -0.06758944f, 0.45621645f, 0.01566098f, 0.00703105f, 0.12956856f, 0.0103332f, 0.04221053f, -0.21318194f}; +#ifdef XNNPACK_FP16_SUPPORTED + if constexpr (std::is_same::value) { + TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, + {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape, + OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}, + 0.05, 0.05); + } else { + TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, + {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape); + } + + +#else TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape); +#endif } TEST(ConvTransposeTest, ConvTranspose_2D_Bias_2) { diff --git a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc index 46eb1180f4e7e..b5b137ff686c9 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/mlas/inc/mlas.h" +#include "core/providers/xnnpack/xnnpack_init.h" #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(XNNPACK_FP16_SUPPORTED) From d2c7c2f0218fe981bc4dbe42c4ce8287c822787f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 5 Oct 2024 04:00:15 +0000 Subject: [PATCH 12/18] Revert "temp change" This reverts commit 9c97d4a5829288e07629890b42a1be285152f0ff. --- onnxruntime/test/providers/checkers.cc | 4 ++ .../test/providers/cpu/nn/conv_fp16_test.cc | 1 - .../cpu/nn/conv_transpose_op_test.cc | 39 +++---------------- .../providers/cpu/nn/pool_fp16_op_test.cc | 1 - 4 files changed, 10 insertions(+), 35 deletions(-) diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index ff5895623fc9b..d82d13afd353d 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -80,6 +80,10 @@ struct DefaultTolerance { if (provider_type == kDmlExecutionProvider) { return 0.005f; } + if (provider_type == kXnnpackExecutionProvider) { + // To allow tests like ConvTranspose_2D_Bias_1 to pass + return 0.05f; + } return absolute; } }; diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index cfd4a21e5c3d1..66bb34bb269dd 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include "core/mlas/inc/mlas.h" -#include "core/providers/xnnpack/xnnpack_init.h" #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(XNNPACK_FP16_SUPPORTED) diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 7fcbbc56b90ec..29525f89ef544 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/xnnpack/xnnpack_init.h" - #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "default_providers.h" @@ -30,17 +28,14 @@ void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, const vector>& input_shapes, const std::vector& expected_output, const vector& expected_output_shape, - float rel_error, - float abs_error, bool is_weight_and_bias_initializer = false, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", - const std::unordered_set& excluded_provider_types = {kTensorrtExecutionProvider} - ) { + const std::unordered_set& excluded_provider_types = {kTensorrtExecutionProvider}) { OpTester test("ConvTranspose", 11); test.AddAttribute("kernel_shape", attributes.kernel_shape); test.AddAttribute("group", attributes.group); - throw std::invalid_argument("initializer 0"); + // Only one of pads / auto_pad can be present if (!attributes.pads.empty()) { test.AddAttribute("pads", attributes.pads); @@ -62,17 +57,16 @@ void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, if (!attributes.dilations.empty()) { test.AddAttribute("dilations", attributes.dilations); } + ORT_ENFORCE(inputs.size() <= 3, "Our name array is only setup to handle 3 inputs"); const char* input_names[] = {"X", "W", "B"}; bool is_initializers[] = {false, is_weight_and_bias_initializer, is_weight_and_bias_initializer}; for (size_t i = 0; i < inputs.size(); i++) { test.AddInput(input_names[i], input_shapes[i], inputs[i], is_initializers[i]); } + test.AddOutput("Y", expected_output_shape, expected_output); - test.AddOutput("Y", expected_output_shape, expected_output, false, rel_error, abs_error); - - // Disable TensorRT because weight as input isn't supported - test.Run(expect_result, err_str, excluded_provider_types); + test.Run(expect_result, err_str, excluded_provider_types); // Disable TensorRT because weight as input is not supported } template @@ -84,18 +78,12 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", const std::unordered_set& excluded_provider_types = - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}, - float rel_error = 0.0, - float abs_error = 0.0 - ) { + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}) { std::unordered_set extra_exclude_openvino_for_initializer_filter = excluded_provider_types; extra_exclude_openvino_for_initializer_filter.insert(kOpenVINOExecutionProvider); - TestConvTransposeOpInitializer(attributes, inputs, input_shapes, expected_output, expected_output_shape, - rel_error, abs_error, true, expect_result, err_str, extra_exclude_openvino_for_initializer_filter); TestConvTransposeOpInitializer(attributes, inputs, input_shapes, expected_output, expected_output_shape, - rel_error, abs_error, false, expect_result, err_str, excluded_provider_types); } @@ -257,23 +245,8 @@ TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { 0.07770107f, -0.09561026f, 0.13388641f, 0.30945939f, 0.14015588f, 0.13079405f, -0.00488365f, -0.06758944f, 0.45621645f, 0.01566098f, 0.00703105f, 0.12956856f, 0.0103332f, 0.04221053f, -0.21318194f}; -#ifdef XNNPACK_FP16_SUPPORTED - if constexpr (std::is_same::value) { - TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, - {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape, - OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}, - 0.05, 0.05); - } else { - TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, - {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape); - } - - -#else TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape); -#endif } TEST(ConvTransposeTest, ConvTranspose_2D_Bias_2) { diff --git a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc index b5b137ff686c9..46eb1180f4e7e 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include "core/mlas/inc/mlas.h" -#include "core/providers/xnnpack/xnnpack_init.h" #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(XNNPACK_FP16_SUPPORTED) From 4dcb7462d287139ae249fb4ea3d29ed9e06534ff Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 7 Oct 2024 01:11:34 +0000 Subject: [PATCH 13/18] add tolerance --- onnxruntime/test/providers/checkers.cc | 4 --- .../test/providers/cpu/nn/conv_fp16_test.cc | 1 + .../cpu/nn/conv_transpose_op_test.cc | 25 +++++++++++++++++-- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index d82d13afd353d..ff5895623fc9b 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -80,10 +80,6 @@ struct DefaultTolerance { if (provider_type == kDmlExecutionProvider) { return 0.005f; } - if (provider_type == kXnnpackExecutionProvider) { - // To allow tests like ConvTranspose_2D_Bias_1 to pass - return 0.05f; - } return absolute; } }; diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 66bb34bb269dd..cfd4a21e5c3d1 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/mlas/inc/mlas.h" +#include "core/providers/xnnpack/xnnpack_init.h" #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(XNNPACK_FP16_SUPPORTED) diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 29525f89ef544..1965418c1f4dd 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/xnnpack/xnnpack_init.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "default_providers.h" @@ -28,6 +29,8 @@ void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, const vector>& input_shapes, const std::vector& expected_output, const vector& expected_output_shape, + float rel_error = 0.0, + float abs_error = 0.0, bool is_weight_and_bias_initializer = false, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", @@ -64,7 +67,7 @@ void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, for (size_t i = 0; i < inputs.size(); i++) { test.AddInput(input_names[i], input_shapes[i], inputs[i], is_initializers[i]); } - test.AddOutput("Y", expected_output_shape, expected_output); + test.AddOutput("Y", expected_output_shape, expected_output, false, rel_error, abs_error); test.Run(expect_result, err_str, excluded_provider_types); // Disable TensorRT because weight as input is not supported } @@ -78,12 +81,16 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", const std::unordered_set& excluded_provider_types = - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}) { + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}, + float rel_error = 0.0, + float abs_error = 0.0) { std::unordered_set extra_exclude_openvino_for_initializer_filter = excluded_provider_types; extra_exclude_openvino_for_initializer_filter.insert(kOpenVINOExecutionProvider); TestConvTransposeOpInitializer(attributes, inputs, input_shapes, expected_output, expected_output_shape, + rel_error, abs_error, true, expect_result, err_str, extra_exclude_openvino_for_initializer_filter); TestConvTransposeOpInitializer(attributes, inputs, input_shapes, expected_output, expected_output_shape, + rel_error, abs_error, false, expect_result, err_str, excluded_provider_types); } @@ -245,8 +252,22 @@ TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { 0.07770107f, -0.09561026f, 0.13388641f, 0.30945939f, 0.14015588f, 0.13079405f, -0.00488365f, -0.06758944f, 0.45621645f, 0.01566098f, 0.00703105f, 0.12956856f, 0.0103332f, 0.04221053f, -0.21318194f}; +#ifdef XNNPACK_FP16_SUPPORTED + if constexpr (std::is_same::value) { + TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, + {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape, + OpTester::ExpectResult::kExpectSuccess, "", // defalut value + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}, //default value + 0.5, 0.5); + } else { + TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, + {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape); + } + +#else TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape); +#endif } TEST(ConvTransposeTest, ConvTranspose_2D_Bias_2) { From 862daf618ab736647c7abcb33810cd68a5172fa4 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 7 Oct 2024 01:34:31 +0000 Subject: [PATCH 14/18] lint --- onnxruntime/core/providers/xnnpack/detail/utils.cc | 11 +++++------ onnxruntime/core/providers/xnnpack/detail/utils.h | 2 +- onnxruntime/core/providers/xnnpack/nn/conv.cc | 2 +- .../test/providers/cpu/nn/conv_transpose_op_test.cc | 4 ++-- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc index 912560dd75fe6..83604abcb8657 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.cc +++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc @@ -115,12 +115,11 @@ bool IsPaddingTypeSupported(AutoPadType auto_pad) { } bool IsComputeTypeSupported(int32_t compute_type, - std::optional> compute_type_set) { - std::unordered_set default_supported_types { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - ONNX_NAMESPACE::TensorProto_DataType_INT8 - }; + std::optional> compute_type_set) { + std::unordered_set default_supported_types{ + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_UINT8, + ONNX_NAMESPACE::TensorProto_DataType_INT8}; auto supported_types = compute_type_set == std::nullopt ? default_supported_types : compute_type_set->get(); #ifdef XNNPACK_FP16_SUPPORTED supported_types.insert(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.h b/onnxruntime/core/providers/xnnpack/detail/utils.h index 4bed0df1e7f95..95dfd6fa22734 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.h +++ b/onnxruntime/core/providers/xnnpack/detail/utils.h @@ -79,7 +79,7 @@ bool IsPaddingTypeSupported(AutoPadType auto_pad); using COMPUTE_TYPE_SETS = std::unordered_set; bool IsComputeTypeSupported(int32_t compute_type, - std::optional> compute_type_set = std::nullopt); + std::optional> compute_type_set = std::nullopt); using XnnpackOperator = std::unique_ptr; diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index a1603b744996e..b09072a2a6ce2 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -25,7 +25,7 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 || conv_type_ == OpComputeType::op_compute_type_fp16); if ((conv_type_is_float && input_idx == 1) || - (!conv_type_is_float && input_idx == 3)) {// InputTensors::IN_W + (!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); const auto rank = orig_shape.NumDimensions(); diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 1965418c1f4dd..0ce87fb65898b 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -256,8 +256,8 @@ TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_Bias_1) { if constexpr (std::is_same::value) { TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, {X_shape, W_shape, B_shape}, GetTypedArray(expected_vals), Y_shape, - OpTester::ExpectResult::kExpectSuccess, "", // defalut value - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}, //default value + OpTester::ExpectResult::kExpectSuccess, "", // defalut value + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}, // default value 0.5, 0.5); } else { TestConvTransposeOp(attrs, {GetTypedArray(X), GetTypedArray(W), GetTypedArray(B)}, From 56d0031ef4859f7ff55cdc313c21915cf3841a99 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 7 Oct 2024 02:10:11 +0000 Subject: [PATCH 15/18] F16 max value comment --- .../core/providers/xnnpack/nn/conv_base.cc | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc index 89bf4aac1d394..6cadad514cdbd 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc @@ -81,6 +81,28 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, foutput_min, foutput_max, flags, code_cache, weights_cache, &p); + } else if (conv_type == OpComputeType::op_compute_type_fp16) { + const auto* B_data = Bias ? Bias->Data() : nullptr; + // 65504 is the max value of float16 + // https://en.wikipedia.org/wiki/Half-precision_floating-point_format + const float output_min = -65504.0; + const float output_max = 65504.0; + auto create_func = is_transpose ? xnn_create_deconvolution2d_nhwc_f16 + : xnn_create_convolution2d_nhwc_f16; + status = create_func( + input_padding_top, input_padding_right, input_padding_bottom, input_padding_left, + kernel_height, kernel_width, + subsampling_height, subsampling_width, + dilation_height, dilation_width, + group_count, + group_input_channels, + group_output_channels, + C, M, // input channel stride, output channel stride + Weight.Data(), B_data, // kernel, bias + output_min, output_max, + flags, + code_cache, weights_cache, + &p); } else if (conv_type == OpComputeType::op_compute_type_qs8) { const float output_scale = quant_param[2].first[0]; const int8_t output_zero_point = quant_param[2].second; @@ -154,26 +176,6 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, flags, code_cache, weights_cache, &p); - } else if (conv_type == OpComputeType::op_compute_type_fp16) { - const auto* B_data = Bias ? Bias->Data() : nullptr; - const float output_min = -65504.0; - const float output_max = 65504.0; - auto create_func = is_transpose ? xnn_create_deconvolution2d_nhwc_f16 - : xnn_create_convolution2d_nhwc_f16; - status = create_func( - input_padding_top, input_padding_right, input_padding_bottom, input_padding_left, - kernel_height, kernel_width, - subsampling_height, subsampling_width, - dilation_height, dilation_width, - group_count, - group_input_channels, - group_output_channels, - C, M, // input channel stride, output channel stride - Weight.Data(), B_data, // kernel, bias - output_min, output_max, - flags, - code_cache, weights_cache, - &p); } if (status != xnn_status_success) { From ebf5bf7f7c8c555d916464d8e6eb0809f170157a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 7 Oct 2024 04:00:51 +0000 Subject: [PATCH 16/18] lint --- onnxruntime/core/providers/xnnpack/nn/conv_base.cc | 4 ++-- .../core/providers/xnnpack/xnnpack_execution_provider.cc | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc index 6cadad514cdbd..c314a02883d49 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc @@ -85,8 +85,8 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, const auto* B_data = Bias ? Bias->Data() : nullptr; // 65504 is the max value of float16 // https://en.wikipedia.org/wiki/Half-precision_floating-point_format - const float output_min = -65504.0; - const float output_max = 65504.0; + const auto output_min = clip_min_max ? onnxruntime::math::floatToHalf(clip_min_max->first) : -65504.0; + const auto output_max = clip_min_max ? onnxruntime::math::floatToHalf(clip_min_max->second) : 65504.0; auto create_func = is_transpose ? xnn_create_deconvolution2d_nhwc_f16 : xnn_create_convolution2d_nhwc_f16; status = create_func( diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index 647122d5603e2..4515a31eb0da0 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -44,12 +44,12 @@ KernelCreateInfo BuildKernelCreateInfo() { ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Type, Op)> #ifdef XNNPACK_FP16_SUPPORTED -#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) \ - class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, \ +#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) \ + class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, domain, \ startver, endver, MLFloat16, name) -#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name) \ - class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, startver, \ +#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name) \ + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, domain, startver, \ MLFloat16, name) #else #define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) From 1e8f64b3d721600267a9a9e0753927541df11e2f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 7 Oct 2024 06:39:37 +0000 Subject: [PATCH 17/18] update --- onnxruntime/core/providers/xnnpack/nn/conv_base.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc index c314a02883d49..7fcc236557c7e 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc @@ -85,8 +85,8 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, const auto* B_data = Bias ? Bias->Data() : nullptr; // 65504 is the max value of float16 // https://en.wikipedia.org/wiki/Half-precision_floating-point_format - const auto output_min = clip_min_max ? onnxruntime::math::floatToHalf(clip_min_max->first) : -65504.0; - const auto output_max = clip_min_max ? onnxruntime::math::floatToHalf(clip_min_max->second) : 65504.0; + float output_min = clip_min_max ? onnxruntime::math::floatToHalf(clip_min_max->first) : -65504.0; + float output_max = clip_min_max ? onnxruntime::math::floatToHalf(clip_min_max->second) : 65504.0; auto create_func = is_transpose ? xnn_create_deconvolution2d_nhwc_f16 : xnn_create_convolution2d_nhwc_f16; status = create_func( From 8009a4ab6626acd2669c3552e74caf78f080c160 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 7 Oct 2024 08:49:14 +0000 Subject: [PATCH 18/18] update --- onnxruntime/core/providers/xnnpack/detail/utils.cc | 2 +- onnxruntime/core/providers/xnnpack/nn/conv.cc | 1 + onnxruntime/core/providers/xnnpack/nn/conv_base.cc | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc index 83604abcb8657..743d196c42a22 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.cc +++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc @@ -2,8 +2,8 @@ // Licensed under the MIT License. #include "utils.h" -#include #include +#include #include #include "core/common/common.h" diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index b09072a2a6ce2..6e404c62594fd 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -54,6 +54,7 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, } is_packed = true; + // we can create the kernel now ORT_RETURN_IF_ERROR(CreateKernel()); } diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc index 7fcc236557c7e..d299d09de1128 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc @@ -85,8 +85,8 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, const auto* B_data = Bias ? Bias->Data() : nullptr; // 65504 is the max value of float16 // https://en.wikipedia.org/wiki/Half-precision_floating-point_format - float output_min = clip_min_max ? onnxruntime::math::floatToHalf(clip_min_max->first) : -65504.0; - float output_max = clip_min_max ? onnxruntime::math::floatToHalf(clip_min_max->second) : 65504.0; + auto output_min = clip_min_max ? onnxruntime::math::floatToHalf(clip_min_max->first) : -65504.0f; + auto output_max = clip_min_max ? onnxruntime::math::floatToHalf(clip_min_max->second) : 65504.0f; auto create_func = is_transpose ? xnn_create_deconvolution2d_nhwc_f16 : xnn_create_convolution2d_nhwc_f16; status = create_func(