forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR tensorflow#15403: Handle multiple users in all-gather dynamic-slic…
…e simplification. Add AllGatherDynamicSliceSimplifier pass Imported from GitHub PR openxla/xla#15403 I have found in some models that have poor SPMD partitioning the below pattern. ``` all-gather.1 = all-gather(x) dot.1 = dot(all-gather.1, y) dynamic-slice.1 = dynamic-slice(all-gather.1) // can be cancelled ``` In this case, the all-gather has multiple users but the dynamic-slice can be cancelled. This is applicable to all-reduce and reduce-scatter also. My changes now support multiple users, but it also depends how this utility is used by internal TPU compiler and the GPU ReduceScatterCreator pass. My changes assume the cancellation is run like this -- 1. Find a dynamic-slice 2. Check if dynamic-slice can be cancelled 3. Delete dynamic-slice but do not delete the collective 4. The collective is deleted by the DCE pass if it has no users The above workflow then supports removing dynamic-slices even if the collective has multiple users. The above is what we are using in our internal Neuron workflow. Interested to hear thoughts on this. Copybara import of the project: -- f518bd6e3164aa10b60b4689f2aa2ee8d8faa7ae by ptoulme-aws <ptoulme@amazon.com>: Handle multiple users in all-gather dynamic-slice simplification. Add AllGatherDynamicSliceSimplifier pass Merging this change closes tensorflow#15403 PiperOrigin-RevId: 675370754
- Loading branch information
1 parent
6ed0d07
commit 424e3a5
Showing
8 changed files
with
417 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
83 changes: 83 additions & 0 deletions
83
third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* Copyright 2024 The OpenXLA Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h" | ||
|
||
#include "xla/hlo/ir/hlo_casting_utils.h" | ||
#include "xla/service/collective_opt_utils.h" | ||
|
||
namespace xla { | ||
bool AllGatherDynamicSliceSimplifier::InstructionMatchesPattern( | ||
HloInstruction* instruction) { | ||
if (instruction->opcode() != HloOpcode::kDynamicSlice) { | ||
return false; | ||
} | ||
|
||
HloDynamicSliceInstruction* dynamic_slice = | ||
Cast<HloDynamicSliceInstruction>(instruction); | ||
HloInstruction* operand = dynamic_slice->mutable_operand(0); | ||
|
||
// Check if the operand is a reshape or all-gather instruction | ||
bool is_reshape = operand->opcode() == HloOpcode::kReshape; | ||
bool is_all_gather = operand->opcode() == HloOpcode::kAllGather; | ||
|
||
if (!is_reshape && !is_all_gather) { | ||
return false; | ||
} | ||
|
||
if (is_reshape && operand->operand(0)->opcode() != HloOpcode::kAllGather) { | ||
return false; | ||
} | ||
|
||
const HloModuleConfig& config = instruction->GetModule()->config(); | ||
HloAllGatherInstruction* all_gather = | ||
is_reshape ? Cast<HloAllGatherInstruction>(operand->mutable_operand(0)) | ||
: Cast<HloAllGatherInstruction>(operand); | ||
|
||
bool match = AllGatherDynamicSliceCancellation( | ||
all_gather, config.num_partitions(), config.replica_count(), | ||
/*allow_multiple_split_dims=*/true, | ||
/*allow_intervening_reshape=*/true, /*min_rank=*/1, | ||
HloPredicateIsOp<HloOpcode::kPartitionId>, | ||
HloPredicateIsOp<HloOpcode::kReplicaId>, | ||
/*allow_intervening_bitcast=*/false, | ||
/*allow_multiple_users=*/true); | ||
|
||
return match; | ||
} | ||
|
||
StatusOr<HloInstruction*> AllGatherDynamicSliceSimplifier::ExpandInstruction( | ||
HloInstruction* instruction) { | ||
HloDynamicSliceInstruction* dynamic_slice = | ||
Cast<HloDynamicSliceInstruction>(instruction); | ||
HloInstruction* operand = dynamic_slice->mutable_operand(0); | ||
|
||
if (operand->opcode() != HloOpcode::kReshape) { | ||
// dynamic-slice(all-gather) case | ||
return operand->mutable_operand(0); | ||
} | ||
|
||
// dynamic-slice(reshape(all-gather)) case | ||
HloReshapeInstruction* reshape = Cast<HloReshapeInstruction>(operand); | ||
HloAllGatherInstruction* all_gather = | ||
Cast<HloAllGatherInstruction>(reshape->mutable_operand(0)); | ||
HloInstruction* all_gather_input = all_gather->mutable_operand(0); | ||
|
||
auto* new_reshape = instruction->parent()->AddInstruction( | ||
HloInstruction::CreateReshape(dynamic_slice->shape(), all_gather_input)); | ||
return new_reshape; | ||
} | ||
|
||
} // namespace xla |
48 changes: 48 additions & 0 deletions
48
third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
/* Copyright 2024 The OpenXLA Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_DYNAMIC_SLICE_SIMPLIFIER_H_ | ||
#define XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_DYNAMIC_SLICE_SIMPLIFIER_H_ | ||
|
||
#include "xla/service/op_expander_pass.h" | ||
|
||
namespace xla { | ||
|
||
// A pass that simplifies a dynamic-slice of an all-gather | ||
// whose slice is the same as the original operand of the all-gather. | ||
// As an example: | ||
// | ||
// ag = all-gather(x) replica_groups={{0,1,2,3,4,5,6,7}} | ||
// offset = multiply(partition_id, slice_size) | ||
// ds = dynamic-slice(ag, offset, 0, 0) | ||
// | ||
// Can be simplified to the all-gather operand. | ||
|
||
class AllGatherDynamicSliceSimplifier : public OpExpanderPass { | ||
public: | ||
absl::string_view name() const override { | ||
return "all-gather-dynamic-slice-simplifier"; | ||
} | ||
|
||
protected: | ||
bool InstructionMatchesPattern(HloInstruction* instruction) override; | ||
|
||
StatusOr<HloInstruction*> ExpandInstruction( | ||
HloInstruction* instruction) override; | ||
}; | ||
|
||
} // namespace xla | ||
|
||
#endif // XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_DYNAMIC_SLICE_SIMPLIFIER_H_ |
Oops, something went wrong.