diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 188eaed0cffd40..cf9801fe291ffa 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -432,6 +432,7 @@ xla_test( "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index 000032ac42e97a..a06b9d9dc052de 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -125,6 +125,7 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) block_sizes_.back() = block_size_; block_sizes_[permutation_.back()] = block_size_; } + output_block_sizes_ = Permute(block_sizes_, permutation_); block_counts_.resize(block_sizes_.size()); for (int64_t i = 0; i < block_sizes_.size(); ++i) { block_counts_[i] = CeilOfRatio(input_shape_[i], block_sizes_[i]); @@ -198,9 +199,16 @@ LaunchDimensions MlirTransposeFusion::launch_dimensions() const { IndexingMap MlirTransposeFusion::GetSharedMemoryIndexing( bool read, mlir::MLIRContext* ctx) const { - auto thread_offsets = GetThreadOffsets(ctx); + auto thread_offsets = GetThreadOffsets(/*read=*/true, ctx); if (!read) { - absl::c_copy(Permute(thread_offsets, permutation_), thread_offsets.begin()); + // Regarding shared memory indexing, the permutation we need to apply is + // just a swap of the two dimensions that are tiled. + if (MostMinorDimensionUnchanged()) { + std::swap(thread_offsets[thread_offsets.size() - 2], + thread_offsets[permutation_[permutation_.size() - 2]]); + } else { + std::swap(thread_offsets.back(), thread_offsets[permutation_.back()]); + } } std::vector dim_var_sizes(6, 1); dim_var_sizes[KernelFusionInterface::kIndexingMapThreadIdxDims[0]] = @@ -395,7 +403,7 @@ absl::Status MlirTransposeFusion::EmitEntryFunction( } llvm::SmallVector MlirTransposeFusion::GetThreadOffsets( - mlir::MLIRContext* ctx) const { + bool read, mlir::MLIRContext* ctx) const { auto thread = mlir::getAffineDimExpr( KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx); auto loop = mlir::getAffineSymbolExpr(0, ctx); @@ -406,7 +414,8 @@ llvm::SmallVector MlirTransposeFusion::GetThreadOffsets( auto minor_dim = mlir::getAffineSymbolExpr(2, ctx); linear_index = linear_index * input_shape_.back() + minor_dim; } - return DelinearizeInBoundsIndex(linear_index, block_sizes_); + return DelinearizeInBoundsIndex(linear_index, + read ? block_sizes_ : output_block_sizes_); } IndexingMap MlirTransposeFusion::GetIndexing(bool input, @@ -418,10 +427,11 @@ IndexingMap MlirTransposeFusion::GetIndexing(bool input, if (!input) { absl::c_copy(Permute(block_ids, permutation_), block_ids.begin()); } - auto thread_offsets = GetThreadOffsets(ctx); + auto thread_offsets = GetThreadOffsets(input, ctx); + const auto& permuted_block_sizes = input ? block_sizes_ : output_block_sizes_; llvm::SmallVector offsets; for (auto [block_id, block_size, thread] : - llvm::zip(block_ids, block_sizes_, thread_offsets)) { + llvm::zip(block_ids, permuted_block_sizes, thread_offsets)) { offsets.push_back(block_id * block_size + thread); } std::vector dim_var_sizes(6, 1); diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h index 538ad2f77d6df0..afb2777967220e 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h @@ -97,13 +97,14 @@ class MlirTransposeFusion : public MlirFusionEmitterBase { mlir::MLIRContext* ctx) const; IndexingMap GetSharedMemoryIndexing(bool read, mlir::MLIRContext* ctx) const; llvm::SmallVector GetThreadOffsets( - mlir::MLIRContext* ctx) const; + bool read, mlir::MLIRContext* ctx) const; bool MostMinorDimensionUnchanged() const; TransposeDescription transpose_; absl::InlinedVector permutation_; std::vector input_shape_; std::vector block_sizes_; // In input elements. + std::vector output_block_sizes_; std::vector block_counts_; int vector_size_; int block_size_; diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc index 95079ab303e4af..d773503859d934 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "mlir/IR/MLIRContext.h" #include "xla/error_spec.h" #include "xla/service/gpu/fusions/mlir_emitter_test_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" @@ -144,6 +145,72 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201_SimplifiedTo021) { )")); } +TEST_F(MlirTransposeFusionTest, Transpose_ThreadIndexing1302) { + auto kHloString = R"( + HloModule Transpose + + %fused_computation { + %param_0 = f32[19, 16, 16, 144] parameter(0) + ROOT %transpose= f32[16, 144, 19, 16] transpose( %param_0), + dimensions={1,3,0,2} + } + ENTRY main { + %param = f32[19, 16, 16, 144] parameter(0) + ROOT %fusion = f32[16, 144, 19, 16] fusion(%param), kind=kInput, + calls=%fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString)); + auto* root = module->entry_computation()->root_instruction(); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); + + MlirTransposeFusion fusion(analysis); + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( + d3 floordiv 80, + (d3 floordiv 5) mod 16, + d0 floordiv 32 + s0 * 4, + (d3 mod 5) * 32 + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 1519] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 3] + s1 in [0, 0] + (d3 mod 5) * 32 + d0 mod 32 in [0, 143] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( + (d3 floordiv 5) mod 16, + (d3 mod 5) * 32 + s0 * 4 + d0 floordiv 32, + d3 floordiv 80, + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 1519] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 7] + s1 in [0, 0] + (d3 mod 5) * 8 + s0 in [0, 35] + d0 mod 32 in [0, 15] + )")); +} + TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized021) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule module @@ -464,6 +531,26 @@ TEST_F(MlirTransposeFusionTest, Transpose_2D) { calls=%fused_computation } )"; + + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + +TEST_F(MlirTransposeFusionTest, Transpose_4D) { + auto kHloString = R"( + HloModule Transpose + + %fused_computation { + %param_0 = f32[19, 16, 16, 144] parameter(0) + ROOT %transpose= f32[16, 144, 19, 16] transpose( %param_0), + dimensions={1,3,0,2} + } + ENTRY main { + %param = f32[19, 16, 16, 144] parameter(0) + ROOT %fusion = f32[16, 144, 19, 16] fusion(%param), kind=kInput, + calls=%fused_computation + } + )"; TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index 734d8891952e7a..aa638b17843aaf 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -598,36 +598,43 @@ static std::optional FindTiledLogicalTranspose( // call GetNormalizedLogicalTransposeShape here. absl::InlinedVector permutation(instr.dimensions().begin(), instr.dimensions().end()); + // A real transpose needs at least 2 transpose dimensions. + if (permutation.size() < 2) { + return std::nullopt; + } absl::InlinedVector dimensions(instr.shape().dimensions().begin(), instr.shape().dimensions().end()); + int64_t operand_most_minor_dim = + instr.operand(0)->shape().dimensions().back(); if (permutation == absl::InlinedVector{0, 2, 1} || - (IsMlirTransposeEmitterEnabled(instr) && - permutation == absl::InlinedVector{1, 0})) { - if ((dimensions[dimensions.size() - 2] >= kMinDimensionToTransposeTiled && - dimensions.back() >= kMinDimensionToTransposeTiled) || - (dimensions[dimensions.size() - 2] >= kMinDimensionToTransposeTiled2 && - dimensions.back() >= kMinDimensionToTransposeTiled2 && - dimensions[dimensions.size() - 2] * dimensions.back() >= - kMinTotalDimensionsToTransposeTiled)) { - return TransposeDescription{&instr, dimensions, permutation}; - } - } else if (permutation == absl::InlinedVector{2, 1, 0}) { - if ((dimensions[0] >= kMinDimensionToTransposeTiled && - dimensions[2] >= kMinDimensionToTransposeTiled) || - (dimensions[0] >= kMinDimensionToTransposeTiled2 && - dimensions[2] >= kMinDimensionToTransposeTiled2 && - dimensions[0] * dimensions[2] >= + permutation == absl::InlinedVector{2, 1, 0}) { + if ((dimensions.back() >= kMinDimensionToTransposeTiled && + operand_most_minor_dim >= kMinDimensionToTransposeTiled) || + (dimensions.back() >= kMinDimensionToTransposeTiled2 && + operand_most_minor_dim >= kMinDimensionToTransposeTiled2 && + dimensions.back() * operand_most_minor_dim >= kMinTotalDimensionsToTransposeTiled)) { return TransposeDescription{&instr, dimensions, permutation}; } } else if (IsMlirTransposeEmitterEnabled(instr)) { - if (permutation == absl::InlinedVector{1, 0, 2}) { + if (permutation.back() == dimensions.size() - 1) { + operand_most_minor_dim = + instr.operand(0)->shape().dimensions(dimensions.size() - 2); auto byte_width = primitive_util::ByteWidth(instr.shape().element_type()); - if (byte_width * dimensions[2] <= kMaxBytesInMostMinorDimension && - byte_width * dimensions[2] * std::min(dimensions[0], dimensions[1]) >= + if (byte_width * dimensions.back() <= kMaxBytesInMostMinorDimension && + byte_width * dimensions.back() * + std::min(operand_most_minor_dim, + dimensions[dimensions.size() - 2]) >= kMinDimensionToTransposeTiled) { return TransposeDescription{&instr, dimensions, permutation}; } + } else if ((operand_most_minor_dim >= kMinDimensionToTransposeTiled && + dimensions.back() >= kMinDimensionToTransposeTiled) || + (operand_most_minor_dim >= kMinDimensionToTransposeTiled2 && + dimensions.back() >= kMinDimensionToTransposeTiled2 && + operand_most_minor_dim * dimensions.back() >= + kMinTotalDimensionsToTransposeTiled)) { + return TransposeDescription{&instr, dimensions, permutation}; } } return std::nullopt; diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc index e7e571f4f88d5a..a491c8f64ddecc 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc @@ -111,6 +111,52 @@ ENTRY entry { EXPECT_FALSE(result.has_value()); } +TEST_F(IrEmissionUtilsTest, FindTiledLogical2103Transpose) { + const char* hlo = R"( +HloModule module + +ENTRY entry { + p = f32[33,48,32,2]{3,2,1,0} parameter(0) + ROOT t = f32[32,48,33,2]{3,2,1,0} transpose(p), dimensions={2,1,0,3} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); + + HloInstruction* tr = module->entry_computation()->root_instruction(); + + auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result->instr, tr); + EXPECT_EQ(result->dimensions, InlinedVector({32, 48, 33, 2})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0, 3})); +} + +TEST_F(IrEmissionUtilsTest, FindTiledLogical1320Transpose) { + const char* hlo = R"( +HloModule module + +ENTRY entry { + p = f32[33,48,32,34]{3,2,1,0} parameter(0) + ROOT t = f32[48,34,32,33]{3,2,1,0} transpose(p), dimensions={1,3,2,0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); + + HloInstruction* tr = module->entry_computation()->root_instruction(); + + auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result->instr, tr); + EXPECT_EQ(result->dimensions, InlinedVector({48, 34, 32, 33})); + EXPECT_EQ(result->permutation, InlinedVector({1, 3, 2, 0})); +} + TEST_F(IrEmissionUtilsTest, FindTiled102Transpose) { const char* hlo = R"( HloModule module diff --git a/third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc index 454dab1ad0b654..140cc6e52641ea 100644 --- a/third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc @@ -126,12 +126,14 @@ TEST_F(InstructionFusionTest, TEST_F(InstructionFusionTest, CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) { HloComputation::Builder builder(TestName()); - HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); - HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(F32, {}), HloOpcode::kExp, const0)); - HloInstruction* transpose2 = builder.AddInstruction( - HloInstruction::CreateTranspose(ShapeUtil::MakeShape(F32, {}), exp1, {})); + Shape operand_shape = ShapeUtil::MakeShape(F32, {64, 32}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, operand_shape, "param0")); + HloInstruction* exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(operand_shape, HloOpcode::kExp, param)); + HloInstruction* transpose2 = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {32, 64}), exp1, {1, 0})); auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build());