From 40da335cd5aa66ec1f51b043b2eff6b21e6af0cf Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Fri, 11 Oct 2024 18:40:34 -0700 Subject: [PATCH] [xla:gatherExpander:NFC] Tidy up GatherIsBroadcast. Change the return type from int64_t to bool. Use the routine in InstructionMatchesPattern to avoid duplicating the code and help understanding. PiperOrigin-RevId: 685024418 --- third_party/xla/xla/service/gather_expander.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/service/gather_expander.cc b/third_party/xla/xla/service/gather_expander.cc index 8277a7b902ad02..e873733d00a069 100644 --- a/third_party/xla/xla/service/gather_expander.cc +++ b/third_party/xla/xla/service/gather_expander.cc @@ -287,7 +287,7 @@ int64_t GatherLoopTripCount(HloInstruction* gather_instr) { return trip_count; } -int64_t GatherIsBroadcast(HloInstruction* gather_instr) { +bool GatherIsBroadcast(HloInstruction* gather_instr) { return absl::c_equal(gather_instr->gather_slice_sizes(), gather_instr->operand(0)->shape().dimensions()); } @@ -412,8 +412,7 @@ bool GatherExpander::InstructionMatchesPattern(HloInstruction* inst) { // which can be represented without a loop -- i.e. we only simplify // gathers which have a trip count of 1. (mode_ == kEliminateAllGathers || GatherLoopTripCount(inst) == 1 || - absl::c_equal(inst->gather_slice_sizes(), - inst->operand(0)->shape().dimensions())); + GatherIsBroadcast(inst)); } } // namespace xla