Skip to content

Commit

Permalink
[XLA] Propagate the layout of layout constrained custom calls with hi…
Browse files Browse the repository at this point in the history
…gher

priority because they have no ability to accept another layout.

PiperOrigin-RevId: 674755625
  • Loading branch information
blakehechtman authored and tensorflower-gardener committed Sep 15, 2024
1 parent 84a4740 commit 14fbade
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 21 deletions.
48 changes: 27 additions & 21 deletions third_party/xla/xla/service/layout_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -741,27 +741,6 @@ absl::Status LayoutAssignment::AddMandatoryConstraints(
}
}
}
} else if (IsLayoutConstrainedCustomCall(instruction)) {
const HloCustomCallInstruction* custom_call =
DynCast<HloCustomCallInstruction>(instruction);

TF_RETURN_IF_ERROR(SetInstructionLayout(custom_call->shape(), custom_call,
/*mandatory=*/true, /*dfs=*/true,
/*allow_alias=*/true));
if (custom_call->IsCustomCall("LayoutConstraint")) {
TF_RETURN_IF_ERROR(
SetOperandLayout(custom_call->shape(), custom_call, 0));
} else {
for (int64_t i = 0; i < custom_call->operand_count(); ++i) {
if (AnyOperandBufferForwarded(custom_call, i)) {
TF_RET_CHECK(AllOperandBuffersForwarded(custom_call, i))
<< "Partial alias of an operand is not supported";
} else {
TF_RETURN_IF_ERROR(SetOperandLayout(
custom_call->operand_shapes_with_layout()[i], custom_call, i));
}
}
}
} else if (IsLayoutConstrainedCollective(instruction)) {
TF_RETURN_IF_ERROR(
SetInstructionLayout(instruction->shape(), instruction));
Expand Down Expand Up @@ -2476,6 +2455,33 @@ absl::Status LayoutAssignment::RunOnComputation(
// Add any backend-specific constraints.
TF_RETURN_IF_ERROR(AddBackendConstraints(constraints));

for (HloInstruction* instruction :
constraints->computation()->MakeInstructionPostOrder()) {
if (!IsLayoutConstrainedCustomCall(instruction)) {
continue;
}
const HloCustomCallInstruction* custom_call =
DynCast<HloCustomCallInstruction>(instruction);

TF_RETURN_IF_ERROR(SetInstructionLayout(custom_call->shape(), custom_call,
/*mandatory=*/true, /*dfs=*/true,
/*allow_alias=*/true));
if (custom_call->IsCustomCall("LayoutConstraint")) {
TF_RETURN_IF_ERROR(
SetOperandLayout(custom_call->shape(), custom_call, 0));
} else {
for (int64_t i = 0; i < custom_call->operand_count(); ++i) {
if (AnyOperandBufferForwarded(custom_call, i)) {
TF_RET_CHECK(AllOperandBuffersForwarded(custom_call, i))
<< "Partial alias of an operand is not supported";
} else {
TF_RETURN_IF_ERROR(SetOperandLayout(
custom_call->operand_shapes_with_layout()[i], custom_call, i));
}
}
}
}

// Propagates layouts from mandatory and backend constraints.
TF_RETURN_IF_ERROR(PropagateConstraints(constraints));

Expand Down
36 changes: 36 additions & 0 deletions third_party/xla/xla/service/layout_assignment_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,42 @@ ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3
ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
}

TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedAndElementwise) {
const char* module_str = R"(
HloModule CustomCallLayoutConstrained
ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
p0 = f32[4,4] parameter(0)
p1 = f32[2,3] parameter(1)
cc = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}}
ROOT e = f32[1,2,3,4] exponential(cc)
}
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<VerifiedHloModule> m,
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
ComputationLayout computation_layout = m->entry_computation_layout();
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ShapeUtil::MakeShapeWithDenseLayout(F32, {4, 4}, {1, 0}));
*computation_layout.mutable_parameter_layout(1) =
ShapeLayout(ShapeUtil::MakeShapeWithDenseLayout(F32, {2, 3}, {1, 0}));
*computation_layout.mutable_result_layout() = ShapeLayout(
ShapeUtil::MakeShapeWithDenseLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
AssignLayouts(m.get(), &computation_layout);

// The custom call should be partially encapsulated in kCopy instructions
// because of the layout mismatches.
ASSERT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Copy(m::Exp(m::CustomCall(m::Copy(), m::Parameter())))));

const HloInstruction* custom_call =
m->entry_computation()->root_instruction()->operand(0)->operand(0);
ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1});
ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
}

TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedAliasedOutput) {
const char* module_str = R"(
HloModule customcall.4
Expand Down

0 comments on commit 14fbade

Please sign in to comment.