Skip to content

Commit

Permalink
Rollback of [XLA:SPMD] Remove LookaheadUserSharding in sharding propa…
Browse files Browse the repository at this point in the history
…gation.

Reverts 339d798

PiperOrigin-RevId: 664283560
  • Loading branch information
tensorflower-gardener committed Aug 18, 2024
1 parent ccf9f4e commit d35ae2f
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -446,13 +446,13 @@ std::optional<HloSharding> HandlerBase::GetShardingFromUser(
CHECK_OK(ins_clone->ReplaceOperandWith(1, rhs_clone.get()));
if (ins_->opcode() == HloOpcode::kConvolution) {
xla::InferConvolutionShardingFromOperands(
ins_clone.get(), /* aggressiveness */ 10,
/* may_combine_partial_sharding */ true);
ins_clone.get(), call_graph_, 10,
/* may_combine_partial_sharding */ true, /* is_spmd */ true);
} else {
xla::InferDotShardingFromOperands(
ins_clone.get(),
ins_clone.get(), call_graph_,
dot_as_convolution_util::ParseDotGeneralFromDot(ins_clone.get()),
/* aggressiveness */ 10, /* may_combine_partial_sharding */ true);
/* may_combine_partial_sharding/ */ true, /* is_spmd */ true);
}
if (!ins_clone->has_sharding()) {
return std::nullopt;
Expand Down
153 changes: 96 additions & 57 deletions third_party/xla/xla/service/sharding_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ limitations under the License.
#include "xla/service/sharding_propagation.h"

#include <algorithm>
#include <array>
#include <cstdint>
#include <functional>
#include <iterator>
#include <list>
#include <map>
#include <memory>
#include <optional>
Expand All @@ -36,20 +36,17 @@ limitations under the License.
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/array.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_domain_metadata.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/hlo/ir/hlo_sharding_metadata.h"
#include "xla/hlo/utils/hlo_sharding_util.h"
#include "xla/protobuf_util.h"
#include "xla/service/call_graph.h"
#include "xla/service/dot_as_convolution_util.h"
#include "xla/service/host_memory_offload_annotations.h"
#include "xla/service/spmd/shard_barrier_partitioner.h"
Expand Down Expand Up @@ -419,6 +416,55 @@ bool SupportSpatialPartitioning(
}
}

// Helper to lookahead sharding of user of an instruction to be used as guidance
// for ambiguous cases.
std::optional<HloSharding> LookaheadUserSharding(HloInstruction* instr,
bool is_spmd,
const CallGraph& call_graph) {
if (instr->user_count() != 1) {
return std::nullopt;
}
HloInstruction* current_user = instr->users()[0];
std::optional<HloSharding> sharding;
std::vector<HloInstruction*> users_chain = {instr, current_user};
// Collect single user instructions along the way.
while (!current_user->has_sharding()) {
// Only consider single user chains.
if (current_user->users().size() != 1) {
users_chain.clear();
break;
}
current_user = current_user->users()[0];
users_chain.push_back(current_user);
}
// Early exit for unsupported cases.
if (users_chain.empty()) {
return std::nullopt;
}
for (int i = users_chain.size() - 1; i >= 1; --i) {
HloInstruction* user = users_chain[i];
HloInstruction* current = users_chain[i - 1];
CHECK(user->has_sharding());
sharding = ShardingPropagation::GetShardingFromUser(
*current, *user, INT64_MAX, is_spmd, call_graph,
/*sharding_helper=*/nullptr);
// We need to set the sharding to the instruction, because
// GetShardingFromUser() interface uses sharding from the instruction
// itself. It will be cleared out later.
if (sharding.has_value() && i != 1) {
current->set_sharding(*sharding);
continue;
}
break;
}
// Clear the sharding of the middle instructions we set the sharding of
// because they were unsharded.
for (int i = 1; i < users_chain.size() - 1; ++i) {
users_chain[i]->clear_sharding();
}
return sharding;
}

// Infer output sharding on index parallel dimensions for gather from operand
// and indices.
bool InferGatherParallelShardingFromOperands(
Expand Down Expand Up @@ -1025,9 +1071,9 @@ bool IsCSEPreventionSharding(const HloSharding& sharding) {
} // namespace

bool InferDotShardingFromOperands(
HloInstruction* instruction,
HloInstruction* instruction, const CallGraph& call_graph,
const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
int64_t aggressiveness, bool may_combine_partial_sharding) {
bool may_combine_partial_sharding, bool is_spmd) {
auto from_operand = [&](int64_t operand_index) {
auto operand = instruction->operand(operand_index);
const HloSharding& operand_sharding = operand->sharding();
Expand Down Expand Up @@ -1082,66 +1128,55 @@ bool InferDotShardingFromOperands(
from_operand(1), instruction, may_combine_partial_sharding,
/*allow_aggressive_resharding=*/false);
}

// Four cases based on if improved_operand_0 and improved_operand_1 are
// available.
// Case 0. Both operands have no improved sharding.
// If not improved sharding found then do not set any sharding.
if (!improved_operand_0.has_value() && !improved_operand_1.has_value()) {
return false;
}
// Case 1. Sharding found from operand 0 but not operand 1. Set sharding from
// operand 0.
// Sharding found from operand 0 but not operand 1. Set sharding from operand
// 0
if (improved_operand_0.has_value() && !improved_operand_1.has_value()) {
instruction->set_sharding(*improved_operand_0);
return true;
}
// Case 2. Sharding found from operand 1 but not operand 0. Set sharding from
// operand 1.
// Sharding found from operand 1 but not operand 0. Set sharding from operand
// 1
if (!improved_operand_0.has_value() && improved_operand_1.has_value()) {
instruction->set_sharding(*improved_operand_1);
return true;
}
// Case 3. Both operands have improved shardings.
CHECK(improved_operand_0.has_value() && improved_operand_1.has_value());

// If one of the improved shardings is a sub-tiling or equal to the other, use
// the better sharding with more tiles.
if (hlo_sharding_util::IsSubTilingOrEqualSharding(
instruction->shape(), *improved_operand_0, *improved_operand_1)) {
instruction->set_sharding(*improved_operand_0);
return true;
}
if (hlo_sharding_util::IsSubTilingOrEqualSharding(
instruction->shape(), *improved_operand_1, *improved_operand_0)) {
instruction->set_sharding(*improved_operand_1);
return true;
}

// If the two improved shardings are mergeable, there is no conflict.
if (std::optional<HloSharding> improved_sharding =
hlo_sharding_util::ReturnImprovedShardingImpl(
*improved_operand_0, &improved_operand_1.value(),
instruction->shape(), may_combine_partial_sharding,
/*allow_aggressive_resharding=*/false)) {
instruction->set_sharding(*improved_sharding);
return true;
}

if (aggressiveness < 3) {
// We can improve the dot with different shardings. Pause the propagation
// and wait for the winner between the two operands.
return false;
}

// The two improved sharding are different and we are at the highest
// aggressiveness. Prioritize the operand with larger size.
std::optional<HloSharding> lookahead_sharding =
LookaheadUserSharding(instruction, is_spmd, call_graph);
std::array<HloSharding, 2> sharding_priority = {*improved_operand_0,
*improved_operand_1};
if (ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) <
ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) {
bool priority_defined_with_lookahead = false;
// Found sharding from lookahead.
if (lookahead_sharding.has_value()) {
const bool operand_0_is_lookahead_subtiling =
hlo_sharding_util::IsSubTilingOrEqualSharding(
instruction->shape(), *lookahead_sharding, *improved_operand_0);
const bool operand_1_is_lookahead_subtiling =
hlo_sharding_util::IsSubTilingOrEqualSharding(
instruction->shape(), *lookahead_sharding, *improved_operand_1);
// If the sharding from operand 0 is a subtiling of the user, but not the
// one from operand 1 prioritize that sharding.
if (operand_0_is_lookahead_subtiling && !operand_1_is_lookahead_subtiling) {
priority_defined_with_lookahead = true;
}
// If the sharding from operand 1 is a subtiling of the user, but not the
// one from operand 0 prioritize that sharding.
if (!operand_0_is_lookahead_subtiling && operand_1_is_lookahead_subtiling) {
instruction->set_sharding(*improved_operand_1);
std::swap(sharding_priority[0], sharding_priority[1]);
priority_defined_with_lookahead = true;
}
}
// If lookahead didn't define a priority then use size.
if (!priority_defined_with_lookahead &&
ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) <
ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) {
std::swap(sharding_priority[0], sharding_priority[1]);
}

// Set primary sharding to the instruction and then try to improve it with
// the secondary sharding.
instruction->set_sharding(sharding_priority[0]);
Expand All @@ -1152,8 +1187,10 @@ bool InferDotShardingFromOperands(

// Convolution handling for InferShardingFromOperands().
bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
const CallGraph& call_graph,
int64_t aggressiveness,
bool may_combine_partial_sharding) {
bool may_combine_partial_sharding,
bool is_spmd) {
auto get_partitions_for_dims =
[&](const HloInstruction* inst,
absl::Span<
Expand Down Expand Up @@ -1188,8 +1225,8 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
(lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 &&
instruction->batch_group_count() == 1 &&
instruction->feature_group_count() == 1)) {
return InferDotShardingFromOperands(instruction, dot_dims, aggressiveness,
may_combine_partial_sharding);
return InferDotShardingFromOperands(instruction, call_graph, dot_dims,
may_combine_partial_sharding, is_spmd);
}
const auto& dnums = instruction->convolution_dimension_numbers();
const HloInstruction* lhs = instruction->operand(0);
Expand Down Expand Up @@ -2292,8 +2329,9 @@ bool ShardingPropagation::InferShardingFromOperands(
1);
}
case HloOpcode::kConvolution:
return InferConvolutionShardingFromOperands(instruction, aggressiveness,
may_combine_partial_sharding);
return InferConvolutionShardingFromOperands(
instruction, call_graph, aggressiveness, may_combine_partial_sharding,
is_spmd_);
case HloOpcode::kTranspose: {
const HloInstruction* input = instruction->operand(0);
if (!hlo_sharding_util::IsSpatiallyPartitioned(input)) {
Expand Down Expand Up @@ -2382,8 +2420,9 @@ bool ShardingPropagation::InferShardingFromOperands(
case HloOpcode::kDot: {
const auto& dnums =
dot_as_convolution_util::ParseDotGeneralFromDot(instruction);
return InferDotShardingFromOperands(instruction, dnums, aggressiveness,
may_combine_partial_sharding);
return InferDotShardingFromOperands(instruction, call_graph, dnums,
may_combine_partial_sharding,
is_spmd_);
}
case HloOpcode::kParameter: {
auto parent_it = computation_map.find(instruction->parent());
Expand Down
9 changes: 5 additions & 4 deletions third_party/xla/xla/service/sharding_propagation.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#ifndef XLA_SERVICE_SHARDING_PROPAGATION_H_
#define XLA_SERVICE_SHARDING_PROPAGATION_H_

#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
Expand All @@ -36,15 +35,17 @@ namespace xla {
// Infers the shardings for a dot HLO op from the shardings on its operands,
// which are expected to have sharding annotations.
bool InferDotShardingFromOperands(
HloInstruction* instruction,
HloInstruction* instruction, const CallGraph& call_graph,
const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
int64_t aggressiveness, bool may_combine_partial_sharding);
bool may_combine_partial_sharding, bool is_spmd);

// Infers the shardings for a convolution HLO op from the shardings on its
// operands, which are expected to have sharding annotations.
bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
const CallGraph& call_graph,
int64_t aggressiveness,
bool may_combine_partial_sharding);
bool may_combine_partial_sharding,
bool is_spmd);

// Remove Sharding custom-call instruction by folding the sharding attribute
// to its operand. If the operand already has a different sharding, insert a
Expand Down
Loading

0 comments on commit d35ae2f

Please sign in to comment.