Skip to content

Commit

Permalink
Slice an HLO module into multiple sub-modules, with each sub-module s…
Browse files Browse the repository at this point in the history
…liced from a subset of starting HLO instructions. Test case added.

PiperOrigin-RevId: 556825722
  • Loading branch information
tensorflower-gardener committed Aug 14, 2023
1 parent 6d6e73c commit 0b9ade2
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 86 deletions.
181 changes: 103 additions & 78 deletions tensorflow/compiler/xla/tools/hlo_slicer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,93 +427,118 @@ std::vector<std::unique_ptr<HloModule>> SliceModuleAndExtract(
const HloModule* hlo_module,
absl::Span<const HloInstruction*> slice_starting_instructions,
const SlicingConfiguration& slicing_configuration) {
// Forward slicing.
SliceOutput forward_slice_output;
if (slicing_configuration.forward_slicing ==
SlicingConfiguration::ForwardSlicingConfig::kRoot) {
// Slice to the root instruction of the entry computation of `hlo_module`.
forward_slice_output = SliceModule(
hlo_module, slice_starting_instructions, /*frontier_selector=*/nullptr,
/*ignore_control_dependency=*/false, /*forward_slice=*/true,
/*nearest_common_ancestor_as_root=*/false);
} else if (slicing_configuration.forward_slicing ==
SlicingConfiguration::ForwardSlicingConfig::kNca) {
// slice to the nearest common ancestors of `slice_starting_instructions`
forward_slice_output = SliceModule(
hlo_module, slice_starting_instructions, /*frontier_selector=*/nullptr,
/*ignore_control_dependency=*/false, /*forward_slice=*/true,
/*nearest_common_ancestor_as_root=*/true);
}
VLOG(1) << "[Num of forward sliced insts]: "
<< forward_slice_output.NumSlicedInstructions();

// Backward slicing.
SliceOutput backward_slice_output;
if (slicing_configuration.backward_slicing) {
backward_slice_output = SliceModule(
hlo_module, slice_starting_instructions, /*frontier_selector=*/nullptr,
/*ignore_control_dependency=*/false, /*forward_slice=*/false);
std::vector<std::unique_ptr<HloModule>> sliced_modules;

// Group `slice_starting_instructions` based on `slicing_group` configuration.
int slicing_group = slicing_configuration.slicing_group;
CHECK(slicing_group >= 1 || slicing_group == -1);
std::vector<absl::Span<const HloInstruction*>> grouped_instructions;
if (slicing_group == -1) {
grouped_instructions = {slice_starting_instructions};
} else {
// Return the empty SliceOutput if backward slicing is not enabled.
backward_slice_output = SliceOutput();
for (int i = 0; i < slice_starting_instructions.size();
i += slicing_group) {
// subspan can correctly handel the last group, which may be smaller than
// `slicing_group`.
grouped_instructions.push_back(
slice_starting_instructions.subspan(i, slicing_group));
}
}

// Combine forward slicing output and backward slicing output.
auto sliced_result = SliceOutput(SliceOutput::UnionSlicedInstructions(
forward_slice_output, backward_slice_output));

// Decide Root to start extraction based on `forward_slicing_config`.
const HloInstruction* extraction_root =
slicing_configuration.forward_slicing ==
SlicingConfiguration::ForwardSlicingConfig::kNca
? forward_slice_output.nearest_common_ancestor_root()
: hlo_module->entry_computation()->root_instruction();
VLOG(1) << "[Root instruction of the sliced module]: "
<< extraction_root->ToString();

// Exclude the instructions that are not in the slicing results.
auto extract_selector = [&sliced_result](const HloInstruction* hlo_inst) {
for (const auto& [computation, instructions] :
sliced_result.sliced_instructions()) {
if (instructions.contains(hlo_inst)) {
return true;
for (const auto& grouped_slice_starting_instructions : grouped_instructions) {
// Forward slicing.
SliceOutput forward_slice_output;
if (slicing_configuration.forward_slicing ==
SlicingConfiguration::ForwardSlicingConfig::kRoot) {
// Slice to the root instruction of the entry computation of `hlo_module`.
forward_slice_output = SliceModule(
hlo_module, grouped_slice_starting_instructions,
/*frontier_selector=*/nullptr,
/*ignore_control_dependency=*/false, /*forward_slice=*/true,
/*nearest_common_ancestor_as_root=*/false);
} else if (slicing_configuration.forward_slicing ==
SlicingConfiguration::ForwardSlicingConfig::kNca) {
// slice to the nearest common ancestors of
// `grouped_slice_starting_instructions`
forward_slice_output = SliceModule(
hlo_module, grouped_slice_starting_instructions,
/*frontier_selector=*/nullptr,
/*ignore_control_dependency=*/false, /*forward_slice=*/true,
/*nearest_common_ancestor_as_root=*/true);
}
VLOG(1) << "[Num of forward sliced insts]: "
<< forward_slice_output.NumSlicedInstructions();

// Backward slicing.
SliceOutput backward_slice_output;
if (slicing_configuration.backward_slicing) {
backward_slice_output = SliceModule(
hlo_module, grouped_slice_starting_instructions,
/*frontier_selector=*/nullptr,
/*ignore_control_dependency=*/false, /*forward_slice=*/false);
} else {
// Return the empty SliceOutput if backward slicing is not enabled.
backward_slice_output = SliceOutput();
}

// Combine forward slicing output and backward slicing output.
auto sliced_result = SliceOutput(SliceOutput::UnionSlicedInstructions(
forward_slice_output, backward_slice_output));

// Decide Root to start extraction based on `forward_slicing_config`.
const HloInstruction* extraction_root =
slicing_configuration.forward_slicing ==
SlicingConfiguration::ForwardSlicingConfig::kNca
? forward_slice_output.nearest_common_ancestor_root()
: hlo_module->entry_computation()->root_instruction();
VLOG(1) << "[Root instruction of the sliced module]: "
<< extraction_root->ToString();

// Exclude the instructions that are not in the slicing results.
auto extract_selector = [&sliced_result](const HloInstruction* hlo_inst) {
for (const auto& [computation, instructions] :
sliced_result.sliced_instructions()) {
if (instructions.contains(hlo_inst)) {
return true;
}
}
return false;
};

// Replace the excluded instructions in the entry computation with zeros.
auto replace_type_selector =
[](const HloInstruction* hlo_inst) -> ReplaceType {
return ReplaceType::kReplaceZeroBroadcast;
};

// Extract from the original module.
auto extracted_module =
ExtractModule(/*instruction=*/extraction_root, /*height=*/-1,
/*extract_selector=*/extract_selector,
/*replace_type_selector=*/replace_type_selector,
/*cross_computation=*/true);

// Remove the custom-call to sharding if `remove_sharding` is specified.
if (slicing_configuration.remove_sharding) {
RemoveSharding(extracted_module.get());
}
return false;
};

// Replace the excluded instructions in the entry computation with zeros.
auto replace_type_selector =
[](const HloInstruction* hlo_inst) -> ReplaceType {
return ReplaceType::kReplaceZeroBroadcast;
};

// Extract from the original module.
auto extracted_module =
ExtractModule(/*instruction=*/extraction_root, /*height=*/-1,
/*extract_selector=*/extract_selector,
/*replace_type_selector=*/replace_type_selector,
/*cross_computation=*/true);

// Remove the custom-call to sharding if `remove_sharding` is specified.
if (slicing_configuration.remove_sharding) {
RemoveSharding(extracted_module.get());
}

// Reduce the parameter instructions of tuple shape if
// `reduce_tuple_parameter` is specified.
if (slicing_configuration.reduce_tuple_parameter) {
ReduceTupleParameter(extracted_module.get());
}
// Reduce the parameter instructions of tuple shape if
// `reduce_tuple_parameter` is specified.
if (slicing_configuration.reduce_tuple_parameter) {
ReduceTupleParameter(extracted_module.get());
}

// Verify if the extracted module (after processing) is valid or not.
HloVerifier verifier(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/true);
TF_CHECK_OK(verifier.Run(extracted_module.get()).status());
// Verify if the extracted module (after processing) is valid or not.
HloVerifier verifier(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/true);
TF_CHECK_OK(verifier.Run(extracted_module.get()).status());

sliced_modules.emplace_back(std::move(extracted_module));
}

// Return all the sliced modules.
std::vector<std::unique_ptr<HloModule>> sliced_modules;
sliced_modules.emplace_back(std::move(extracted_module));
CHECK_EQ(sliced_modules.size(), grouped_instructions.size());
return sliced_modules;
}

Expand Down
15 changes: 15 additions & 0 deletions tensorflow/compiler/xla/tools/hlo_slicer.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,27 @@ SliceOutput SliceModule(
// each parameters of entry computation, if it is of tuple type, we will remove
// the elements that are not used by any other instructions. This is useful when
// slicing from a large module.
//
// `slicing_group`: `SliceModuleAndExtract` groups
// `slicing_starting_instructions` into multiple non-overlapping groups, and
// for each group of `slicing_starting_instructions`, slice/extract an HLO
// module. The `slicing_group` specifies the number of
// `slicing_starting_instructions` each group contains. For example,
// say `slicing_start_instructions` = {a, b, c ,d}. If `slicing_group` = 1,
// there would be 4 sliced/extracted HLO modules, sliced from {a}, {b}, {c},
// {d}, respectively. If `slicing_group` = 2, there would be 2 sliced/extracted
// HLO modules, sliced from {a, b}, {c, d}, respectively. The
// `slicing_starting_instructions` are grouped accoding to order in the
// absl::Span. When `slicing_group` = -1, there would be only one group which
// contains all the `slice_starting_instructions`, so there would be only 1
// sliced/extracted module. `slicing_group` can only be -1 or positive integer.
struct SlicingConfiguration {
enum class ForwardSlicingConfig { kRoot, kNca };
ForwardSlicingConfig forward_slicing = ForwardSlicingConfig::kRoot;
bool backward_slicing = false;
bool remove_sharding = false;
bool reduce_tuple_parameter = false;
int slicing_group = -1;
};

// Slices from the `hlo_module` from the `slicing_starting_instructions`,
Expand Down
73 changes: 65 additions & 8 deletions tensorflow/compiler/xla/tools/hlo_slicer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1014,19 +1014,19 @@ TEST_F(HloSlicerTest, TestSliceModuleAndExtractRemoveSharding) {
/*slice_starting_instructions=*/
absl::MakeSpan(relevant_instructions),
/*slicing_configuration=*/slicing_config);
CHECK_EQ(sliced_modules.size(), 1);
EXPECT_EQ(sliced_modules.size(), 1);
auto sliced_module = std::move(sliced_modules[0]);

// Test if the custom-call to sharding is removed.
for (HloInstruction* instruction :
sliced_module->entry_computation()->instructions()) {
CHECK_NE(instruction->opcode(), HloOpcode::kCustomCall);
EXPECT_NE(instruction->opcode(), HloOpcode::kCustomCall);
}

// Check that both the operands of %add.39786 are %multiply.39766.
for (HloInstruction* instruction :
sliced_module->entry_computation()->root_instruction()->operands()) {
CHECK_EQ(instruction->name(), "multiply.39766");
EXPECT_EQ(instruction->name(), "multiply.39766");
}
}
}
Expand Down Expand Up @@ -1066,18 +1066,75 @@ TEST_F(HloSlicerTest, TestSliceModuleAndExtractReduceTupleParameter) {
/*slice_starting_instructions=*/
absl::MakeSpan(relevant_instructions),
/*slicing_configuration=*/slicing_config);
CHECK_EQ(sliced_modules.size(), 1);
EXPECT_EQ(sliced_modules.size(), 1);
auto sliced_module = std::move(sliced_modules[0]);

// Check that the new p.0 only has one element.
HloInstruction* p_0 = FindInstruction(sliced_module.get(), "p.0");
CHECK_NE(p_0, nullptr);
CHECK_EQ(p_0->shape().tuple_shapes_size(), 1);
EXPECT_NE(p_0, nullptr);
EXPECT_EQ(p_0->shape().tuple_shapes_size(), 1);

// Check that the new p.1 only has one element.
HloInstruction* p_1 = FindInstruction(sliced_module.get(), "p.1");
CHECK_NE(p_1, nullptr);
CHECK_EQ(p_1->shape().tuple_shapes_size(), 1);
EXPECT_NE(p_1, nullptr);
EXPECT_EQ(p_1->shape().tuple_shapes_size(), 1);
}
}

TEST_F(HloSlicerTest, TestSliceModuleAndExtractSlicingGroup) {
const std::string& hlo_string = R"(
HloModule axpy_module
ENTRY axpy_computation (p.0: (s32[], s32[3]{0}), p.1: (s32[3]{0}, s32[])) -> s32[] {
p.0 = (s32[], s32[3]{0}) parameter(0)
gte.0 = s32[] get-tuple-element(p.0), index=0
p.1 = (s32[3]{0}, s32[]) parameter(1)
gte.1 = s32[] get-tuple-element(p.1), index=1
ROOT add.0 = s32[] add(gte.0, gte.1)
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
ParseAndReturnVerifiedModule(hlo_string));

HloInstruction* gte_0 = FindInstruction(hlo_module.get(), "gte.0");
CHECK_NE(gte_0, nullptr);
HloInstruction* gte_1 = FindInstruction(hlo_module.get(), "gte.1");
CHECK_NE(gte_1, nullptr);

// slice_starting_instructions: {gte.0, gte.1}.
// forward_slicing: kNca.
// backward_slicing: true.
// remove_sharding: false.
// reduce_tuple_parameter: false.
// slicing_group: 1
{
// Generate two sliced modules, sliced from gte.0 and gte.1, respectively
// (`slicing_group` = 1).
std::vector<const HloInstruction*> relevant_instructions({gte_0, gte_1});
SlicingConfiguration slicing_config = {
/*forward_slicing=*/SlicingConfiguration::ForwardSlicingConfig::kNca,
/*backward_slicing=*/true, /*remove_sharding=*/false,
/*reduce_tuple_parameter=*/false, /*slicing_group=*/1};
std::vector<std::unique_ptr<HloModule>> sliced_modules =
SliceModuleAndExtract(hlo_module.get(),
/*slice_starting_instructions=*/
absl::MakeSpan(relevant_instructions),
/*slicing_configuration=*/slicing_config);

// There are two sliced module.
EXPECT_EQ(sliced_modules.size(), 2);

// The first sliced module contains gte.0 and p.0.
auto sliced_module_0 = std::move(sliced_modules[0]);
EXPECT_EQ(sliced_module_0->entry_computation()->instruction_count(), 2);
HloInstruction* p_0 = FindInstruction(sliced_module_0.get(), "p.0");
EXPECT_NE(p_0, nullptr);

// The second sliced module contains gte.1 and p.1.
auto sliced_module_1 = std::move(sliced_modules[1]);
EXPECT_EQ(sliced_module_0->entry_computation()->instruction_count(), 2);
HloInstruction* p_1 = FindInstruction(sliced_module_1.get(), "p.1");
EXPECT_NE(p_1, nullptr);
}
}

Expand Down

0 comments on commit 0b9ade2

Please sign in to comment.