Skip to content

Commit

Permalink
[XLA] Fix big memory allocation and compile time slowdown for WhileLo…
Browse files Browse the repository at this point in the history
…opInvariantCodeMotion.

The pass allocates in some cases the input for the while-cond computation, but it really only
needs one of the tuple elements of the parameter, not the whole tuple.

Allocating the whole tuple (because the inputs to the loop can be big) can result in very large
memory allocations an significant time spent on memory allocation/initialization.

PiperOrigin-RevId: 567543997
  • Loading branch information
Marcello Maggioni authored and tensorflower-gardener committed Sep 22, 2023
1 parent 41154a6 commit 2635d85
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3209,6 +3209,7 @@ cc_library(
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_reachability",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
],
)

Expand Down
21 changes: 18 additions & 3 deletions third_party/xla/xla/service/while_loop_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "xla/service/while_loop_analysis.h"

#include "absl/base/casts.h"
#include "absl/container/flat_hash_map.h"
#include "xla/hlo/evaluator/hlo_evaluator.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
Expand Down Expand Up @@ -614,15 +615,29 @@ optional<int64_t> ComputeWhileLoopTripCountUpperBound(
<< while_body_indvar->ToString();
return nullopt;
}
// Create a new while cond computation accessing only the single parameter
// extracted by the GTE above to avoid excessive memory allocation for the
// evaluator.
absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements;
auto new_param = HloInstruction::CreateParameter(
0, ShapeUtil::MakeTupleShape({cond_gte->shape()}), "temp");
replacements[cond_gte] =
HloInstruction::CreateGetTupleElement(new_param.get(), 0);
replacements[while_cond_param] = std::move(new_param);
auto new_module = std::make_unique<HloModule>("temp_mod", HloModuleConfig{});
auto* new_computation = new_module->AddEmbeddedComputation(
while_cond->CloneWithReplacements(&replacements));

// We have a constant. Evaluate the condition on this constant.
HloEvaluator evaluator(/*max_loop_iterations=*/0);
Literal fake_input = Literal::CreateFromShape(while_cond_param->shape());
Literal fake_input = Literal::CreateFromShape(
new_computation->parameter_instruction(0)->shape());
TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(),
/*dest_shape_index=*/{indvar_index},
/*dest_shape_index=*/{0},
/*src_shape_index=*/{}));
StatusOr<Literal> eval_result =
evaluator.Evaluate(*while_cond, {std::move(fake_input)});
evaluator.Evaluate(*new_computation, {std::move(fake_input)});

if (!eval_result.ok()) {
VLOG(2) << "Couldn't evaluate while loop condition.";
Expand Down

0 comments on commit 2635d85

Please sign in to comment.