Skip to content

Commit

Permalink
[xla:gatherExpander:NFC] Tidy up GatherIsBroadcast.
Browse files Browse the repository at this point in the history
Change the return type from int64_t to bool. Use the routine in
InstructionMatchesPattern to avoid duplicating the code and help understanding.

PiperOrigin-RevId: 685024418
  • Loading branch information
bixia1 authored and tensorflower-gardener committed Oct 12, 2024
1 parent a3fa648 commit 40da335
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions third_party/xla/xla/service/gather_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ int64_t GatherLoopTripCount(HloInstruction* gather_instr) {
return trip_count;
}

int64_t GatherIsBroadcast(HloInstruction* gather_instr) {
bool GatherIsBroadcast(HloInstruction* gather_instr) {
return absl::c_equal(gather_instr->gather_slice_sizes(),
gather_instr->operand(0)->shape().dimensions());
}
Expand Down Expand Up @@ -412,8 +412,7 @@ bool GatherExpander::InstructionMatchesPattern(HloInstruction* inst) {
// which can be represented without a loop -- i.e. we only simplify
// gathers which have a trip count of 1.
(mode_ == kEliminateAllGathers || GatherLoopTripCount(inst) == 1 ||
absl::c_equal(inst->gather_slice_sizes(),
inst->operand(0)->shape().dimensions()));
GatherIsBroadcast(inst));
}

} // namespace xla

0 comments on commit 40da335

Please sign in to comment.