Skip to content

Commit

Permalink
lite: enable group conv -> conv2d conversion
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565268295
  • Loading branch information
zichuan-wei authored and tensorflower-gardener committed Sep 14, 2023
1 parent 20196d5 commit ab7c193
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
14 changes: 14 additions & 0 deletions tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2066,6 +2066,20 @@ func.func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16
func.return %0 : tensor<1x8x8x16xf32>
}

// CHECK-LABEL: func @convert_group_conv2d(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x14x14x2240xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x112x2240xf32>) -> tensor<1x7x7x2240xf32> {
// CHECK: %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [0, 0, 1, 1, 1, 1, 0, 0], padding = "EXPLICIT", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x14x14x2240xf32>, tensor<3x3x112x2240xf32>) -> tensor<1x7x7x2240xf32>
// CHECk: return %[[VAL_2]] : tensor<1x7x7x2240xf32>
// CHECK: }
func.func @convert_group_conv2d(%arg0: tensor<1x14x14x2240xf32>, %arg1: tensor<3x3x112x2240xf32>) -> tensor<1x7x7x2240xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 :i64,
dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
feature_group_count = 20 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_reversal = dense<false> : tensor<2xi1>, window_strides = dense<2> : tensor<2xi64>} :
(tensor<1x14x14x2240xf32>, tensor<3x3x112x2240xf32>) -> tensor<1x7x7x2240xf32>
func.return %0 : tensor<1x7x7x2240xf32>
}

// CHECK-LABEL: func @convert_conv2d_no_padding(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x6x6x207xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x4x4x16xf32> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,19 @@ class Convert2DConvOp : public OpConversionPattern<mhlo::ConvolutionOp>,

mhlo::ConvDimensionNumbersAttr dnums = conv_op.getDimensionNumbers();
const int input_feature_dimension = dnums.getInputFeatureDimension();
const int kernel_input_feature_dimension =
dnums.getKernelInputFeatureDimension();
const int input_channels =
conv_op.getLhs().getType().cast<ShapedType>().getDimSize(
input_feature_dimension);
const int kernel_input_channels =
conv_op.getRhs().getType().cast<ShapedType>().getDimSize(
kernel_input_feature_dimension);
int feature_group_count = conv_op.getFeatureGroupCount();

if (feature_group_count != 1 && feature_group_count != input_channels) {
// Group convolution is not supported yet.
// check if group count is valid
if (feature_group_count != input_channels / kernel_input_channels ||
input_channels % kernel_input_channels != 0) {
return failure();
}

Expand Down

0 comments on commit ab7c193

Please sign in to comment.