Skip to content

Commit

Permalink
[XLA:GPU] Triton GEMM: enable more fusions of binary elementwise oper…
Browse files Browse the repository at this point in the history
…ations of broadcasts.

PiperOrigin-RevId: 573143351
  • Loading branch information
Ilia Sergachev authored and tensorflower-gardener committed Oct 13, 2023
1 parent b5953b8 commit 16350d0
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 29 deletions.
110 changes: 81 additions & 29 deletions third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,16 @@ class FusionContext {
// around `hlo`.
FusionDecision RequireSupportedDimOrders(const HloInstruction& hlo,
DimOrderUpdates& updates) const;
// Try to calculate transformations of dimensions defined by the
// instruction, then check that the resulting dimension orders are supported.
DimOrderUpdatesOrError RequireSupportedInstruction(
const HloInstruction& hlo, const DimOrderMap& dim_orders,
TransformDirection direction) const;
// Checks if the instruction is possible and profitable to fuse.
// If so tries to transform dim_order describing one side of `hlo` into
// description(s) of its other side if it is supported.
DimOrderUpdatesOrError AnalyzeForFusion(
const HloInstruction& hlo, bool as_input,
const HloInstruction& hlo, TransformDirection transform_direction,
absl::flat_hash_map<const HloInstruction*, HloInstruction*>&
old_to_new_mapping,
se::GpuComputeCapability gpu_version) const;
Expand Down Expand Up @@ -421,6 +426,12 @@ class FusionContext {
const DimOrderMap& DimOrders() const { return dim_orders_; }

private:
DimOrderUpdatesOrError AnalyzeForFusionImpl(
const HloInstruction& hlo, TransformDirection transform_direction,
absl::flat_hash_map<const HloInstruction*, HloInstruction*>&
old_to_new_mapping,
const DimOrderMap& dim_orders,
se::GpuComputeCapability gpu_version) const;
bool SetSplittableDimensionMajorPartSize(const int64_t size) {
if (IsSupportedSplittableDimensionMajorPartSize(size)) {
std::get<DotProperties>(properties_)
Expand Down Expand Up @@ -927,7 +938,7 @@ DimOrderUpdatesOrError FusionContext::HandleDimensionAlteringOp(
DimOrderUpdatesOrError FusionContext::HandleInstruction(
const HloInstruction* hlo, const DimOrderMap& dim_orders,
const TransformDirection direction) const {
VLOG(7) << hlo->ToString();
VLOG(7) << "Analyzing " << hlo->ToString();
if (hlo->opcode() == HloOpcode::kParameter ||
hlo_query::IsScalarConstant(hlo)) {
return DimOrderUpdates{};
Expand All @@ -940,6 +951,9 @@ DimOrderUpdatesOrError FusionContext::HandleInstruction(
}
return HandleDimensionAlteringOp(hlo, dim_orders, direction);
} else if (hlo->opcode() == HloOpcode::kReduce) {
if (!std::holds_alternative<SoftmaxProperties>(properties_)) {
return "Reductions are not supported in GEMM fusions yet.";
}
if (direction != TransformDirection::kOutputToInput) {
return "Unsupported direction of reduction.";
}
Expand Down Expand Up @@ -1011,17 +1025,37 @@ bool IsOutputWorthFusing(const HloInstruction& hlo) {
InputMinusOutputBytes(hlo) >= -kIoToleranceBytes;
}

DimOrderUpdatesOrError FusionContext::RequireSupportedInstruction(
const HloInstruction& hlo, const DimOrderMap& dim_orders,
const TransformDirection transform_direction) const {
auto result = HandleInstruction(&hlo, dim_orders, transform_direction);
if (!std::holds_alternative<DimOrderUpdates>(result)) {
return std::get<FusionDecision>(result);
}

if (FusionDecision supported =
RequireSupportedDimOrders(hlo, std::get<DimOrderUpdates>(result));
!supported) {
return supported;
}
return std::get<DimOrderUpdates>(result);
}

DimOrderUpdatesOrError FusionContext::AnalyzeForFusion(
const HloInstruction& hlo, bool as_input,
const HloInstruction& hlo, const TransformDirection transform_direction,
absl::flat_hash_map<const HloInstruction*, HloInstruction*>&
old_to_new_mapping,
const se::GpuComputeCapability gpu_version) const {
int fusion_level =
hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level();
if (!std::get<se::CudaComputeCapability>(gpu_version)
.IsAtLeast(se::CudaComputeCapability::AMPERE)) {
fusion_level = std::min(fusion_level, 1);
}
return AnalyzeForFusionImpl(hlo, transform_direction, old_to_new_mapping,
dim_orders_, gpu_version);
}

DimOrderUpdatesOrError FusionContext::AnalyzeForFusionImpl(
const HloInstruction& hlo, const TransformDirection transform_direction,
absl::flat_hash_map<const HloInstruction*, HloInstruction*>&
old_to_new_mapping,
const DimOrderMap& dim_orders,
const se::GpuComputeCapability gpu_version) const {
if (hlo.opcode() == HloOpcode::kTuple ||
hlo.opcode() == HloOpcode::kGetTupleElement) {
return "Unsupported instruction.";
Expand All @@ -1041,7 +1075,18 @@ DimOrderUpdatesOrError FusionContext::AnalyzeForFusion(
if (!IsTritonSupportedDataType(hlo.shape().element_type(), gpu_version)) {
return "Unsupported output data type.";
}
if (as_input) {
DimOrderUpdatesOrError result =
RequireSupportedInstruction(hlo, dim_orders, transform_direction);
if (!std::holds_alternative<DimOrderUpdates>(result)) {
return result;
}
int fusion_level =
hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level();
if (!std::get<se::CudaComputeCapability>(gpu_version)
.IsAtLeast(se::CudaComputeCapability::AMPERE)) {
fusion_level = std::min(fusion_level, 1);
}
if (transform_direction == TransformDirection::kOutputToInput) {
if (fusion_level < 2) {
if (hlo.opcode() == HloOpcode::kConvert) {
if (FusionDecision decision =
Expand All @@ -1053,7 +1098,25 @@ DimOrderUpdatesOrError FusionContext::AnalyzeForFusion(
return "Ignored elementwise operation";
}
} else {
if (!IsInputWorthFusing(hlo)) {
// Exception for binary elementwise operations: in most cases these are
// not trivial to fuse because they increase DRAM traffic but if one
// of the inputs is for example a broadcast that can be fused too it
// becomes worth fusing. Look ahead and analyze operands here.
bool accepted = false;
if (hlo.IsElementwise() && hlo.operand_count() == 2) {
for (const HloInstruction* operand : hlo.operands()) {
if (operand->opcode() == HloOpcode::kBroadcast &&
(operand->operand(0)->opcode() == HloOpcode::kParameter ||
operand->operand(0)->opcode() == HloOpcode::kConstant) &&
std::holds_alternative<DimOrderUpdates>(AnalyzeForFusionImpl(
*operand, transform_direction, old_to_new_mapping,
std::get<DimOrderUpdates>(result).map, gpu_version))) {
accepted = true;
break;
}
}
}
if (!accepted && !IsInputWorthFusing(hlo)) {
return "Not obviously profitable to fuse as input.";
}
}
Expand All @@ -1079,20 +1142,6 @@ DimOrderUpdatesOrError FusionContext::AnalyzeForFusion(
return "Not obviously profitable to fuse as output.";
}
}

auto result =
HandleInstruction(&hlo, dim_orders_,
as_input ? TransformDirection::kOutputToInput
: TransformDirection::kInputToOutput);
if (!std::holds_alternative<DimOrderUpdates>(result)) {
return std::get<FusionDecision>(result);
}

if (FusionDecision supported =
RequireSupportedDimOrders(hlo, std::get<DimOrderUpdates>(result));
!supported) {
return supported;
}
return std::get<DimOrderUpdates>(result);
}

Expand Down Expand Up @@ -1193,8 +1242,9 @@ void FusionContext::TryToFuseWithInputsRecursively(
continue;
}
num_requeued = 0;
const DimOrderUpdatesOrError result = AnalyzeForFusion(
*hlo, /*as_input=*/true, old_to_new_mapping, gpu_version);
const DimOrderUpdatesOrError result =
AnalyzeForFusion(*hlo, TransformDirection::kOutputToInput,
old_to_new_mapping, gpu_version);
if (!std::holds_alternative<DimOrderUpdates>(result) ||
!MergeUpdates(std::get<DimOrderUpdates>(result))) {
continue;
Expand All @@ -1207,6 +1257,7 @@ void FusionContext::TryToFuseWithInputsRecursively(
to_fuse_list.push_back(hlo);
for (HloInstruction* operand : hlo->operands()) {
if (enqueued.insert(operand).second) {
VLOG(6) << "Enqueueing " << operand->ToString();
to_visit.push(operand);
}
}
Expand Down Expand Up @@ -1293,8 +1344,9 @@ StatusOr<FusionDecision> FuseDot(HloInstruction& dot,
if (!IsDistributiveOverAddition(*user)) {
break;
}
auto result = context.AnalyzeForFusion(*user, /*as_input=*/false,
old_to_new_mapping, gpu_version);
auto result =
context.AnalyzeForFusion(*user, TransformDirection::kInputToOutput,
old_to_new_mapping, gpu_version);
if (!std::holds_alternative<DimOrderUpdates>(result)) {
continue;
}
Expand Down
39 changes: 39 additions & 0 deletions third_party/xla/xla/service/gpu/gemm_rewriter_triton_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,45 @@ ENTRY e {
GmockMatch(m::Fusion(m::Parameter(), m::Parameter())));
}

TEST_F(GemmRewriterTritonTest, BinaryElementwiseOfBroadcastIsFused) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
ENTRY e {
p2 = f32[3072] parameter(2)
b = f32[8192,3072] broadcast(p2), dimensions={1}
p0 = f16[8192,3072] parameter(0)
p0c = f32[8192,3072] convert(p0)
a = f32[8192,3072] add(p0c, b)
p1 = f32[3072,768] parameter(1)
ROOT r = f32[8192,768] dot(a, p1),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
})"));
const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
EXPECT_TRUE(GemmRewriterTriton(cc).Run(module.get()).value());
EXPECT_THAT(
module->entry_computation()->root_instruction(),
GmockMatch(m::Fusion(m::Parameter(), m::Parameter(), m::Parameter())));
}

TEST_F(GemmRewriterTritonTest,
BinaryElementwiseOfUnsupportedBroadcastIsNotFused) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(R"(
ENTRY e {
p2 = f32[768] parameter(2)
b = f32[8192,768,4] broadcast(p2), dimensions={1}
s = f32[8192,3072] bitcast(b)
p0 = f16[8192,3072] parameter(0)
p0c = f32[8192,3072] convert(p0)
a = f32[8192,3072] add(p0c, s)
p1 = f32[3072,768] parameter(1)
ROOT r = f32[8192,768] dot(a, p1),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
})"));
const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0};
EXPECT_FALSE(GemmRewriterTriton(cc).Run(module.get()).value());
}

class GemmRewriterTritonLevel2Test : public GemmRewriterTritonTest {
public:
DebugOptions GetDebugOptionsForTest() override {
Expand Down

0 comments on commit 16350d0

Please sign in to comment.