-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add conv fp16 kernel in xnnpack EP #22301
base: main
Are you sure you want to change the base?
Changes from all commits
7841444
d4a863d
abbacdb
e84f9eb
e33b780
582c3c3
72f69a9
85f8b9a
d0b507a
c824655
9c97d4a
d2c7c2f
4dcb746
862daf6
56d0031
ebf5bf7
1e8f64b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,8 +10,8 @@ | |
#include "core/framework/tensorprotoutils.h" | ||
#include "core/framework/transpose_helper.h" | ||
#include "core/providers/utils.h" | ||
#include "core/providers/xnnpack/xnnpack_init.h" | ||
#include "core/providers/xnnpack/detail/utils.h" | ||
#include "core/providers/xnnpack/xnnpack_init.h" | ||
|
||
namespace onnxruntime { | ||
namespace xnnpack { | ||
|
@@ -22,8 +22,10 @@ | |
/*out*/ PrePackedWeights* /*prepacked_weights*/) { | ||
is_packed = false; | ||
// only layout of weight input is adjusted via PrePack | ||
if ((conv_type_ == OpComputeType::op_compute_type_fp32 && input_idx == 1) || | ||
(conv_type_ != OpComputeType::op_compute_type_fp32 && input_idx == 3)) { // InputTensors::IN_W | ||
const bool conv_type_is_float = (conv_type_ == OpComputeType::op_compute_type_fp32 || | ||
conv_type_ == OpComputeType::op_compute_type_fp16); | ||
if ((conv_type_is_float && input_idx == 1) || | ||
(!conv_type_is_float && input_idx == 3)) { // InputTensors::IN_W | ||
// Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} | ||
auto orig_shape = tensor.Shape(); | ||
const auto rank = orig_shape.NumDimensions(); | ||
|
@@ -52,11 +54,9 @@ | |
} | ||
|
||
is_packed = true; | ||
|
||
// we can create the kernel now | ||
ORT_RETURN_IF_ERROR(CreateKernel()); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we remain the blank? |
||
|
||
return Status::OK(); | ||
} | ||
|
||
|
@@ -102,6 +102,8 @@ | |
reshape_fn = xnn_reshape_convolution2d_nhwc_qu8; | ||
} else if (conv_type_ == OpComputeType::op_compute_type_qs8_per_channel) { | ||
reshape_fn = xnn_reshape_convolution2d_nhwc_qs8_qc8w; | ||
} else if (conv_type_ == OpComputeType::op_compute_type_fp16) { | ||
reshape_fn = xnn_reshape_convolution2d_nhwc_f16; | ||
} | ||
|
||
auto status = reshape_fn(op0_.get(), N, H, W, | ||
|
@@ -112,12 +114,14 @@ | |
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_convolution2d_nhwc_", OpTypeToString(conv_type_), | ||
"returned ", status); | ||
} | ||
|
||
workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size)); | ||
|
||
if (conv_type_ == OpComputeType::op_compute_type_fp32) { | ||
status = xnn_setup_convolution2d_nhwc_f32(op0_.get(), workspace.get(), X.Data<float>(), | ||
Y->MutableData<float>()); | ||
} else if (conv_type_ == OpComputeType::op_compute_type_fp16) { | ||
status = xnn_setup_convolution2d_nhwc_f16(op0_.get(), workspace.get(), X.Data<MLFloat16>(), | ||
Y->MutableData<MLFloat16>()); | ||
} else if (conv_type_ == OpComputeType::op_compute_type_qs8) { | ||
status = xnn_setup_convolution2d_nhwc_qs8(op0_.get(), workspace.get(), X.Data<int8_t>(), | ||
Y->MutableData<int8_t>()); | ||
|
@@ -149,6 +153,15 @@ | |
ONNX_OPERATOR_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, kXnnpackExecutionProvider, | ||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), | ||
Conv); | ||
#ifdef XNNPACK_FP16_SUPPORTED | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not quite understand why do we have the macro here? |
||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 1, 10, MLFloat16, kXnnpackExecutionProvider, | ||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()), | ||
Check warning on line 158 in onnxruntime/core/providers/xnnpack/nn/conv.cc GitHub Actions / Optional Lint C++
Check warning on line 158 in onnxruntime/core/providers/xnnpack/nn/conv.cc GitHub Actions / Optional Lint C++
|
||
Conv); | ||
Check warning on line 159 in onnxruntime/core/providers/xnnpack/nn/conv.cc GitHub Actions / Optional Lint C++
|
||
|
||
ONNX_OPERATOR_TYPED_KERNEL_EX(Conv, kMSInternalNHWCDomain, 11, MLFloat16, kXnnpackExecutionProvider, | ||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()), | ||
Check warning on line 162 in onnxruntime/core/providers/xnnpack/nn/conv.cc GitHub Actions / Optional Lint C++
|
||
Conv); | ||
Check warning on line 163 in onnxruntime/core/providers/xnnpack/nn/conv.cc GitHub Actions / Optional Lint C++
|
||
#endif | ||
|
||
ONNX_OPERATOR_TYPED_KERNEL_EX( | ||
QLinearConv, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,6 +81,28 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, | |
foutput_min, foutput_max, flags, | ||
code_cache, weights_cache, | ||
&p); | ||
} else if (conv_type == OpComputeType::op_compute_type_fp16) { | ||
const auto* B_data = Bias ? Bias->Data<MLFloat16>() : nullptr; | ||
// 65504 is the max value of float16 | ||
// https://en.wikipedia.org/wiki/Half-precision_floating-point_format | ||
float output_min = clip_min_max ? onnxruntime::math::floatToHalf(clip_min_max->first) : -65504.0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we assign a half value to float? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add the suffix 'f' after a float constant to avoid some compiler warnings. |
||
float output_max = clip_min_max ? onnxruntime::math::floatToHalf(clip_min_max->second) : 65504.0; | ||
auto create_func = is_transpose ? xnn_create_deconvolution2d_nhwc_f16 | ||
: xnn_create_convolution2d_nhwc_f16; | ||
status = create_func( | ||
input_padding_top, input_padding_right, input_padding_bottom, input_padding_left, | ||
kernel_height, kernel_width, | ||
subsampling_height, subsampling_width, | ||
dilation_height, dilation_width, | ||
group_count, | ||
group_input_channels, | ||
group_output_channels, | ||
C, M, // input channel stride, output channel stride | ||
Weight.Data<MLFloat16>(), B_data, // kernel, bias | ||
output_min, output_max, | ||
flags, | ||
code_cache, weights_cache, | ||
&p); | ||
} else if (conv_type == OpComputeType::op_compute_type_qs8) { | ||
const float output_scale = quant_param[2].first[0]; | ||
const int8_t output_zero_point = quant_param[2].second; | ||
|
@@ -236,6 +258,13 @@ OpComputeType GetConvCompType( | |
return op_compute_type_qu8; | ||
} | ||
break; | ||
case TensorTypeFp16: | ||
if (input_datatype == TensorTypeFp16 && | ||
(!bias_datatype || *bias_datatype == TensorTypeInt32) && | ||
output_datatype == TensorTypeFp16) { | ||
return op_compute_type_fp16; | ||
} | ||
break; | ||
default: | ||
break; | ||
} | ||
|
@@ -326,10 +355,7 @@ bool ConvBase::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& | |
|
||
// we only support float and u8 currently | ||
const auto* x_type = x_arg.TypeAsProto(); | ||
if (x_type == nullptr || | ||
(x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && | ||
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 && | ||
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) { | ||
if (x_type == nullptr || !IsComputeTypeSupported(x_type->tensor_type().elem_type())) { | ||
mszhanyi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
break; | ||
} | ||
// require C, H, W to be known so we can construct the xnnpack kernel prior to Compute | ||
|
@@ -420,9 +446,11 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) | |
input_dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { | ||
weight_index = 3; | ||
conv_type_ = ParseQuantParamAndConType(info, quant_param_, input_dtype); | ||
} else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { | ||
conv_type_ = OpComputeType::op_compute_type_fp16; | ||
} else { | ||
auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X.TypeAsProto())); | ||
ORT_THROW("unsupported Conv in XnnpackEP, we have FLOAT|UINT8|INT8, but got ", stype); | ||
ORT_THROW("unsupported Conv in XnnpackEP, we have FLOAT|UINT8|INT8|FLOAT16, but got ", stype); | ||
} | ||
|
||
ORT_ENFORCE(info.TryGetConstantInput(weight_index, &Weight), | ||
|
@@ -491,7 +519,6 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) | |
|
||
output_shape_.push_back(M_); | ||
} | ||
|
||
// have to delay creating the xnnpack kernel until after the weights are pre-packed. | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
<unordered_set> should be included instead of by the instance of
default_supported_types