From 14fbade4a33070bcdbbfd1f9c118354b7618b848 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Sat, 14 Sep 2024 19:36:10 -0700 Subject: [PATCH] [XLA] Propagate the layout of layout constrained custom calls with higher priority because they have no ability to accept another layout. PiperOrigin-RevId: 674755625 --- .../xla/xla/service/layout_assignment.cc | 48 +++++++++++-------- .../xla/xla/service/layout_assignment_test.cc | 36 ++++++++++++++ 2 files changed, 63 insertions(+), 21 deletions(-) diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index 02885038251c06..7bb3e9ad89a9b4 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -741,27 +741,6 @@ absl::Status LayoutAssignment::AddMandatoryConstraints( } } } - } else if (IsLayoutConstrainedCustomCall(instruction)) { - const HloCustomCallInstruction* custom_call = - DynCast(instruction); - - TF_RETURN_IF_ERROR(SetInstructionLayout(custom_call->shape(), custom_call, - /*mandatory=*/true, /*dfs=*/true, - /*allow_alias=*/true)); - if (custom_call->IsCustomCall("LayoutConstraint")) { - TF_RETURN_IF_ERROR( - SetOperandLayout(custom_call->shape(), custom_call, 0)); - } else { - for (int64_t i = 0; i < custom_call->operand_count(); ++i) { - if (AnyOperandBufferForwarded(custom_call, i)) { - TF_RET_CHECK(AllOperandBuffersForwarded(custom_call, i)) - << "Partial alias of an operand is not supported"; - } else { - TF_RETURN_IF_ERROR(SetOperandLayout( - custom_call->operand_shapes_with_layout()[i], custom_call, i)); - } - } - } } else if (IsLayoutConstrainedCollective(instruction)) { TF_RETURN_IF_ERROR( SetInstructionLayout(instruction->shape(), instruction)); @@ -2476,6 +2455,33 @@ absl::Status LayoutAssignment::RunOnComputation( // Add any backend-specific constraints. TF_RETURN_IF_ERROR(AddBackendConstraints(constraints)); + for (HloInstruction* instruction : + constraints->computation()->MakeInstructionPostOrder()) { + if (!IsLayoutConstrainedCustomCall(instruction)) { + continue; + } + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + + TF_RETURN_IF_ERROR(SetInstructionLayout(custom_call->shape(), custom_call, + /*mandatory=*/true, /*dfs=*/true, + /*allow_alias=*/true)); + if (custom_call->IsCustomCall("LayoutConstraint")) { + TF_RETURN_IF_ERROR( + SetOperandLayout(custom_call->shape(), custom_call, 0)); + } else { + for (int64_t i = 0; i < custom_call->operand_count(); ++i) { + if (AnyOperandBufferForwarded(custom_call, i)) { + TF_RET_CHECK(AllOperandBuffersForwarded(custom_call, i)) + << "Partial alias of an operand is not supported"; + } else { + TF_RETURN_IF_ERROR(SetOperandLayout( + custom_call->operand_shapes_with_layout()[i], custom_call, i)); + } + } + } + } + // Propagates layouts from mandatory and backend constraints. TF_RETURN_IF_ERROR(PropagateConstraints(constraints)); diff --git a/third_party/xla/xla/service/layout_assignment_test.cc b/third_party/xla/xla/service/layout_assignment_test.cc index 0b294c46ddef17..fab2df5eb1acfb 100644 --- a/third_party/xla/xla/service/layout_assignment_test.cc +++ b/third_party/xla/xla/service/layout_assignment_test.cc @@ -1250,6 +1250,42 @@ ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3 ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0}); } +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedAndElementwise) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrained + +ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] { + p0 = f32[4,4] parameter(0) + p1 = f32[2,3] parameter(1) + cc = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}} + ROOT e = f32[1,2,3,4] exponential(cc) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = m->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithDenseLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(ShapeUtil::MakeShapeWithDenseLayout(F32, {2, 3}, {1, 0})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithDenseLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(m.get(), &computation_layout); + + // The custom call should be partially encapsulated in kCopy instructions + // because of the layout mismatches. + ASSERT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Copy(m::Exp(m::CustomCall(m::Copy(), m::Parameter()))))); + + const HloInstruction* custom_call = + m->entry_computation()->root_instruction()->operand(0)->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); + ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1}); + ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0}); +} + TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedAliasedOutput) { const char* module_str = R"( HloModule customcall.4