From 2635d85c1cac2562197f5aedfeba4274c0454e9c Mon Sep 17 00:00:00 2001 From: Marcello Maggioni Date: Fri, 22 Sep 2023 00:36:41 -0700 Subject: [PATCH] [XLA] Fix big memory allocation and compile time slowdown for WhileLoopInvariantCodeMotion. The pass allocates in some cases the input for the while-cond computation, but it really only needs one of the tuple elements of the parameter, not the whole tuple. Allocating the whole tuple (because the inputs to the loop can be big) can result in very large memory allocations an significant time spent on memory allocation/initialization. PiperOrigin-RevId: 567543997 --- third_party/xla/xla/service/BUILD | 1 + .../xla/xla/service/while_loop_analysis.cc | 21 ++++++++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 2aee4c3a9ec378..b32c8fb37c070f 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -3209,6 +3209,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_reachability", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", ], ) diff --git a/third_party/xla/xla/service/while_loop_analysis.cc b/third_party/xla/xla/service/while_loop_analysis.cc index 7a881be0474fe8..dae16d93928085 100644 --- a/third_party/xla/xla/service/while_loop_analysis.cc +++ b/third_party/xla/xla/service/while_loop_analysis.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/while_loop_analysis.h" #include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -614,15 +615,29 @@ optional ComputeWhileLoopTripCountUpperBound( << while_body_indvar->ToString(); return nullopt; } + // Create a new while cond computation accessing only the single parameter + // extracted by the GTE above to avoid excessive memory allocation for the + // evaluator. + absl::flat_hash_map> + replacements; + auto new_param = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({cond_gte->shape()}), "temp"); + replacements[cond_gte] = + HloInstruction::CreateGetTupleElement(new_param.get(), 0); + replacements[while_cond_param] = std::move(new_param); + auto new_module = std::make_unique("temp_mod", HloModuleConfig{}); + auto* new_computation = new_module->AddEmbeddedComputation( + while_cond->CloneWithReplacements(&replacements)); // We have a constant. Evaluate the condition on this constant. HloEvaluator evaluator(/*max_loop_iterations=*/0); - Literal fake_input = Literal::CreateFromShape(while_cond_param->shape()); + Literal fake_input = Literal::CreateFromShape( + new_computation->parameter_instruction(0)->shape()); TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(), - /*dest_shape_index=*/{indvar_index}, + /*dest_shape_index=*/{0}, /*src_shape_index=*/{})); StatusOr eval_result = - evaluator.Evaluate(*while_cond, {std::move(fake_input)}); + evaluator.Evaluate(*new_computation, {std::move(fake_input)}); if (!eval_result.ok()) { VLOG(2) << "Couldn't evaluate while loop condition.";