Skip to content

Commit

Permalink
Merge pull request tensorflow#61706 from philipphack:u_fp8_dynamicsli…
Browse files Browse the repository at this point in the history
…ce_xla

PiperOrigin-RevId: 560628637
  • Loading branch information
tensorflower-gardener committed Aug 28, 2023
2 parents a620914 + 3e96679 commit a5e55b4
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 10 deletions.
30 changes: 20 additions & 10 deletions tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ HloInstruction *PadOperandToMultipleOf16(absl::Span<const int64_t> batch_dims,
return PadOperandToTargetShape(padded_shape, x);
}

// Recursively collects unary, pad, divide or multiply operands of instr until
// an instruction with FP8 element type is reached. Returns std::nullopt when no
// FP8 instruction is reached.
// Recursively collects unary, divide, dynamic-slice, pad or multiply operands
// of instr until an instruction with FP8 element type is reached. Returns
// std::nullopt when no FP8 instruction is reached.
std::optional<std::vector<HloInstruction *>> FindF8SubgraphRecursive(
HloInstruction *instr, absl::flat_hash_set<int> &visited_instrs,
std::vector<HloInstruction *> subgraph) {
Expand All @@ -163,6 +163,7 @@ std::optional<std::vector<HloInstruction *>> FindF8SubgraphRecursive(
return subgraph;
} else {
if (instr->operand_count() == 1 || instr->opcode() == HloOpcode::kDivide ||
instr->opcode() == HloOpcode::kDynamicSlice ||
instr->opcode() == HloOpcode::kPad) {
return FindF8SubgraphRecursive(instr->mutable_operand(0), visited_instrs,
subgraph);
Expand Down Expand Up @@ -230,15 +231,17 @@ bool IsSupportedF8Pattern(HloInstruction *instr, HloInstruction *&x,
};
for (int i = 3; i < subgraph->size(); ++i) {
// The remaining instructions must be commutative with dequantization.
// Bitcast, broadcast, copy, pad, reshape, slice and all-gather instructions
// are supported. Specifically, the 'all-gather' operation is permitted only
// in SPMD or no-partition cases since the optimization cannot be guaranteed
// to be applied to all replicas in the MPMD scenario.
// Bitcast, broadcast, copy, dynamic-slice, pad, reshape, slice,
// all-gather, all-to-all and collective-permute instructions are supported.
// Specifically, the all-gather, all-to-all and collective-permute
// operations are permitted only in SPMD cases since the optimization cannot
// be guaranteed to be applied to all replicas in the MPMD scenario.
if (!Match(
(*subgraph)[i],
m::AnyOf<HloInstruction>(
m::Bitcast().WithPredicate(preserves_element_type),
m::Broadcast(), m::Copy(), m::Pad(), m::Reshape(), m::Slice(),
m::Broadcast(), m::Copy(), m::DynamicSlice(), m::Pad(),
m::Reshape(), m::Slice(),
m::AllGather().WithPredicate(use_spmd_partitioning),
m::AllToAll().WithPredicate(use_spmd_partitioning),
m::CollectivePermute().WithPredicate(use_spmd_partitioning)))) {
Expand Down Expand Up @@ -906,13 +909,20 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
return false;
}

// Sequentially apply the collected unary and pad ops to the unconverted and
// unscaled operands.
// Sequentially apply the collected unary, dynamic-slice and pad ops to the
// unconverted and unscaled operands.
auto shift_unary_ops =
[&instr](HloInstruction *&x,
std::vector<HloInstruction *> &x_unary_ops) -> void {
for (HloInstruction *unary_op : x_unary_ops) {
std::vector<HloInstruction *> operands = {x};
// Insert the additional operands of dynamic-slice ops.
if (unary_op->opcode() == HloOpcode::kDynamicSlice) {
for (int i = 1; i < unary_op->operand_count(); ++i) {
operands.emplace_back(unary_op->mutable_operand(i));
}
}
// Convert the second operand of pad ops.
if (unary_op->opcode() == HloOpcode::kPad) {
HloInstruction *convert =
instr->AddInstruction(HloInstruction::CreateConvert(
Expand Down
64 changes: 64 additions & 0 deletions tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4970,6 +4970,70 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) {
)");
}

TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) {
#if CUDA_VERSION < 12000
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
#endif

const char* hlo_text = R"(
HloModule test
ENTRY test {
x = f8e4m3fn[32,32] parameter(0)
y = f8e4m3fn[16,32] parameter(1)
zero = s32[] constant(0)
x_f32 = f32[32,32] convert(x)
y_f32 = f32[16,32] convert(y)
x_scale = f32[] parameter(2)
y_scale = f32[] parameter(3)
x_scale_bcast = f32[32,32] broadcast(x_scale), dimensions={}
y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={}
x_unscaled = f32[32,32] multiply(x_f32, x_scale_bcast)
y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast)
dyn_slice = f32[16,32]{1,0} dynamic-slice(x_unscaled, zero, zero), dynamic_slice_sizes={16,32}
ROOT dot_a = f32[16,16] dot(dyn_slice, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
GemmRewriter pass(
se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0});
TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get()));
EXPECT_TRUE(changed);

CheckFp8IfSupported(hlo_text);
RunAndFilecheckHloRewrite(hlo_text,
GemmRewriter(se::CudaComputeCapability{
se::CudaComputeCapability::HOPPER, 0}),
R"(
; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[32,32], y: f8e4m3fn[16,32], x_scale: f32[], y_scale: f32[]) -> f32[16,16] {
; CHECK: [[P0:%[^ ]+]] = f8e4m3fn[32,32]{1,0} parameter(0)
; CHECK: [[C0:%[^ ]+]] = s32[] constant(0)
; CHECK: [[DYN_SLICE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} dynamic-slice([[P0]], [[C0]], [[C0]]), dynamic_slice_sizes={16,32}
; CHECK: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1)
; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2)
; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3)
; CHECK: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK: ROOT [[OUT:%[^ ]+]] = f32[16,16]{1,0} custom-call([[DYN_SLICE]], [[P1]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config={
; CHECK-DAG: "alpha_real":1
; CHECK-DAG: "alpha_imag":0
; CHECK-DAG: "beta":0
; CHECK-DAG: "dot_dimension_numbers":{
; CHECK-DAG: "lhs_contracting_dimensions":["1"]
; CHECK-DAG: "rhs_contracting_dimensions":["1"]
; CHECK-DAG: "lhs_batch_dimensions":[]
; CHECK-DAG: "rhs_batch_dimensions":[]
; CHECK-DAG: }
; CHECK-DAG: "precision_config":{
; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
; CHECK-DAG: }
; CHECK-DAG: "epilogue":"DEFAULT"
; CHECK: }
)");
}

TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) {
#if CUDA_VERSION < 12000
GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above.";
Expand Down

0 comments on commit a5e55b4

Please sign in to comment.