From 15324c4d313556d0d1324a33ce66e9172534fd1d Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 9 Oct 2024 17:17:00 +0300 Subject: [PATCH] Add EvalTranspose pattern to StablehloAggressiveFolder (#2570) This patch folds `stablehlo.transpose` operation with constant operand into `stablehlo.constant`. I have considered doing this by iterating over source index space, i.e. ``` auto initialValue = *std::begin(data); SmallVector result(resultType.getNumElements(), initialValue); for (int64_t i = 0; i < operandType.getNumElements(); ++i) { auto srcDimIndex = delinearize(i, operandStrides); auto dstDimIndex = applyPermutation(srcDimIndex, permutation); auto dstLinearIndex = linearize(dstDimIndex, resultStrides); result[dstLinearIndex] = data[i]; } ``` but that requires preinitialization of result vector with some value, which is twice as slow on simple case: ``` func.func @eval_transpose() -> (tensor<5000x80x30xi32>) { %0 = stablehlo.iota dim = 0 : tensor<30x80x5000xi32> %1 = stablehlo.transpose %0, dims = [2, 1, 0] : (tensor<30x80x5000xi32>) -> tensor<5000x80x30xi32> func.return %1 : tensor<5000x80x30xi32> } ``` --- .../stablehlo_aggressive_folder.mlir | 64 +++++++++++++++++++ .../transforms/StablehloAggressiveFolder.cpp | 55 ++++++++++++++++ 2 files changed, 119 insertions(+) diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir index df87f70469f..914063131e8 100644 --- a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir +++ b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir @@ -110,3 +110,67 @@ func.func @eval_convert_f64_precision_loss() -> (tensor<1xf32>, tensor) { %3 = stablehlo.convert %1 : (tensor) -> tensor func.return %2, %3 : tensor<1xf32>, tensor } + +// ----- + +// CHECK-LABEL: func @eval_transpose +func.func @eval_transpose() -> (tensor<2x3x2xi32>, tensor<2x4x3xi32>, tensor<4x3x2xi32>) { + // CHECK: [[RESULT0:%.*]] = stablehlo.constant dense< + // CHECK-SAME: {{\[\[}}[1, 7], [3, 9], [5, 11]], + // CHECK-SAME: {{\[}}[2, 8], [4, 10], [6, 12]]]> : tensor<2x3x2xi32> + // + // CHECK: [[RESULT1:%.*]] = stablehlo.constant dense< + // CHECK-SAME: {{\[\[}}[1, 3, 5], [7, 9, 11], [13, 15, 17], [19, 21, 23]], + // CHECK-SAME: {{\[}}[2, 4, 6], [8, 10, 12], [14, 16, 18], [20, 22, 24]]]> : tensor<2x4x3xi32> + // + // CHECK: [[RESULT2:%.*]] = stablehlo.constant dense< + // CHECK-SAME: {{\[\[}}[1, 2], [3, 4], [5, 6]] + // CHECK-SAME: {{\[}}[7, 8], [9, 10], [11, 12]], + // CHECK-SAME: {{\[}}[13, 14], [15, 16], [17, 18]], + // CHECK-SAME: {{\[}}[19, 20], [21, 22], [23, 24]]]> : tensor<4x3x2xi32> + // + // CHECK: return [[RESULT0]], [[RESULT1]], [[RESULT2]] + %0 = stablehlo.constant dense<[[[1,2], [3,4], [5,6]], + [[7,8], [9,10], [11,12]]]> : tensor<2x3x2xi32> + %1 = stablehlo.constant dense<[[[1, 2], [3, 4], [5, 6]], + [[7, 8], [9, 10], [11,12]], + [[13,14], [15,16], [17,18]], + [[19,20], [21,22], [23,24]]]> : tensor<4x3x2xi32> + %2 = stablehlo.transpose %0, dims = [2, 1, 0] : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32> + %3 = stablehlo.transpose %1, dims = [2, 0, 1] : (tensor<4x3x2xi32>) -> tensor<2x4x3xi32> + %4 = stablehlo.transpose %3, dims = [1, 2, 0] : (tensor<2x4x3xi32>) -> tensor<4x3x2xi32> + func.return %2, %3, %4 : tensor<2x3x2xi32>, tensor<2x4x3xi32>, tensor<4x3x2xi32> +} + +// ----- + +// CHECK-LABEL: func @eval_transpose_zerodim +func.func @eval_transpose_zerodim() -> (tensor<10x3x0xf32>) { + // CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<> : tensor<10x3x0xf32> + // CHECK: return [[RESULT0]] + %0 = stablehlo.constant dense<> : tensor<3x0x10xf32> + %1 = stablehlo.transpose %0, dims = [2, 0, 1] : (tensor<3x0x10xf32>) -> tensor<10x3x0xf32> + func.return %1 : tensor<10x3x0xf32> +} + +// ----- + +// CHECK-LABEL: func @eval_transpose_zerorank +func.func @eval_transpose_zerorank() -> tensor { + // CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<1> : tensor + // CHECK: return [[RESULT0]] + %0 = stablehlo.constant dense<1> : tensor + %1 = stablehlo.transpose %0, dims = [] : (tensor) -> tensor + func.return %1 : tensor +} + +// ----- + +// CHECK-LABEL: func @eval_transpose_splat +func.func @eval_transpose_splat() -> (tensor<10x3x1xi32>) { + // CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<1> : tensor<10x3x1xi32> + // CHECK: return [[RESULT0]] + %0 = stablehlo.constant dense<1> : tensor<3x1x10xi32> + %1 = stablehlo.transpose %0, dims = [2, 0, 1] : (tensor<3x1x10xi32>) -> tensor<10x3x1xi32> + func.return %1 : tensor<10x3x1xi32> +} diff --git a/stablehlo/transforms/StablehloAggressiveFolder.cpp b/stablehlo/transforms/StablehloAggressiveFolder.cpp index a5768f4c9ff..2aee5398bc3 100644 --- a/stablehlo/transforms/StablehloAggressiveFolder.cpp +++ b/stablehlo/transforms/StablehloAggressiveFolder.cpp @@ -643,6 +643,60 @@ struct EvalIotaOpPattern : public OpRewritePattern { } }; +template +DenseElementsAttr transposeType(TransposeOp& op, const RangeType& data) { + using ElementType = std::decay_t; + + RankedTensorType operandType = op.getOperand().getType(); + RankedTensorType resultType = op.getResult().getType(); + + const auto operandStrides = computeStrides(operandType.getShape()); + const auto resultStrides = computeStrides(resultType.getShape()); + const auto inversePermutation = invertPermutationVector(op.getPermutation()); + + SmallVector result; + result.reserve(resultType.getNumElements()); + + for (int64_t i = 0; i < resultType.getNumElements(); ++i) { + auto dstDimIndex = delinearize(i, resultStrides); + auto srcDimIndex = applyPermutation(dstDimIndex, inversePermutation); + auto srcLinearIndex = linearize(srcDimIndex, operandStrides); + result.push_back(data[srcLinearIndex]); + } + + return DenseElementsAttr::get(resultType, ArrayRef(result)); +} + +// transpose(constant) => constant with permuted dimensions +// This covers ranked tensor types with 0 dimensions(zero elements) and 0 +// rank(scalar), as well as splat values. +struct EvalTransposeOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TransposeOp op, + PatternRewriter& rewriter) const override { + auto resultType = op.getType(); + if (failed(validateResultTypeForEval(rewriter, op, resultType))) + return failure(); + + ElementsAttr els; + if (!matchPattern(op.getOperand(), m_Constant(&els))) + return rewriter.notifyMatchFailure( + op, "expected constant integer or float operand"); + + DenseElementsAttr resAttr; + if (auto data = els.tryGetValues()) + resAttr = transposeType(op, *data); + else if (auto data = els.tryGetValues()) + resAttr = transposeType(op, *data); + else + return rewriter.notifyMatchFailure(op.getLoc(), + "unsupported element type"); + + rewriter.replaceOpWithNewOp(op, resAttr); + return success(); + } +}; + struct StablehloAggressiveFolderPass : public impl::StablehloAggressiveFolderPassBase< StablehloAggressiveFolderPass> { @@ -672,6 +726,7 @@ void populateStablehloAggressiveFolderPatterns(RewritePatternSet* patterns, bool foldFloat) { populateStablehloShapeFolderPatterns(patterns, context, foldFloat); patterns->add(context); + patterns->add(context); } void populateStablehloShapeFolderPatterns(RewritePatternSet* patterns,