From d35ae2f8c5810615ca11ea3d5c44dec9250a952c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 17 Aug 2024 19:31:30 -0700 Subject: [PATCH] Rollback of [XLA:SPMD] Remove LookaheadUserSharding in sharding propagation. Reverts 339d798df2aff84af89bf269cab2ff3743dea4be PiperOrigin-RevId: 664283560 --- .../auto_sharding_dot_handler.cc | 8 +- .../xla/xla/service/sharding_propagation.cc | 153 +++++++++++------- .../xla/xla/service/sharding_propagation.h | 9 +- .../xla/service/sharding_propagation_test.cc | 108 ++----------- 4 files changed, 117 insertions(+), 161 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 815085f4a1294d..dbd161b365d698 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -446,13 +446,13 @@ std::optional HandlerBase::GetShardingFromUser( CHECK_OK(ins_clone->ReplaceOperandWith(1, rhs_clone.get())); if (ins_->opcode() == HloOpcode::kConvolution) { xla::InferConvolutionShardingFromOperands( - ins_clone.get(), /* aggressiveness */ 10, - /* may_combine_partial_sharding */ true); + ins_clone.get(), call_graph_, 10, + /* may_combine_partial_sharding */ true, /* is_spmd */ true); } else { xla::InferDotShardingFromOperands( - ins_clone.get(), + ins_clone.get(), call_graph_, dot_as_convolution_util::ParseDotGeneralFromDot(ins_clone.get()), - /* aggressiveness */ 10, /* may_combine_partial_sharding */ true); + /* may_combine_partial_sharding/ */ true, /* is_spmd */ true); } if (!ins_clone->has_sharding()) { return std::nullopt; diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index 45d47b44a096fd..5239d6c7d30575 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -16,10 +16,10 @@ limitations under the License. #include "xla/service/sharding_propagation.h" #include -#include #include #include #include +#include #include #include #include @@ -36,12 +36,10 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -49,7 +47,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding_metadata.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/protobuf_util.h" -#include "xla/service/call_graph.h" #include "xla/service/dot_as_convolution_util.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/service/spmd/shard_barrier_partitioner.h" @@ -419,6 +416,55 @@ bool SupportSpatialPartitioning( } } +// Helper to lookahead sharding of user of an instruction to be used as guidance +// for ambiguous cases. +std::optional LookaheadUserSharding(HloInstruction* instr, + bool is_spmd, + const CallGraph& call_graph) { + if (instr->user_count() != 1) { + return std::nullopt; + } + HloInstruction* current_user = instr->users()[0]; + std::optional sharding; + std::vector users_chain = {instr, current_user}; + // Collect single user instructions along the way. + while (!current_user->has_sharding()) { + // Only consider single user chains. + if (current_user->users().size() != 1) { + users_chain.clear(); + break; + } + current_user = current_user->users()[0]; + users_chain.push_back(current_user); + } + // Early exit for unsupported cases. + if (users_chain.empty()) { + return std::nullopt; + } + for (int i = users_chain.size() - 1; i >= 1; --i) { + HloInstruction* user = users_chain[i]; + HloInstruction* current = users_chain[i - 1]; + CHECK(user->has_sharding()); + sharding = ShardingPropagation::GetShardingFromUser( + *current, *user, INT64_MAX, is_spmd, call_graph, + /*sharding_helper=*/nullptr); + // We need to set the sharding to the instruction, because + // GetShardingFromUser() interface uses sharding from the instruction + // itself. It will be cleared out later. + if (sharding.has_value() && i != 1) { + current->set_sharding(*sharding); + continue; + } + break; + } + // Clear the sharding of the middle instructions we set the sharding of + // because they were unsharded. + for (int i = 1; i < users_chain.size() - 1; ++i) { + users_chain[i]->clear_sharding(); + } + return sharding; +} + // Infer output sharding on index parallel dimensions for gather from operand // and indices. bool InferGatherParallelShardingFromOperands( @@ -1025,9 +1071,9 @@ bool IsCSEPreventionSharding(const HloSharding& sharding) { } // namespace bool InferDotShardingFromOperands( - HloInstruction* instruction, + HloInstruction* instruction, const CallGraph& call_graph, const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, - int64_t aggressiveness, bool may_combine_partial_sharding) { + bool may_combine_partial_sharding, bool is_spmd) { auto from_operand = [&](int64_t operand_index) { auto operand = instruction->operand(operand_index); const HloSharding& operand_sharding = operand->sharding(); @@ -1082,66 +1128,55 @@ bool InferDotShardingFromOperands( from_operand(1), instruction, may_combine_partial_sharding, /*allow_aggressive_resharding=*/false); } - - // Four cases based on if improved_operand_0 and improved_operand_1 are - // available. - // Case 0. Both operands have no improved sharding. + // If not improved sharding found then do not set any sharding. if (!improved_operand_0.has_value() && !improved_operand_1.has_value()) { return false; } - // Case 1. Sharding found from operand 0 but not operand 1. Set sharding from - // operand 0. + // Sharding found from operand 0 but not operand 1. Set sharding from operand + // 0 if (improved_operand_0.has_value() && !improved_operand_1.has_value()) { instruction->set_sharding(*improved_operand_0); return true; } - // Case 2. Sharding found from operand 1 but not operand 0. Set sharding from - // operand 1. + // Sharding found from operand 1 but not operand 0. Set sharding from operand + // 1 if (!improved_operand_0.has_value() && improved_operand_1.has_value()) { instruction->set_sharding(*improved_operand_1); return true; } - // Case 3. Both operands have improved shardings. CHECK(improved_operand_0.has_value() && improved_operand_1.has_value()); - - // If one of the improved shardings is a sub-tiling or equal to the other, use - // the better sharding with more tiles. - if (hlo_sharding_util::IsSubTilingOrEqualSharding( - instruction->shape(), *improved_operand_0, *improved_operand_1)) { - instruction->set_sharding(*improved_operand_0); - return true; - } - if (hlo_sharding_util::IsSubTilingOrEqualSharding( - instruction->shape(), *improved_operand_1, *improved_operand_0)) { - instruction->set_sharding(*improved_operand_1); - return true; - } - - // If the two improved shardings are mergeable, there is no conflict. - if (std::optional improved_sharding = - hlo_sharding_util::ReturnImprovedShardingImpl( - *improved_operand_0, &improved_operand_1.value(), - instruction->shape(), may_combine_partial_sharding, - /*allow_aggressive_resharding=*/false)) { - instruction->set_sharding(*improved_sharding); - return true; - } - - if (aggressiveness < 3) { - // We can improve the dot with different shardings. Pause the propagation - // and wait for the winner between the two operands. - return false; - } - - // The two improved sharding are different and we are at the highest - // aggressiveness. Prioritize the operand with larger size. + std::optional lookahead_sharding = + LookaheadUserSharding(instruction, is_spmd, call_graph); std::array sharding_priority = {*improved_operand_0, *improved_operand_1}; - if (ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) < - ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) { + bool priority_defined_with_lookahead = false; + // Found sharding from lookahead. + if (lookahead_sharding.has_value()) { + const bool operand_0_is_lookahead_subtiling = + hlo_sharding_util::IsSubTilingOrEqualSharding( + instruction->shape(), *lookahead_sharding, *improved_operand_0); + const bool operand_1_is_lookahead_subtiling = + hlo_sharding_util::IsSubTilingOrEqualSharding( + instruction->shape(), *lookahead_sharding, *improved_operand_1); + // If the sharding from operand 0 is a subtiling of the user, but not the + // one from operand 1 prioritize that sharding. + if (operand_0_is_lookahead_subtiling && !operand_1_is_lookahead_subtiling) { + priority_defined_with_lookahead = true; + } + // If the sharding from operand 1 is a subtiling of the user, but not the + // one from operand 0 prioritize that sharding. + if (!operand_0_is_lookahead_subtiling && operand_1_is_lookahead_subtiling) { + instruction->set_sharding(*improved_operand_1); + std::swap(sharding_priority[0], sharding_priority[1]); + priority_defined_with_lookahead = true; + } + } + // If lookahead didn't define a priority then use size. + if (!priority_defined_with_lookahead && + ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) < + ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) { std::swap(sharding_priority[0], sharding_priority[1]); } - // Set primary sharding to the instruction and then try to improve it with // the secondary sharding. instruction->set_sharding(sharding_priority[0]); @@ -1152,8 +1187,10 @@ bool InferDotShardingFromOperands( // Convolution handling for InferShardingFromOperands(). bool InferConvolutionShardingFromOperands(HloInstruction* instruction, + const CallGraph& call_graph, int64_t aggressiveness, - bool may_combine_partial_sharding) { + bool may_combine_partial_sharding, + bool is_spmd) { auto get_partitions_for_dims = [&](const HloInstruction* inst, absl::Span< @@ -1188,8 +1225,8 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, (lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 && instruction->batch_group_count() == 1 && instruction->feature_group_count() == 1)) { - return InferDotShardingFromOperands(instruction, dot_dims, aggressiveness, - may_combine_partial_sharding); + return InferDotShardingFromOperands(instruction, call_graph, dot_dims, + may_combine_partial_sharding, is_spmd); } const auto& dnums = instruction->convolution_dimension_numbers(); const HloInstruction* lhs = instruction->operand(0); @@ -2292,8 +2329,9 @@ bool ShardingPropagation::InferShardingFromOperands( 1); } case HloOpcode::kConvolution: - return InferConvolutionShardingFromOperands(instruction, aggressiveness, - may_combine_partial_sharding); + return InferConvolutionShardingFromOperands( + instruction, call_graph, aggressiveness, may_combine_partial_sharding, + is_spmd_); case HloOpcode::kTranspose: { const HloInstruction* input = instruction->operand(0); if (!hlo_sharding_util::IsSpatiallyPartitioned(input)) { @@ -2382,8 +2420,9 @@ bool ShardingPropagation::InferShardingFromOperands( case HloOpcode::kDot: { const auto& dnums = dot_as_convolution_util::ParseDotGeneralFromDot(instruction); - return InferDotShardingFromOperands(instruction, dnums, aggressiveness, - may_combine_partial_sharding); + return InferDotShardingFromOperands(instruction, call_graph, dnums, + may_combine_partial_sharding, + is_spmd_); } case HloOpcode::kParameter: { auto parent_it = computation_map.find(instruction->parent()); diff --git a/third_party/xla/xla/service/sharding_propagation.h b/third_party/xla/xla/service/sharding_propagation.h index 8d8289de719a13..22cb7af042545d 100644 --- a/third_party/xla/xla/service/sharding_propagation.h +++ b/third_party/xla/xla/service/sharding_propagation.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef XLA_SERVICE_SHARDING_PROPAGATION_H_ #define XLA_SERVICE_SHARDING_PROPAGATION_H_ -#include #include #include #include @@ -36,15 +35,17 @@ namespace xla { // Infers the shardings for a dot HLO op from the shardings on its operands, // which are expected to have sharding annotations. bool InferDotShardingFromOperands( - HloInstruction* instruction, + HloInstruction* instruction, const CallGraph& call_graph, const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, - int64_t aggressiveness, bool may_combine_partial_sharding); + bool may_combine_partial_sharding, bool is_spmd); // Infers the shardings for a convolution HLO op from the shardings on its // operands, which are expected to have sharding annotations. bool InferConvolutionShardingFromOperands(HloInstruction* instruction, + const CallGraph& call_graph, int64_t aggressiveness, - bool may_combine_partial_sharding); + bool may_combine_partial_sharding, + bool is_spmd); // Remove Sharding custom-call instruction by folding the sharding attribute // to its operand. If the operand already has a different sharding, insert a diff --git a/third_party/xla/xla/service/sharding_propagation_test.cc b/third_party/xla/xla/service/sharding_propagation_test.cc index 52f0a197e05c04..072f43644ccd83 100644 --- a/third_party/xla/xla/service/sharding_propagation_test.cc +++ b/third_party/xla/xla/service/sharding_propagation_test.cc @@ -3324,7 +3324,7 @@ ENTRY %conv { EXPECT_THAT(instruction, op::Sharding("{devices=[2,2,2]0,1,2,3,4,5,6,7}")); if (GetParam().propagate_metadata && !GetParam().clear_metadata) { EXPECT_THAT(instruction->sharding(), - ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")})); + ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")})); } else { EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); } @@ -3396,7 +3396,7 @@ ENTRY %conv { EXPECT_THAT(instruction, op::Sharding("{devices=[2,4]0,2,3,1,4,6,7,5}")); if (GetParam().propagate_metadata && !GetParam().clear_metadata) { EXPECT_THAT(instruction->sharding(), - ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")})); + ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")})); } else { EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); } @@ -11864,7 +11864,7 @@ ENTRY main.9 { op::Sharding("{{devices=[4]<=[4]}, {devices=[4]<=[4]}}")); } -TEST_F(ShardingPropagationTest, InferDotShardingFromOperands1) { +TEST_F(ShardingPropagationTest, LookaheadUsersOfDot) { const char* const hlo_string = R"( HloModule module @@ -11881,108 +11881,24 @@ ENTRY %entry { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN( - bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true}) + .Run(module.get())); EXPECT_TRUE(changed); XLA_VLOG_LINES(1, module->ToString()); + // Check dangling sharding custom-call can be removed by DCE after + // propagation. auto* instruction = FindInstruction(module.get(), "dot.1"); + // Check sharding is correctly propagated. EXPECT_THAT(instruction, op::Sharding( "{devices=[4,4,1,4]<=[4,16]T(1,0) last_tile_dim_replicate}")); } -TEST_F(ShardingPropagationTest, InferDotShardingFromOperands2) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %entry { - p0 = bf16[16,32] parameter(0), sharding={devices=[16,1]<=[16]} - p1 = bf16[32,64] parameter(1), sharding={devices=[1,16]<=[16]} - dot = bf16[16,64] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT copy = bf16[16,64] copy(dot), sharding={replicated} -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); - EXPECT_TRUE(changed); - - XLA_VLOG_LINES(1, module->ToString()); - auto* instruction = FindInstruction(module.get(), "dot"); - EXPECT_THAT(instruction, op::Sharding("{devices=[1,16]<=[16]}")); -} - -TEST_F(ShardingPropagationTest, InferDotShardingFromOperands3) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %entry { - p0 = bf16[4,16,32] parameter(0), sharding={devices=[2,4,2]<=[16]} - p1 = bf16[4,32,64] parameter(1), sharding={devices=[2,8,1]<=[16]} - dot = bf16[4,16,64] dot(p0, p1), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} - ROOT copy = bf16[4,16,64] copy(dot) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); - EXPECT_TRUE(changed); - - XLA_VLOG_LINES(1, module->ToString()); - auto* instruction = FindInstruction(module.get(), "dot"); - EXPECT_THAT( - instruction, - op::Sharding("{devices=[2,4,1,2]<=[16] last_tile_dim_replicate}")); -} - -TEST_F(ShardingPropagationTest, InferDotShardingFromOperands4) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %entry { - p0 = bf16[4,16,32] parameter(0), sharding={devices=[2,1,8]<=[16]} - p1 = bf16[4,32,64] parameter(1), sharding={devices=[4,1,4]<=[16]} - dot = bf16[4,16,64] dot(p0, p1), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} - ROOT copy = bf16[4,16,64] copy(dot) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); - EXPECT_TRUE(changed); - - XLA_VLOG_LINES(1, module->ToString()); - auto* instruction = FindInstruction(module.get(), "dot"); - EXPECT_THAT(instruction, op::Sharding("{devices=[4,1,4]<=[16]}")); -} - -TEST_F(ShardingPropagationTest, InferDotShardingFromOperands5) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %entry { - p0 = bf16[16,16] parameter(0), sharding={devices=[4,4]<=[4,4]T(1,0)} - p1 = bf16[16,16] parameter(1), sharding={devices=[4,4]<=[4,4]T(1,0)} - dot.0 = bf16[16,16] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1} - p2 = bf16[16,16] parameter(2), sharding={devices=[4,4]<=[16]} - p3 = bf16[16,16] parameter(3), sharding={devices=[4,4]<=[16]} - dot.1 = bf16[16,16] dot(p2, p3), lhs_contracting_dims={1}, rhs_contracting_dims={0} - add = bf16[16,16] add(dot.0, dot.1) - ROOT copy = bf16[16,16] copy(add) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); - EXPECT_TRUE(changed); - - XLA_VLOG_LINES(1, module->ToString()); - for (absl::string_view name : {"dot.0", "dot.1", "add"}) { - auto* instruction = FindInstruction(module.get(), name); - EXPECT_THAT(instruction, op::Sharding("{devices=[4,4]<=[16]}")); - } -} - TEST_F(ShardingPropagationTest, AsyncInstructionManualShardingArray) { const char* const hlo_string = R"( HloModule module