diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 2aee4c3a9ec378..b32c8fb37c070f 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -3209,6 +3209,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_reachability", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", ], ) diff --git a/third_party/xla/xla/service/while_loop_analysis.cc b/third_party/xla/xla/service/while_loop_analysis.cc index 7a881be0474fe8..dae16d93928085 100644 --- a/third_party/xla/xla/service/while_loop_analysis.cc +++ b/third_party/xla/xla/service/while_loop_analysis.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/while_loop_analysis.h" #include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -614,15 +615,29 @@ optional ComputeWhileLoopTripCountUpperBound( << while_body_indvar->ToString(); return nullopt; } + // Create a new while cond computation accessing only the single parameter + // extracted by the GTE above to avoid excessive memory allocation for the + // evaluator. + absl::flat_hash_map> + replacements; + auto new_param = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({cond_gte->shape()}), "temp"); + replacements[cond_gte] = + HloInstruction::CreateGetTupleElement(new_param.get(), 0); + replacements[while_cond_param] = std::move(new_param); + auto new_module = std::make_unique("temp_mod", HloModuleConfig{}); + auto* new_computation = new_module->AddEmbeddedComputation( + while_cond->CloneWithReplacements(&replacements)); // We have a constant. Evaluate the condition on this constant. HloEvaluator evaluator(/*max_loop_iterations=*/0); - Literal fake_input = Literal::CreateFromShape(while_cond_param->shape()); + Literal fake_input = Literal::CreateFromShape( + new_computation->parameter_instruction(0)->shape()); TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(), - /*dest_shape_index=*/{indvar_index}, + /*dest_shape_index=*/{0}, /*src_shape_index=*/{})); StatusOr eval_result = - evaluator.Evaluate(*while_cond, {std::move(fake_input)}); + evaluator.Evaluate(*new_computation, {std::move(fake_input)}); if (!eval_result.ok()) { VLOG(2) << "Couldn't evaluate while loop condition.";