diff --git a/tensorflow/compiler/xla/tools/hlo_slicer.cc b/tensorflow/compiler/xla/tools/hlo_slicer.cc index 3697dd27863773..b287a3c7e73f70 100644 --- a/tensorflow/compiler/xla/tools/hlo_slicer.cc +++ b/tensorflow/compiler/xla/tools/hlo_slicer.cc @@ -427,93 +427,118 @@ std::vector> SliceModuleAndExtract( const HloModule* hlo_module, absl::Span slice_starting_instructions, const SlicingConfiguration& slicing_configuration) { - // Forward slicing. - SliceOutput forward_slice_output; - if (slicing_configuration.forward_slicing == - SlicingConfiguration::ForwardSlicingConfig::kRoot) { - // Slice to the root instruction of the entry computation of `hlo_module`. - forward_slice_output = SliceModule( - hlo_module, slice_starting_instructions, /*frontier_selector=*/nullptr, - /*ignore_control_dependency=*/false, /*forward_slice=*/true, - /*nearest_common_ancestor_as_root=*/false); - } else if (slicing_configuration.forward_slicing == - SlicingConfiguration::ForwardSlicingConfig::kNca) { - // slice to the nearest common ancestors of `slice_starting_instructions` - forward_slice_output = SliceModule( - hlo_module, slice_starting_instructions, /*frontier_selector=*/nullptr, - /*ignore_control_dependency=*/false, /*forward_slice=*/true, - /*nearest_common_ancestor_as_root=*/true); - } - VLOG(1) << "[Num of forward sliced insts]: " - << forward_slice_output.NumSlicedInstructions(); - - // Backward slicing. - SliceOutput backward_slice_output; - if (slicing_configuration.backward_slicing) { - backward_slice_output = SliceModule( - hlo_module, slice_starting_instructions, /*frontier_selector=*/nullptr, - /*ignore_control_dependency=*/false, /*forward_slice=*/false); + std::vector> sliced_modules; + + // Group `slice_starting_instructions` based on `slicing_group` configuration. + int slicing_group = slicing_configuration.slicing_group; + CHECK(slicing_group >= 1 || slicing_group == -1); + std::vector> grouped_instructions; + if (slicing_group == -1) { + grouped_instructions = {slice_starting_instructions}; } else { - // Return the empty SliceOutput if backward slicing is not enabled. - backward_slice_output = SliceOutput(); + for (int i = 0; i < slice_starting_instructions.size(); + i += slicing_group) { + // subspan can correctly handel the last group, which may be smaller than + // `slicing_group`. + grouped_instructions.push_back( + slice_starting_instructions.subspan(i, slicing_group)); + } } - // Combine forward slicing output and backward slicing output. - auto sliced_result = SliceOutput(SliceOutput::UnionSlicedInstructions( - forward_slice_output, backward_slice_output)); - - // Decide Root to start extraction based on `forward_slicing_config`. - const HloInstruction* extraction_root = - slicing_configuration.forward_slicing == - SlicingConfiguration::ForwardSlicingConfig::kNca - ? forward_slice_output.nearest_common_ancestor_root() - : hlo_module->entry_computation()->root_instruction(); - VLOG(1) << "[Root instruction of the sliced module]: " - << extraction_root->ToString(); - - // Exclude the instructions that are not in the slicing results. - auto extract_selector = [&sliced_result](const HloInstruction* hlo_inst) { - for (const auto& [computation, instructions] : - sliced_result.sliced_instructions()) { - if (instructions.contains(hlo_inst)) { - return true; + for (const auto& grouped_slice_starting_instructions : grouped_instructions) { + // Forward slicing. + SliceOutput forward_slice_output; + if (slicing_configuration.forward_slicing == + SlicingConfiguration::ForwardSlicingConfig::kRoot) { + // Slice to the root instruction of the entry computation of `hlo_module`. + forward_slice_output = SliceModule( + hlo_module, grouped_slice_starting_instructions, + /*frontier_selector=*/nullptr, + /*ignore_control_dependency=*/false, /*forward_slice=*/true, + /*nearest_common_ancestor_as_root=*/false); + } else if (slicing_configuration.forward_slicing == + SlicingConfiguration::ForwardSlicingConfig::kNca) { + // slice to the nearest common ancestors of + // `grouped_slice_starting_instructions` + forward_slice_output = SliceModule( + hlo_module, grouped_slice_starting_instructions, + /*frontier_selector=*/nullptr, + /*ignore_control_dependency=*/false, /*forward_slice=*/true, + /*nearest_common_ancestor_as_root=*/true); + } + VLOG(1) << "[Num of forward sliced insts]: " + << forward_slice_output.NumSlicedInstructions(); + + // Backward slicing. + SliceOutput backward_slice_output; + if (slicing_configuration.backward_slicing) { + backward_slice_output = SliceModule( + hlo_module, grouped_slice_starting_instructions, + /*frontier_selector=*/nullptr, + /*ignore_control_dependency=*/false, /*forward_slice=*/false); + } else { + // Return the empty SliceOutput if backward slicing is not enabled. + backward_slice_output = SliceOutput(); + } + + // Combine forward slicing output and backward slicing output. + auto sliced_result = SliceOutput(SliceOutput::UnionSlicedInstructions( + forward_slice_output, backward_slice_output)); + + // Decide Root to start extraction based on `forward_slicing_config`. + const HloInstruction* extraction_root = + slicing_configuration.forward_slicing == + SlicingConfiguration::ForwardSlicingConfig::kNca + ? forward_slice_output.nearest_common_ancestor_root() + : hlo_module->entry_computation()->root_instruction(); + VLOG(1) << "[Root instruction of the sliced module]: " + << extraction_root->ToString(); + + // Exclude the instructions that are not in the slicing results. + auto extract_selector = [&sliced_result](const HloInstruction* hlo_inst) { + for (const auto& [computation, instructions] : + sliced_result.sliced_instructions()) { + if (instructions.contains(hlo_inst)) { + return true; + } } + return false; + }; + + // Replace the excluded instructions in the entry computation with zeros. + auto replace_type_selector = + [](const HloInstruction* hlo_inst) -> ReplaceType { + return ReplaceType::kReplaceZeroBroadcast; + }; + + // Extract from the original module. + auto extracted_module = + ExtractModule(/*instruction=*/extraction_root, /*height=*/-1, + /*extract_selector=*/extract_selector, + /*replace_type_selector=*/replace_type_selector, + /*cross_computation=*/true); + + // Remove the custom-call to sharding if `remove_sharding` is specified. + if (slicing_configuration.remove_sharding) { + RemoveSharding(extracted_module.get()); } - return false; - }; - - // Replace the excluded instructions in the entry computation with zeros. - auto replace_type_selector = - [](const HloInstruction* hlo_inst) -> ReplaceType { - return ReplaceType::kReplaceZeroBroadcast; - }; - - // Extract from the original module. - auto extracted_module = - ExtractModule(/*instruction=*/extraction_root, /*height=*/-1, - /*extract_selector=*/extract_selector, - /*replace_type_selector=*/replace_type_selector, - /*cross_computation=*/true); - - // Remove the custom-call to sharding if `remove_sharding` is specified. - if (slicing_configuration.remove_sharding) { - RemoveSharding(extracted_module.get()); - } - // Reduce the parameter instructions of tuple shape if - // `reduce_tuple_parameter` is specified. - if (slicing_configuration.reduce_tuple_parameter) { - ReduceTupleParameter(extracted_module.get()); - } + // Reduce the parameter instructions of tuple shape if + // `reduce_tuple_parameter` is specified. + if (slicing_configuration.reduce_tuple_parameter) { + ReduceTupleParameter(extracted_module.get()); + } - // Verify if the extracted module (after processing) is valid or not. - HloVerifier verifier(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true); - TF_CHECK_OK(verifier.Run(extracted_module.get()).status()); + // Verify if the extracted module (after processing) is valid or not. + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true); + TF_CHECK_OK(verifier.Run(extracted_module.get()).status()); + + sliced_modules.emplace_back(std::move(extracted_module)); + } // Return all the sliced modules. - std::vector> sliced_modules; - sliced_modules.emplace_back(std::move(extracted_module)); + CHECK_EQ(sliced_modules.size(), grouped_instructions.size()); return sliced_modules; } diff --git a/tensorflow/compiler/xla/tools/hlo_slicer.h b/tensorflow/compiler/xla/tools/hlo_slicer.h index 89e3f67eb8ebd2..e2e6cc9d1e5d86 100644 --- a/tensorflow/compiler/xla/tools/hlo_slicer.h +++ b/tensorflow/compiler/xla/tools/hlo_slicer.h @@ -215,12 +215,27 @@ SliceOutput SliceModule( // each parameters of entry computation, if it is of tuple type, we will remove // the elements that are not used by any other instructions. This is useful when // slicing from a large module. +// +// `slicing_group`: `SliceModuleAndExtract` groups +// `slicing_starting_instructions` into multiple non-overlapping groups, and +// for each group of `slicing_starting_instructions`, slice/extract an HLO +// module. The `slicing_group` specifies the number of +// `slicing_starting_instructions` each group contains. For example, +// say `slicing_start_instructions` = {a, b, c ,d}. If `slicing_group` = 1, +// there would be 4 sliced/extracted HLO modules, sliced from {a}, {b}, {c}, +// {d}, respectively. If `slicing_group` = 2, there would be 2 sliced/extracted +// HLO modules, sliced from {a, b}, {c, d}, respectively. The +// `slicing_starting_instructions` are grouped accoding to order in the +// absl::Span. When `slicing_group` = -1, there would be only one group which +// contains all the `slice_starting_instructions`, so there would be only 1 +// sliced/extracted module. `slicing_group` can only be -1 or positive integer. struct SlicingConfiguration { enum class ForwardSlicingConfig { kRoot, kNca }; ForwardSlicingConfig forward_slicing = ForwardSlicingConfig::kRoot; bool backward_slicing = false; bool remove_sharding = false; bool reduce_tuple_parameter = false; + int slicing_group = -1; }; // Slices from the `hlo_module` from the `slicing_starting_instructions`, diff --git a/tensorflow/compiler/xla/tools/hlo_slicer_test.cc b/tensorflow/compiler/xla/tools/hlo_slicer_test.cc index f120ef395a6c21..2d99da1f20b53e 100644 --- a/tensorflow/compiler/xla/tools/hlo_slicer_test.cc +++ b/tensorflow/compiler/xla/tools/hlo_slicer_test.cc @@ -1014,19 +1014,19 @@ TEST_F(HloSlicerTest, TestSliceModuleAndExtractRemoveSharding) { /*slice_starting_instructions=*/ absl::MakeSpan(relevant_instructions), /*slicing_configuration=*/slicing_config); - CHECK_EQ(sliced_modules.size(), 1); + EXPECT_EQ(sliced_modules.size(), 1); auto sliced_module = std::move(sliced_modules[0]); // Test if the custom-call to sharding is removed. for (HloInstruction* instruction : sliced_module->entry_computation()->instructions()) { - CHECK_NE(instruction->opcode(), HloOpcode::kCustomCall); + EXPECT_NE(instruction->opcode(), HloOpcode::kCustomCall); } // Check that both the operands of %add.39786 are %multiply.39766. for (HloInstruction* instruction : sliced_module->entry_computation()->root_instruction()->operands()) { - CHECK_EQ(instruction->name(), "multiply.39766"); + EXPECT_EQ(instruction->name(), "multiply.39766"); } } } @@ -1066,18 +1066,75 @@ TEST_F(HloSlicerTest, TestSliceModuleAndExtractReduceTupleParameter) { /*slice_starting_instructions=*/ absl::MakeSpan(relevant_instructions), /*slicing_configuration=*/slicing_config); - CHECK_EQ(sliced_modules.size(), 1); + EXPECT_EQ(sliced_modules.size(), 1); auto sliced_module = std::move(sliced_modules[0]); // Check that the new p.0 only has one element. HloInstruction* p_0 = FindInstruction(sliced_module.get(), "p.0"); - CHECK_NE(p_0, nullptr); - CHECK_EQ(p_0->shape().tuple_shapes_size(), 1); + EXPECT_NE(p_0, nullptr); + EXPECT_EQ(p_0->shape().tuple_shapes_size(), 1); // Check that the new p.1 only has one element. HloInstruction* p_1 = FindInstruction(sliced_module.get(), "p.1"); - CHECK_NE(p_1, nullptr); - CHECK_EQ(p_1->shape().tuple_shapes_size(), 1); + EXPECT_NE(p_1, nullptr); + EXPECT_EQ(p_1->shape().tuple_shapes_size(), 1); + } +} + +TEST_F(HloSlicerTest, TestSliceModuleAndExtractSlicingGroup) { + const std::string& hlo_string = R"( + HloModule axpy_module + ENTRY axpy_computation (p.0: (s32[], s32[3]{0}), p.1: (s32[3]{0}, s32[])) -> s32[] { + p.0 = (s32[], s32[3]{0}) parameter(0) + gte.0 = s32[] get-tuple-element(p.0), index=0 + p.1 = (s32[3]{0}, s32[]) parameter(1) + gte.1 = s32[] get-tuple-element(p.1), index=1 + ROOT add.0 = s32[] add(gte.0, gte.1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloInstruction* gte_0 = FindInstruction(hlo_module.get(), "gte.0"); + CHECK_NE(gte_0, nullptr); + HloInstruction* gte_1 = FindInstruction(hlo_module.get(), "gte.1"); + CHECK_NE(gte_1, nullptr); + + // slice_starting_instructions: {gte.0, gte.1}. + // forward_slicing: kNca. + // backward_slicing: true. + // remove_sharding: false. + // reduce_tuple_parameter: false. + // slicing_group: 1 + { + // Generate two sliced modules, sliced from gte.0 and gte.1, respectively + // (`slicing_group` = 1). + std::vector relevant_instructions({gte_0, gte_1}); + SlicingConfiguration slicing_config = { + /*forward_slicing=*/SlicingConfiguration::ForwardSlicingConfig::kNca, + /*backward_slicing=*/true, /*remove_sharding=*/false, + /*reduce_tuple_parameter=*/false, /*slicing_group=*/1}; + std::vector> sliced_modules = + SliceModuleAndExtract(hlo_module.get(), + /*slice_starting_instructions=*/ + absl::MakeSpan(relevant_instructions), + /*slicing_configuration=*/slicing_config); + + // There are two sliced module. + EXPECT_EQ(sliced_modules.size(), 2); + + // The first sliced module contains gte.0 and p.0. + auto sliced_module_0 = std::move(sliced_modules[0]); + EXPECT_EQ(sliced_module_0->entry_computation()->instruction_count(), 2); + HloInstruction* p_0 = FindInstruction(sliced_module_0.get(), "p.0"); + EXPECT_NE(p_0, nullptr); + + // The second sliced module contains gte.1 and p.1. + auto sliced_module_1 = std::move(sliced_modules[1]); + EXPECT_EQ(sliced_module_0->entry_computation()->instruction_count(), 2); + HloInstruction* p_1 = FindInstruction(sliced_module_1.get(), "p.1"); + EXPECT_NE(p_1, nullptr); } }