diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 72626222a133b7..17cda004c88245 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -277,6 +277,7 @@ cc_library( copts = ["-Ithird_party"], deps = [ ":passes_inc_gen", + "//tensorflow/compiler/mlir/quantization/stablehlo:uniform_quantized_types", "@com_google_absl//absl/algorithm:container", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc index 3a8317ef1d0fdc..4a355834a396ad 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h" #define DEBUG_TYPE "stablehlo-compose-uniform-quantized-type" @@ -44,8 +45,9 @@ namespace mlir { namespace odml { namespace { -using quant::UniformQuantizedPerAxisType; -using quant::UniformQuantizedType; +using ::mlir::quant::CreateI8F32UniformQuantizedType; +using ::mlir::quant::UniformQuantizedPerAxisType; +using ::mlir::quant::UniformQuantizedType; #define GEN_PASS_DEF_COMPOSEUNIFORMQUANTIZEDTYPEPASS #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" @@ -97,20 +99,6 @@ bool IsI32ToF32Cast(stablehlo::ConvertOp convert_op) { return is_i32_operand && is_f32_result; } -// Creates a `UniformQuantizedType` with the given `scale` and `zero_point` -// values. The produced type has f32 as its expressed type and i8 as its -// storage type with default storage type min and max values, set to -128 and -// 127, respectively. -UniformQuantizedType CreateI8F32UniformQuantizedType(Location loc, - PatternRewriter& rewriter, - const double scale, - const int64_t zero_point) { - return UniformQuantizedType::getChecked( - loc, /*flags=*/true, /*storageType=*/rewriter.getI8Type(), - /*expressedType=*/rewriter.getF32Type(), scale, zero_point, - /*storageTypeMin=*/-128, /*storageTypeMax=*/127); -} - // Creates a `UniformQuantizedPerAxisType` with the given `scales` and // `zero_points` values. The produced type has f32 as its expressed type and // i8 as its storage type with default storage type min and max values, set to @@ -693,9 +681,9 @@ class ComposeUniformQuantizedConvolutionOp Value input_value = uniform_quantize_call_pattern_for_input.GetInputValue(); UniformQuantizedType input_quantized_element_type = - CreateI8F32UniformQuantizedType(uniform_quantize_call_op.getLoc(), - rewriter, input_scale_value, - input_zero_point_value); + CreateI8F32UniformQuantizedType( + uniform_quantize_call_op.getLoc(), *rewriter.getContext(), + input_scale_value, input_zero_point_value); auto input_uniform_quantize_op = rewriter.create( uniform_quantize_call_op.getLoc(), @@ -801,7 +789,7 @@ class ComposeUniformQuantizedConvolutionOp UniformQuantizedType output_uniform_quantized_type = CreateI8F32UniformQuantizedType( - output_uniform_quantize_call_op.getLoc(), rewriter, + output_uniform_quantize_call_op.getLoc(), *rewriter.getContext(), /*scale=*/1.0 / output_inverse_scale_value, output_zero_point_value); @@ -1036,9 +1024,9 @@ class ComposeUniformQuantizedDotGeneralOp .getSExtValue(); const UniformQuantizedType input_uniform_quantized_type = - CreateI8F32UniformQuantizedType(input_uniform_quantize_call_op.getLoc(), - rewriter, input_scale_value, - input_zero_point_value); + CreateI8F32UniformQuantizedType( + input_uniform_quantize_call_op.getLoc(), *rewriter.getContext(), + input_scale_value, input_zero_point_value); Value input_value = input_uniform_quantize_call_pattern->GetInputValue(); auto input_uniform_quantize_op = @@ -1157,7 +1145,7 @@ class ComposeUniformQuantizedDotGeneralOp const UniformQuantizedType output_uniform_quantized_type = CreateI8F32UniformQuantizedType( - output_uniform_quantize_call_op.getLoc(), rewriter, + output_uniform_quantize_call_op.getLoc(), *rewriter.getContext(), output_scale_value, output_zero_point_value); auto new_dot_general_op = rewriter.create( @@ -1478,7 +1466,7 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations const UniformQuantizedType input1_uniform_quantized_type = CreateI8F32UniformQuantizedType( - input1_uniform_quantize_call_op.getLoc(), rewriter, + input1_uniform_quantize_call_op.getLoc(), *rewriter.getContext(), input1_scale_value, input1_zero_point_value); Value input1_value = input1_uniform_quantize_call_pattern->GetInputValue(); @@ -1517,7 +1505,7 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations const UniformQuantizedType input2_uniform_quantized_type = CreateI8F32UniformQuantizedType( - input2_uniform_quantize_call_op.getLoc(), rewriter, + input2_uniform_quantize_call_op.getLoc(), *rewriter.getContext(), input2_scale_value, input2_zero_point_value); Value input2_value = input2_uniform_quantize_call_pattern->GetInputValue(); @@ -1566,7 +1554,7 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations const UniformQuantizedType output_uniform_quantized_type = CreateI8F32UniformQuantizedType( - output_uniform_quantize_call_op.getLoc(), rewriter, + output_uniform_quantize_call_op.getLoc(), *rewriter.getContext(), output_scale_value, output_zero_point_value); auto new_dot_general_op = rewriter.create( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 39941daf03dc99..9137d0b83b3383 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -372,3 +372,28 @@ tf_cc_binary( "@stablehlo//:stablehlo_ops", ], ) + +cc_library( + name = "uniform_quantized_types", + srcs = ["uniform_quantized_types.cc"], + hdrs = ["uniform_quantized_types.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "uniform_quantized_types_test", + srcs = ["uniform_quantized_types_test.cc"], + deps = [ + ":uniform_quantized_types", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc new file mode 100644 index 00000000000000..89b372ff72f007 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.cc @@ -0,0 +1,40 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h" + +#include + +#include "llvm/Support/MathExtras.h" +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace mlir { +namespace quant { + +UniformQuantizedType CreateI8F32UniformQuantizedType(Location loc, + MLIRContext& context, + const float scale, + const int8_t zero_point) { + return UniformQuantizedType::getChecked( + loc, /*flags=*/QuantizationFlags::Signed, + /*storageType=*/IntegerType::get(&context, /*width=*/8), + /*expressedType=*/FloatType::getF32(&context), scale, zero_point, + /*storageTypeMin=*/llvm::minIntN(8), /*storageTypeMax=*/llvm::maxIntN(8)); +} + +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h new file mode 100644 index 00000000000000..d7138aee27449d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UNIFORM_QUANTIZED_TYPES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UNIFORM_QUANTIZED_TYPES_H_ + +#include + +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace mlir { +namespace quant { + +// Creates a `UniformQuantizedType` with the given `scale` and `zero_point` +// values. The produced type has f32 as its expressed type and i8 as its +// storage type. The available values use the full range of the storage value, +// i.e. [-128, 127]. Assumes asymmetric quantization, meaning the zero point +// values may be nonzero. +quant::UniformQuantizedType CreateI8F32UniformQuantizedType( + Location loc, MLIRContext& context, float scale, int8_t zero_point); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UNIFORM_QUANTIZED_TYPES_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc new file mode 100644 index 00000000000000..bab4f00a86004e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h" + +#include +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace mlir { +namespace quant { +namespace { + +using ::mlir::quant::UniformQuantizedType; + +class CreateI8F32UniformQuantizedTypeTest : public ::testing::Test { + protected: + CreateI8F32UniformQuantizedTypeTest() : ctx_() { + ctx_.loadDialect(); + } + + MLIRContext ctx_; +}; + +TEST_F(CreateI8F32UniformQuantizedTypeTest, HasI8StorageType) { + const UniformQuantizedType quantized_type = + CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, + /*scale=*/1.0, /*zero_point=*/0); + + EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(8)); +} + +TEST_F(CreateI8F32UniformQuantizedTypeTest, HasF32ExpressedType) { + const UniformQuantizedType quantized_type = + CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, + /*scale=*/1.0, /*zero_point=*/0); + + EXPECT_TRUE(quantized_type.getExpressedType().isF32()); +} + +TEST_F(CreateI8F32UniformQuantizedTypeTest, IsSigned) { + const UniformQuantizedType quantized_type = + CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, + /*scale=*/1.0, /*zero_point=*/0); + + EXPECT_TRUE(quantized_type.isSigned()); +} + +TEST_F(CreateI8F32UniformQuantizedTypeTest, SotrageTypeMinMaxEqualToI8MinMax) { + const UniformQuantizedType quantized_type = + CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, + /*scale=*/1.0, /*zero_point=*/0); + + EXPECT_EQ(quantized_type.getStorageTypeMin(), -128); + EXPECT_EQ(quantized_type.getStorageTypeMax(), 127); +} + +TEST_F(CreateI8F32UniformQuantizedTypeTest, HasScaleAndZeroPointProperlySet) { + const UniformQuantizedType quantized_type = + CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_, + /*scale=*/8.0, /*zero_point=*/99); + + EXPECT_EQ(quantized_type.getScale(), 8.0); + EXPECT_EQ(quantized_type.getZeroPoint(), 99); +} + +} // namespace +} // namespace quant +} // namespace mlir