Skip to content

Commit

Permalink
Merged PR 5420139: DmlDev RI from GitHub master 2020-11-18 #2
Browse files Browse the repository at this point in the history
Fix some more future merge issues manually.
  • Loading branch information
fdwr committed Nov 18, 2020
2 parents 395fac9 + 26882c2 commit 63ee7b4
Show file tree
Hide file tree
Showing 27 changed files with 3,008 additions and 1,946 deletions.
8 changes: 5 additions & 3 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DynamicQuantizeMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, MatMulIntegerToFloat);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DynamicQuantizeLSTM);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConv);
// ******** End: Quantization ******************* //

Expand Down Expand Up @@ -92,7 +93,7 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {

Status RegisterNchwcKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<void>, //default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<void>, //default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderInput)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderOutput)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, Conv)>,
Expand All @@ -115,7 +116,7 @@ Status RegisterNchwcKernels(KernelRegistry& kernel_registry) {

Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<void>, //default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<void>, //default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear)>,
Expand All @@ -132,6 +133,7 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DynamicQuantizeMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, MatMulIntegerToFloat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DynamicQuantizeLSTM)>,
#if defined(MLAS_TARGET_AMD64_IX86)
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConv)>,
#endif
Expand All @@ -149,7 +151,7 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {

Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<void>, //default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<void>, //default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp)>,

// add more kernels here
Expand Down
39 changes: 19 additions & 20 deletions onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,16 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
// A: input (BxSxNxH) (B.)S x NH S x NH
// B: weights (NxHx3xNxH) NH x (3.N.)H NH x H
// C: QKV[qkv_index] (3xBxNxSxH) (3.B.N.)S x H S x H

MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(qkv_dest + qkv_offset,
head_size,
&dequant_scale,
bias_data + weights_offset);
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
if (packed_weights_) {
const auto* packed_weight =
static_cast<const uint8_t*>(packed_weights_.get()) + packed_weights_size_ * (weights_offset / head_size);

MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(qkv_dest + qkv_offset,
head_size,
&dequant_scale,
bias_data + weights_offset);
MlasGemm(
sequence_length, // M = S
head_size, // N = H
Expand All @@ -229,22 +230,20 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
continue;
}
#endif
QGemm(sequence_length, // M = S
head_size, // N = H
hidden_size, // K = NH
input_data + input_offset, // A
hidden_size, // lda = NH
input_zero_point, // input zero point
weights_data + weights_offset, // B
3 * hidden_size, // ldb = 3NH
weight_zero_point, // weight zero point
weights_is_signed, // weight data type
qkv_dest + qkv_offset, // C
head_size, // ldc
&dequant_scale, // output scale
bias_data + weights_offset, // bias
nullptr // use single-thread
);
QGemm(sequence_length, // M = S
head_size, // N = H
hidden_size, // K = NH
input_data + input_offset, // A
hidden_size, // lda = NH
input_zero_point, // input zero point
weights_data + weights_offset, // B
3 * hidden_size, // ldb = 3NH
weight_zero_point, // weight zero point
weights_is_signed, // weight data type
reinterpret_cast<int32_t*>(qkv_dest + qkv_offset), // C
head_size, // ldc
nullptr, // use single-thread
&scale_bias_processor); // post processor
}
});
}
Expand Down
202 changes: 202 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_lstm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
#include "core/providers/cpu/rnn/lstm_base.h"
#include "core/providers/cpu/rnn/rnn_helpers.h"
#include "core/providers/cpu/rnn/uni_directional_lstm.h"

namespace onnxruntime {
namespace contrib {

using namespace rnn::detail;

class DynamicQuantizeLSTM : public OpKernel, public LSTMBase {
public:
DynamicQuantizeLSTM(const OpKernelInfo& info) : OpKernel(info), LSTMBase(info) {}

#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override;
#endif

Status Compute(OpKernelContext* context) const override;

~DynamicQuantizeLSTM() override = default;

private:
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
Status TryPackWeights(const Tensor& weights, PackedWeights& packed_weights, bool& is_packed, bool& is_weight_signed);
#endif

template <typename T>
Status ComputeImpl(OpKernelContext& context) const;

PackedWeights packed_W_;
PackedWeights packed_R_;
bool is_W_signed_;
bool is_R_signed_;
};

#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
Status DynamicQuantizeLSTM::TryPackWeights(const Tensor& weights, PackedWeights& packed_weights, bool& is_packed, bool& is_weight_signed) {
const auto& shape = weights.Shape();
if (shape.NumDimensions() != 3) {
return Status::OK();
}

// weights: [num_directions, input_size, 4*hidden_size]
// recurrence weights: [num_directions, hidden_size, 4*hidden_size]
const size_t K = static_cast<size_t>(shape[1]);
const size_t N = static_cast<size_t>(shape[2]);

if ((shape[0] != num_directions_) || (N != static_cast<size_t>(hidden_size_ * 4))) {
return Status::OK();
}

is_weight_signed = weights.IsDataType<int8_t>();
const size_t packed_weights_size = MlasGemmPackBSize(N, K, is_weight_signed);
if (packed_weights_size == 0) {
return Status::OK();
}

auto alloc = Info().GetAllocator(0, OrtMemTypeDefault);
auto* packed_weights_data = alloc->Alloc(SafeInt<size_t>(packed_weights_size) * num_directions_);
packed_weights.buffer_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
packed_weights.weights_size_ = packed_weights_size;
packed_weights.shape_ = shape;

const auto* weights_data = static_cast<const uint8_t*>(weights.DataRaw());
for (int i = 0; i < num_directions_; i++) {
MlasGemmPackB(N, K, weights_data, N, is_weight_signed, packed_weights_data);
packed_weights_data = static_cast<uint8_t*>(packed_weights_data) + packed_weights_size;
weights_data += N * K;
}

is_packed = true;
return Status::OK();
}

Status DynamicQuantizeLSTM::PrePack(const Tensor& tensor, int input_idx, bool& is_packed) {
is_packed = false;

if (input_idx == 1) {
return TryPackWeights(tensor, packed_W_, is_packed, is_W_signed_);
} else if (input_idx == 2) {
return TryPackWeights(tensor, packed_R_, is_packed, is_R_signed_);
}

return Status::OK();
}
#endif

#define WeightCheck(weight_shape, weight_name) \
if (weight_shape.NumDimensions() != 1 && weight_shape.NumDimensions() != 2 || \
weight_shape.NumDimensions() == 2 && weight_shape[1] != hidden_size_ * 4 || \
weight_shape[0] != num_directions_) { \
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, \
"Input ", #weight_name, " must have shape {", num_directions_, "} for per-tensor/layer quantization or shape {", \
num_directions_, ", 4*", hidden_size_, "} for per-channel quantization. Actual:", weight_shape); \
}

#define ZeroPointCheck(w_zp, zp_shape, is_W_signed, weight_name) \
if (zp_shape.NumDimensions() == 2) { \
const int64_t zp_size = zp_shape.Size(); \
const uint8_t* w_zp_data = static_cast<const uint8_t*>(w_zp->DataRaw()); \
if (is_W_signed) { \
for (int64_t i = 0; i < zp_size; i++) { \
if (w_zp_data[i] != 0) { \
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "DynamicQuantizeLSTM : ", #weight_name, "Weight zero point must be zero"); \
} \
} \
} else { \
const uint8_t W_zero_point_value = w_zp_data[0]; \
for (int64_t i = 1; i < zp_size; i++) { \
if (w_zp_data[i] != W_zero_point_value) { \
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "DynamicQuantizeLSTM : ", #weight_name, "Weight point must be constant"); \
} \
} \
} \
}

Status DynamicQuantizeLSTM::Compute(OpKernelContext* context) const {
// weights. [num_directions, input_size, 4*hidden_size]
const Tensor* W = packed_W_.buffer_ ? nullptr : context->Input<Tensor>(1);
// recurrence weights. [num_directionshidden_size, 4*hidden_size]
const Tensor* R = packed_R_.buffer_ ? nullptr : context->Input<Tensor>(2);

const auto& W_shape = (W != nullptr) ? W->Shape() : packed_W_.shape_;
const auto& R_shape = (R != nullptr) ? R->Shape() : packed_R_.shape_;

const Tensor* w_scale = context->Input<Tensor>(8);
const Tensor* w_zp = context->Input<Tensor>(9);
const Tensor* r_scale = context->Input<Tensor>(10);
const Tensor* r_zp = context->Input<Tensor>(11);

const TensorShape& W_zp_shape = w_zp->Shape();
const TensorShape& R_zp_shape = w_zp->Shape();
const TensorShape& W_scale_shape = w_scale->Shape();
const TensorShape& R_scale_shape = r_scale->Shape();

WeightCheck(W_zp_shape, W_zero_point);
WeightCheck(R_zp_shape, R_zero_point);
WeightCheck(W_scale_shape, W_scale);
WeightCheck(W_scale_shape, R_scale);

const bool is_W_signed = (W != nullptr) ? W->IsDataType<int8_t>() : is_W_signed_;
const bool is_R_signed = (R != nullptr) ? R->IsDataType<int8_t>() : is_R_signed_;

ZeroPointCheck(w_zp, W_zp_shape, is_W_signed, Input);
ZeroPointCheck(r_zp, R_zp_shape, is_R_signed, Recurrent);

size_t W_scale_size = W_scale_shape.NumDimensions() == 2 ? W_scale_shape[1] : 1;
size_t R_scale_size = R_scale_shape.NumDimensions() == 2 ? R_scale_shape[1] : 1;

QuantizationParameter quant_para_W_1(w_scale->Data<float>(),
static_cast<const uint8_t*>(w_zp->DataRaw()),
is_W_signed,
W_scale_size);
QuantizationParameter quant_para_R_1(r_scale->Data<float>(),
static_cast<const uint8_t*>(r_zp->DataRaw()),
is_R_signed,
R_scale_size);

const uint8_t* W_data = W != nullptr ? static_cast<const uint8_t*>(W->DataRaw()) : nullptr;
const uint8_t* R_data = R != nullptr ? static_cast<const uint8_t*>(R->DataRaw()) : nullptr;

// spans for first direction
const size_t W_size_per_direction = W_shape[1] * W_shape[2];
const size_t R_size_per_direction = R_shape[1] * R_shape[2];

GemmWeights<uint8_t> W_1(0, W_data, W_size_per_direction, packed_W_, &quant_para_W_1);
GemmWeights<uint8_t> R_1(0, R_data, R_size_per_direction, packed_R_, &quant_para_R_1);

GemmWeights<uint8_t> W_2;
GemmWeights<uint8_t> R_2;

QuantizationParameter quant_para_W_2(quant_para_W_1);
QuantizationParameter quant_para_R_2(quant_para_R_1);

if (direction_ == Direction::kBidirectional) {
quant_para_W_2.scale += W_scale_size;
quant_para_R_2.scale += R_scale_size;

quant_para_W_2.zero_point += W_scale_size; // zero_point and scale have same size
quant_para_R_2.zero_point += R_scale_size; // zero_point and scale have same size

W_2.Init(1, W_data, W_size_per_direction, packed_W_, &quant_para_W_2);
R_2.Init(1, R_data, R_size_per_direction, packed_R_, &quant_para_R_2);
}

return LSTMBase::ComputeImpl<float, uint8_t>(*context, W_1, W_2, R_1, R_2);
}

ONNX_OPERATOR_TYPED_KERNEL_EX(
DynamicQuantizeLSTM,
kMSDomain,
1,
float,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int32_t>())
.TypeConstraint("T2", {DataTypeImpl::GetTensorType<uint8_t>(), DataTypeImpl::GetTensorType<int8_t>()}),
DynamicQuantizeLSTM);

} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();

for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(
y_data + helper.OutputOffsets()[i],
static_cast<size_t>(helper.N()),
&multiplier,
bias_data);

#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
if (packed_b_) {
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(y_data + helper.OutputOffsets()[i],
Expand Down Expand Up @@ -84,11 +90,10 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
static_cast<int>(helper.N()),
b_zero_point,
b_is_signed,
y_data + helper.OutputOffsets()[i],
reinterpret_cast<int32_t*>(y_data + helper.OutputOffsets()[i]),
static_cast<int>(helper.N()),
&multiplier,
bias_data,
thread_pool);
thread_pool,
&scale_bias_processor);
}

return Status::OK();
Expand All @@ -108,24 +113,6 @@ class MatMulIntegerToFloat final : public MatMulIntegerToFloatBase {
Status Compute(OpKernelContext* context) const override;
};

static void GetQuantizationParameter(const float* data, int64_t num_of_elements, float& scale, uint8_t& zp) {
// find input range min and max
float min, max;
MlasFindMinMaxElement(data, &min, &max, num_of_elements);

// ensure the input range includes zero
min = std::min(min, 0.0f);
max = std::max(max, 0.0f);

// find scale and zero point
uint8_t qmin = 0;
uint8_t qmax = 255;
scale = max == min ? 1.0f : (max - min) / (qmax - qmin);

float initial_zero_point = qmin - min / scale;
zp = static_cast<uint8_t>(RoundHalfToEven(std::max(float(qmin), std::min(float(qmax), initial_zero_point))));
}

Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
const Tensor* a = ctx->Input<Tensor>(0);
const Tensor* b = packed_b_ ? nullptr : ctx->Input<Tensor>(1);
Expand Down
Loading

0 comments on commit 63ee7b4

Please sign in to comment.