From b12aa805683ce6f4fdf527234dfc6aae21520373 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 12 Aug 2024 01:37:38 -0700 Subject: [PATCH] PR #15919: [GPU] Use CuDnnThunk for FMHA. Imported from GitHub PR https://github.com/openxla/xla/pull/15919 CuDnnThunk currently used for GEMM fusions is capable of executing arbitrary cuDNN graphs. Moving FMHA to use it lets remove lots of specialized runtime code. The overview of the change is: - cuda_dnn.cc: At cuDNN graph construction assign tensor UIDs using their order in HLO to match the CuDnnThunk calling convention instead of using custom constants. - cuda_dnn.h/cc: Move dropout seed / offset / increment to the CudnnGraph properties and handle them accordingly during graph execution. - Rename cudnn_workspace_rewriter to cudnn_custom_call_compiler and let it set workspace as it did before + compile and serialize graphs just like cudnn_fusion_compiler aiming CuDnnThunks already does. - Move the remainders of the MHA config / descriptor logic to cudnn_custom_call_compiler from the deleted fused_mha_runner. - ir_emitter_unnested.cc: Remove MHA-specific logic, create CuDnnThunks for MHA custom calls the same universal way that works for cuDNN GEMM fusions. - Delete no more necessary special thunks, runners, lazy ops, command buffer commands. Copybara import of the project: -- 5d5b046a6ee8771e33b6c6b0f41d380205277129 by Ilia Sergachev : [GPU] Use CuDnnThunk for FMHA. CuDnnThunk currently used for GEMM fusions is capable of executing arbitrary cuDNN graphs. Moving FMHA to use it lets remove lots of specialized runtime code. The overview of the change is: - cuda_dnn.cc: At cuDNN graph construction assign tensor UIDs using their order in HLO to match the CuDnnThunk calling convention instead of using custom constants. - cuda_dnn.h/cc: Move dropout seed / offset / increment to the CudnnGraph properties and handle them accordingly during graph execution. - Rename cudnn_workspace_rewriter to cudnn_custom_call_compiler and let it set workspace as it dif before + compile and serialize graphs just like cudnn_fusion_compiler aiming CuDnnThunks already does. - Move the remainders of the MHA config / descriptor logic to cudnn_custom_call_compiler from the deleted fused_mha_runner. - ir_emitter_unnested.cc: Remove MHA-specific logic, create CuDnnThunks for MHA custom calls the same universal way that works for cuDNN GEMM fusions. - Delete no more necessary special thunks, runners, lazy ops, command buffer commands. Merging this change closes #15919 PiperOrigin-RevId: 661991045 --- third_party/xla/xla/service/gpu/BUILD | 30 +- .../xla/xla/service/gpu/gpu_compiler.cc | 4 +- .../xla/xla/service/gpu/gpu_compiler.h | 8 +- .../xla/service/gpu/gpu_fused_mha_runner.cc | 719 ------------------ .../xla/service/gpu/gpu_fused_mha_runner.h | 431 ----------- .../xla/service/gpu/ir_emitter_unnested.cc | 235 +----- .../xla/xla/service/gpu/ir_emitter_unnested.h | 3 +- .../xla/xla/service/gpu/nvptx_compiler.cc | 13 +- .../xla/xla/service/gpu/nvptx_compiler.h | 6 +- third_party/xla/xla/service/gpu/runtime/BUILD | 24 - .../service/gpu/runtime/command_buffer_cmd.cc | 308 -------- .../service/gpu/runtime/command_buffer_cmd.h | 109 --- .../gpu/runtime/command_buffer_cmd_emitter.cc | 26 - .../xla/xla/service/gpu/transforms/BUILD | 8 +- .../transforms/cudnn_custom_call_compiler.cc | 660 ++++++++++++++++ ...ewriter.h => cudnn_custom_call_compiler.h} | 24 +- .../transforms/cudnn_workspace_rewriter.cc | 272 ------- .../xla/xla/stream_executor/cuda/cuda_dnn.cc | 465 +++-------- .../xla/xla/stream_executor/cuda/cuda_dnn.h | 34 +- .../cuda/cudnn_frontend_helpers.h | 5 + third_party/xla/xla/stream_executor/dnn.cc | 36 - third_party/xla/xla/stream_executor/dnn.h | 55 -- .../xla/xla/stream_executor/lazy_op_runner.h | 70 -- 23 files changed, 842 insertions(+), 2703 deletions(-) delete mode 100644 third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc delete mode 100644 third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h create mode 100644 third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc rename third_party/xla/xla/service/gpu/transforms/{cudnn_workspace_rewriter.h => cudnn_custom_call_compiler.h} (61%) delete mode 100644 third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.cc diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 96388e0c2b7864..65ec5409801709 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -313,7 +313,6 @@ cc_library( ":execution_stream_assignment", ":gpu_asm_opts_util", ":gpu_conv_runner", - ":gpu_fused_mha_runner", ":gpu_norm_runner", ":hlo_fusion_analysis", ":ir_emission_utils", @@ -356,9 +355,9 @@ cc_library( "//xla/service/gpu/runtime:conditional_thunk", "//xla/service/gpu/runtime:convolution_thunk", "//xla/service/gpu/runtime:copy_thunk", + "//xla/service/gpu/runtime:cudnn_thunk", "//xla/service/gpu/runtime:custom_call_thunk", "//xla/service/gpu/runtime:fft_thunk", - "//xla/service/gpu/runtime:fused_mha_thunk", "//xla/service/gpu/runtime:gemm_thunk", "//xla/service/gpu/runtime:gpublas_lt_matmul_thunk", "//xla/service/gpu/runtime:infeed_thunk", @@ -1045,31 +1044,6 @@ cc_library( ]), ) -cc_library( - name = "gpu_fused_mha_runner", - srcs = ["gpu_fused_mha_runner.cc"], - hdrs = ["gpu_fused_mha_runner.h"], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":stream_executor_util", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/stream_executor", - "//xla/stream_executor:dnn", - "//xla/stream_executor:lazy_op_runner", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "cusolver_context", srcs = if_gpu_is_configured(["cusolver_context.cc"]), @@ -1779,6 +1753,7 @@ cc_library( "//xla/service/gpu/transforms:conv_padding_legalization", "//xla/service/gpu/transforms:conv_rewriter", "//xla/service/gpu/transforms:cublas_pad_for_gemms", + "//xla/service/gpu/transforms:cudnn_custom_call_compiler", "//xla/service/gpu/transforms:cudnn_fused_conv_rewriter", "//xla/service/gpu/transforms:cudnn_fused_mha_rewriter", "//xla/service/gpu/transforms:cudnn_fused_mha_transpose_fusion", @@ -1787,7 +1762,6 @@ cc_library( "//xla/service/gpu/transforms:cudnn_pad_for_convolutions", "//xla/service/gpu/transforms:cudnn_simplify_padding", "//xla/service/gpu/transforms:cudnn_vectorize_convolutions", - "//xla/service/gpu/transforms:cudnn_workspace_rewriter", "//xla/service/gpu/transforms:dot_sparsity_rewriter", "//xla/service/gpu/transforms:gpusolver_rewriter", "//xla/service/gpu/transforms:sort_rewriter", diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index d367419d3b4776..46eacbbe358869 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -2150,8 +2150,8 @@ absl::StatusOr> GpuCompiler::RunBackend( }}; BinaryMap dnn_compiled_graphs; if (stream_exec) { - TF_RETURN_IF_ERROR(RunCudnnFusionCompilerPass(module.get(), stream_exec, - &dnn_compiled_graphs)); + TF_RETURN_IF_ERROR(RunCudnnCompilerPasses(module.get(), stream_exec, + &dnn_compiled_graphs)); } const DebugOptions& debug_opts = module->config().debug_options(); diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.h b/third_party/xla/xla/service/gpu/gpu_compiler.h index aa22bfcf3ba338..456e6755b0d83a 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.h +++ b/third_party/xla/xla/service/gpu/gpu_compiler.h @@ -171,10 +171,10 @@ class GpuCompiler : public LLVMCompiler { return absl::OkStatus(); } - // Runs cuDNN fusion compiler pass. - virtual absl::Status RunCudnnFusionCompilerPass( - HloModule* module, se::StreamExecutor* stream_exec, - BinaryMap* dnn_compiled_graphs) { + // Runs cuDNN fusion and custom call compiler passes. + virtual absl::Status RunCudnnCompilerPasses(HloModule* module, + se::StreamExecutor* stream_exec, + BinaryMap* dnn_compiled_graphs) { return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc deleted file mode 100644 index 566c0068f5dbba..00000000000000 --- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc +++ /dev/null @@ -1,719 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/gpu_fused_mha_runner.h" - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "Eigen/Core" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/shape.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/dnn.h" -#include "xla/stream_executor/lazy_op_runner.h" -#include "xla/stream_executor/stream.h" -#include "xla/util.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { - -namespace { -using se::DeviceMemory; -using se::DeviceMemoryBase; -using se::dnn::DataType; -using se::dnn::MatmulTensorDescriptor; -using se::dnn::TensorDescriptor; - -template -absl::Status RunFusedMHA(GpufMHAParams params, se::Stream *stream, - RunFusedMHAOptions options, - DeviceMemory lhs_bmm1_buffer, - DeviceMemory rhs_bmm1_buffer, - DeviceMemory rhs_bmm2_buffer, - DeviceMemory output_buffer, - DeviceMemoryBase bias_buffer, - DeviceMemoryBase scratch_memory, - DeviceMemoryBase activation_output, - DeviceMemoryBase seqlen_q, DeviceMemoryBase seqlen_k) { - se::dnn::LazyOpRunner *lazy_runner = - options.runner_cache->AsFusedMHARunner(); - std::optional> local_runner; - if (!lazy_runner) { - local_runner.emplace(params.config->algorithm); - lazy_runner = &*local_runner; - } - std::optional dropout_rate; - if (params.config->dropout_rate) { - dropout_rate = *params.config->dropout_rate; - } - - std::optional seed; - if (params.config->seed) { - seed = *params.config->seed; - } - - TF_ASSIGN_OR_RETURN(se::dnn::FusedMHAOp::Config config, - params.config->AsDnnFusedMHAOpConfig()); - TF_ASSIGN_OR_RETURN(auto *runner, - lazy_runner->GetOrCreateRunner(config, stream)); - return (*runner)(stream, options.profile_result, scratch_memory, - lhs_bmm1_buffer, rhs_bmm1_buffer, rhs_bmm2_buffer, - output_buffer, bias_buffer, activation_output, seqlen_q, - seqlen_k); -} - -template -absl::Status RunGpuFMHAImpl(const GpufMHAParams ¶ms, se::Stream *stream, - se::DeviceMemoryBase scratch_memory, - RunFusedMHAOptions options) { - auto lhs_bmm1_buffer = se::DeviceMemory(params.lhs_bmm1_buffer); - auto rhs_bmm1_buffer = se::DeviceMemory(params.rhs_bmm1_buffer); - auto rhs_bmm2_buffer = se::DeviceMemory(params.rhs_bmm2_buffer); - auto output_buffer = se::DeviceMemory(params.output_buffer); - auto activation_buffer = - params.activation_buffer.has_value() - ? se::DeviceMemory(*params.activation_buffer) - : se::DeviceMemoryBase(); - auto bias_buffer = params.bias_buffer.has_value() - ? se::DeviceMemory(*params.bias_buffer) - : se::DeviceMemoryBase(); - auto seqlen_q_buffer = - params.seqlen_q_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_q_buffer) - : se::DeviceMemoryBase(); - auto seqlen_k_buffer = - params.seqlen_k_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_k_buffer) - : se::DeviceMemoryBase(); - se::dnn::AlgorithmDesc algorithm = params.config->algorithm; - if (options.runner_cache) { - algorithm = options.runner_cache->ToAlgorithmDesc(); - } - - absl::Status run_status = absl::OkStatus(); - switch (params.config->kind) { - case CudnnfMHAKind::kSoftmaxDropout: - case CudnnfMHAKind::kSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmaxDropout: - run_status = RunFusedMHA( - params, stream, options, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, bias_buffer, scratch_memory, - activation_buffer, seqlen_q_buffer, seqlen_k_buffer); - break; - default: - return Internal("Invalid cuDNN fMHA kind"); - } - - if (!run_status.ok()) { - return run_status; - } - - if (!stream->ok()) { - return Internal("Unable to launch FMHA with type %s and algorithm %s", - CudnnfMHAKindToString(params.config->kind), - algorithm.ToString()); - } - - return absl::OkStatus(); -} - -template -absl::Status RunFusedMHABackward( - GpufMHABackwardParams params, se::Stream *stream, - RunFusedMHABackwardOptions options, - DeviceMemory bmm1_grad_gemm1_rhs_buffer, - DeviceMemory bmm1_grad_gemm2_rhs_buffer, - DeviceMemory bmm2_grad_gemm1_lhs_buffer, - DeviceMemory bmm2_grad_gemm2_rhs_buffer, - DeviceMemory d_output_buffer, - DeviceMemory d_bmm1_lhs_buffer, - DeviceMemory d_bmm1_rhs_buffer, - DeviceMemory d_bmm2_rhs_buffer, DeviceMemoryBase d_s_buffer, - DeviceMemoryBase d_bias_buffer, DeviceMemoryBase fwd_output_buffer, - DeviceMemoryBase bias_buffer, DeviceMemoryBase scratch_memory, - DeviceMemoryBase seqlen_q, DeviceMemoryBase seqlen_k) { - se::dnn::LazyOpRunner *lazy_runner = - options.runner_cache->AsFusedMHABackwardRunner(); - std::optional> - local_runner; - if (!lazy_runner) { - local_runner.emplace(params.config->algorithm); - lazy_runner = &*local_runner; - } - std::optional dropout_rate; - if (params.config->dropout_rate) { - dropout_rate = *params.config->dropout_rate; - } - - std::optional seed; - if (params.config->seed) { - seed = *params.config->seed; - } - - TF_ASSIGN_OR_RETURN(se::dnn::FusedMHABackwardOp::Config config, - params.config->AsDnnFusedMHABackwardOpConfig()); - TF_ASSIGN_OR_RETURN(auto *runner, - lazy_runner->GetOrCreateRunner(config, stream)); - // TODO: pass in real softmax_sum, dQ_accum, fwd_output - return (*runner)(stream, options.profile_result, scratch_memory, - bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, - d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, - d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, - fwd_output_buffer, bias_buffer, seqlen_q, seqlen_k); - return absl::OkStatus(); -} - -template -absl::Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, - se::Stream *stream, - se::DeviceMemoryBase scratch_memory, - RunFusedMHABackwardOptions options) { - auto bmm1_grad_gemm1_rhs_buffer = - se::DeviceMemory(params.bmm1_grad_gemm1_rhs_buffer); - auto bmm1_grad_gemm2_rhs_buffer = - se::DeviceMemory(params.bmm1_grad_gemm2_rhs_buffer); - auto bmm2_grad_gemm1_lhs_buffer = - se::DeviceMemory(params.bmm2_grad_gemm1_lhs_buffer); - auto bmm2_grad_gemm2_rhs_buffer = - se::DeviceMemory(params.bmm2_grad_gemm2_rhs_buffer); - auto d_output_buffer = se::DeviceMemory(params.d_output_buffer); - auto d_bmm1_lhs_buffer = - se::DeviceMemory(params.d_bmm1_lhs_buffer); - auto d_bmm1_rhs_buffer = - se::DeviceMemory(params.d_bmm1_rhs_buffer); - auto d_bmm2_rhs_buffer = - se::DeviceMemory(params.d_bmm2_rhs_buffer); - - // optional buffers - auto d_s_buffer = params.d_s_buffer.has_value() - ? se::DeviceMemory(*params.d_s_buffer) - : se::DeviceMemoryBase(); - - auto d_bias_buffer = params.d_bias_buffer.has_value() - ? se::DeviceMemory(*params.d_bias_buffer) - : se::DeviceMemoryBase(); - - auto fwd_output_buffer = - params.fwd_output_buffer.has_value() - ? se::DeviceMemory(*params.fwd_output_buffer) - : se::DeviceMemoryBase(); - - auto bias_buffer = params.bias_buffer.has_value() - ? se::DeviceMemory(*params.bias_buffer) - : se::DeviceMemoryBase(); - - auto seqlen_q_buffer = - params.seqlen_q_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_q_buffer) - : se::DeviceMemoryBase(); - - auto seqlen_k_buffer = - params.seqlen_k_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_k_buffer) - : se::DeviceMemoryBase(); - - se::dnn::AlgorithmDesc algorithm = params.config->algorithm; - if (options.runner_cache) { - algorithm = options.runner_cache->ToAlgorithmDesc(); - } - - absl::Status run_status = absl::OkStatus(); - switch (params.config->kind) { - case CudnnfMHAKind::kBackwardSoftmaxDropout: - case CudnnfMHAKind::kBackwardSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout: - run_status = RunFusedMHABackward( - params, stream, options, bmm1_grad_gemm1_rhs_buffer, - bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer, - bmm2_grad_gemm2_rhs_buffer, d_output_buffer, d_bmm1_lhs_buffer, - d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, - fwd_output_buffer, bias_buffer, scratch_memory, seqlen_q_buffer, - seqlen_k_buffer); - break; - default: - return Internal("Invalid cuDNN fMHA kind"); - } - - if (!run_status.ok()) { - return run_status; - } - - if (!stream->ok()) { - return Internal("Unable to launch FMHA with type %s and algorithm %s", - CudnnfMHAKindToString(params.config->kind), - algorithm.ToString()); - } - - return run_status; -} -} // namespace - -/*static*/ absl::StatusOr GpufMHAConfig::For( - const GpufMHADescriptor &desc) { - // Get shapes from desc. - const Shape &lhs_bmm1_shape = desc.lhs_bmm1_shape; - const Shape &rhs_bmm1_shape = desc.rhs_bmm1_shape; - const Shape &rhs_bmm2_shape = desc.rhs_bmm2_shape; - const Shape &intermediate_lhs_bmm2_shape = desc.intermediate_lhs_bmm2_shape; - const Shape &output_shape = desc.output_shapes[0]; - - // Get DNN dtype from primtive types - TF_ASSIGN_OR_RETURN( - DataType lhs_bmm1_type, - GetDNNDataTypeFromPrimitiveType(lhs_bmm1_shape.element_type())); - TF_ASSIGN_OR_RETURN( - DataType rhs_bmm1_type, - GetDNNDataTypeFromPrimitiveType(rhs_bmm1_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType rhs_bmm2_type, - GetDNNDataTypeFromPrimitiveType(rhs_bmm2_shape.element_type())); - TF_ASSIGN_OR_RETURN(DataType lhs_bmm2_type, - GetDNNDataTypeFromPrimitiveType( - intermediate_lhs_bmm2_shape.element_type())); - TF_ASSIGN_OR_RETURN(DataType output_type, GetDNNDataTypeFromPrimitiveType( - output_shape.element_type())); - GpufMHAConfig config; - config.input_type = lhs_bmm1_shape.element_type(); - config.output_type = output_shape.element_type(); - - // Get MatmulTensorDescriptors for BMM1 - config.lhs_bmm1 = - MatmulTensorDescriptor::For(lhs_bmm1_type, lhs_bmm1_shape.dimensions(), - desc.lhs_bmm1_shape.layout().minor_to_major(), - desc.bmm1_dnums.lhs_batch_dimensions(), - desc.bmm1_dnums.lhs_contracting_dimensions()); - config.rhs_bmm1 = - MatmulTensorDescriptor::For(rhs_bmm1_type, rhs_bmm1_shape.dimensions(), - desc.rhs_bmm1_shape.layout().minor_to_major(), - desc.bmm1_dnums.rhs_batch_dimensions(), - desc.bmm1_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for BMM2 - config.rhs_bmm2 = - MatmulTensorDescriptor::For(rhs_bmm2_type, rhs_bmm2_shape.dimensions(), - desc.rhs_bmm2_shape.layout().minor_to_major(), - desc.bmm2_dnums.rhs_batch_dimensions(), - desc.bmm2_dnums.rhs_contracting_dimensions()); - - config.intermediate_lhs_bmm2 = MatmulTensorDescriptor::For( - lhs_bmm2_type, intermediate_lhs_bmm2_shape.dimensions(), - desc.intermediate_lhs_bmm2_shape.layout().minor_to_major(), - desc.bmm2_dnums.lhs_batch_dimensions(), - desc.bmm2_dnums.lhs_contracting_dimensions()); - - config.output = TensorDescriptor::For(output_type, output_shape.dimensions(), - output_shape.layout().minor_to_major()); - - if (desc.output_shapes.size() > 1) { - const Shape &activation_shape = desc.output_shapes.back(); - // Generally, activation should have same type as output, but set it - // explicityly just to be safe. - TF_ASSIGN_OR_RETURN( - DataType activation_type, - GetDNNDataTypeFromPrimitiveType(activation_shape.element_type())); - config.activation = - TensorDescriptor::For(activation_type, activation_shape.dimensions(), - activation_shape.layout().minor_to_major()); - } - - if (desc.mask_shape) { - const Shape &mask_shape = *desc.mask_shape; - TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType( - mask_shape.element_type())); - config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(), - mask_shape.layout().minor_to_major()); - } - - if (desc.bias_shape) { - const Shape &bias_shape = *desc.bias_shape; - TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType( - bias_shape.element_type())); - config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(), - bias_shape.layout().minor_to_major()); - } - config.kind = desc.kind; - config.mask_type = desc.mask_type; - const CudnnfMHABackendConfig &backend_config = desc.backend_config; - config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); - config.fmha_scale.emplace(backend_config.fmha_scale()); - config.dropout_rate.emplace(backend_config.dropout_rate()); - config.seed.emplace(backend_config.seed()); - return config; -} - -absl::StatusOr -GpufMHAConfig::AsDnnFusedMHAOpConfig() const { - double scale = 1.0; - if (fmha_scale.has_value()) { - scale = *fmha_scale; - } - TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type)); - - return se::dnn::FusedMHAOp::Config{ - scale, lhs_bmm1, rhs_bmm1, rhs_bmm2, intermediate_lhs_bmm2, - output, bias, activation, dropout_rate, seed, - mask_type}; -} - -/*static*/ absl::StatusOr GpufMHABackwardConfig::For( - const GpufMHABackwardDescriptor &desc) { - // Get shapes from desc. - - const Shape &bmm1_grad_gemm1_rhs_shape = desc.bmm1_grad_gemm1_rhs_shape; - const Shape &bmm1_grad_gemm2_rhs_shape = desc.bmm1_grad_gemm2_rhs_shape; - const Shape &bmm2_grad_gemm1_lhs_shape = desc.bmm2_grad_gemm1_lhs_shape; - const Shape &bmm2_grad_gemm2_rhs_shape = desc.bmm2_grad_gemm2_rhs_shape; - const Shape &d_output_shape = desc.d_output_shape; - const Shape &d_bmm1_lhs_shape = desc.d_bmm1_lhs_shape; - const Shape &d_bmm1_rhs_shape = desc.d_bmm1_rhs_shape; - const Shape &d_bmm2_rhs_shape = desc.d_bmm2_rhs_shape; - // Get DNN dtype from primtive types - TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm1_rhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm1_grad_gemm1_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm2_rhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm1_grad_gemm2_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm1_lhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm2_grad_gemm1_lhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm2_rhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm2_grad_gemm2_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_output_type, - GetDNNDataTypeFromPrimitiveType(d_output_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_bmm1_lhs_type, - GetDNNDataTypeFromPrimitiveType(d_bmm1_lhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_bmm1_rhs_type, - GetDNNDataTypeFromPrimitiveType(d_bmm1_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_bmm2_rhs_type, - GetDNNDataTypeFromPrimitiveType(d_bmm2_rhs_shape.element_type())); - - GpufMHABackwardConfig config; - config.input_type = bmm1_grad_gemm1_rhs_shape.element_type(); - config.output_type = d_bmm1_lhs_shape.element_type(); - - // Get MatmulTensorDescriptors for lhs of BMM1 grad GEMM 1 - config.bmm1_grad_gemm1_rhs = MatmulTensorDescriptor::For( - bmm1_grad_gemm1_rhs_type, bmm1_grad_gemm1_rhs_shape.dimensions(), - desc.bmm1_grad_gemm1_rhs_shape.layout().minor_to_major(), - desc.bmm1_grad_gemm1_dnums.rhs_batch_dimensions(), - desc.bmm1_grad_gemm1_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for rhs of BMM1 grad GEMM 2 - config.bmm1_grad_gemm2_rhs = MatmulTensorDescriptor::For( - bmm1_grad_gemm2_rhs_type, bmm1_grad_gemm2_rhs_shape.dimensions(), - desc.bmm1_grad_gemm2_rhs_shape.layout().minor_to_major(), - desc.bmm1_grad_gemm2_dnums.rhs_batch_dimensions(), - desc.bmm1_grad_gemm2_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for BMM2 grad GEMM 1 - config.bmm2_grad_gemm1_lhs = MatmulTensorDescriptor::For( - bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(), - desc.bmm2_grad_gemm1_lhs_shape.layout().minor_to_major(), - desc.bmm2_grad_gemm1_dnums.lhs_batch_dimensions(), - desc.bmm2_grad_gemm1_dnums.lhs_contracting_dimensions()); - - config.d_output = MatmulTensorDescriptor::For( - d_output_type, d_output_shape.dimensions(), - desc.d_output_shape.layout().minor_to_major(), - desc.bmm2_grad_gemm1_dnums.rhs_batch_dimensions(), - desc.bmm2_grad_gemm1_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for BMM2 grad GEMM 2 - config.bmm2_grad_gemm2_rhs = MatmulTensorDescriptor::For( - bmm2_grad_gemm2_rhs_type, bmm2_grad_gemm2_rhs_shape.dimensions(), - desc.bmm2_grad_gemm2_rhs_shape.layout().minor_to_major(), - desc.bmm2_grad_gemm2_dnums.rhs_batch_dimensions(), - desc.bmm2_grad_gemm2_dnums - .rhs_contracting_dimensions()); // FMHA TODO: transpose here? - - config.d_bmm1_lhs = - TensorDescriptor::For(d_bmm1_lhs_type, d_bmm1_lhs_shape.dimensions(), - d_bmm1_lhs_shape.layout().minor_to_major()); - config.d_bmm1_rhs = - TensorDescriptor::For(d_bmm1_rhs_type, d_bmm1_rhs_shape.dimensions(), - d_bmm1_rhs_shape.layout().minor_to_major()); - config.d_bmm2_rhs = - TensorDescriptor::For(d_bmm2_rhs_type, d_bmm2_rhs_shape.dimensions(), - d_bmm2_rhs_shape.layout().minor_to_major()); - config.d_s = TensorDescriptor::For( - bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(), - bmm2_grad_gemm1_lhs_shape.layout().minor_to_major()); - - if (desc.d_bias_shape) { - const Shape &d_bias_shape = *desc.d_bias_shape; - // Get DNN dtype from primtive types - TF_ASSIGN_OR_RETURN(DataType d_bias_type, GetDNNDataTypeFromPrimitiveType( - d_bias_shape.element_type())); - config.d_bias = - TensorDescriptor::For(d_bias_type, d_bias_shape.dimensions(), - d_bias_shape.layout().minor_to_major()); - } - - if (desc.mask_shape) { - const Shape &mask_shape = *desc.mask_shape; - TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType( - mask_shape.element_type())); - config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(), - mask_shape.layout().minor_to_major()); - } - if (desc.fwd_output_shape) { - const Shape &fwd_output_shape = *desc.fwd_output_shape; - TF_ASSIGN_OR_RETURN( - DataType fwd_output_type, - GetDNNDataTypeFromPrimitiveType(fwd_output_shape.element_type())); - config.fwd_output = - TensorDescriptor::For(fwd_output_type, fwd_output_shape.dimensions(), - fwd_output_shape.layout().minor_to_major()); - } - - if (desc.bias_shape) { - const Shape &bias_shape = *desc.bias_shape; - TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType( - bias_shape.element_type())); - config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(), - bias_shape.layout().minor_to_major()); - } - - config.kind = desc.kind; - config.mask_type = desc.mask_type; - config.force_deterministic = desc.force_deterministic; - const CudnnfMHABackendConfig &backend_config = desc.backend_config; - config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); - config.fmha_scale.emplace(backend_config.fmha_scale()); - config.dropout_rate.emplace(backend_config.dropout_rate()); - config.seed.emplace(backend_config.seed()); - return config; -} - -absl::StatusOr -GpufMHABackwardConfig::AsDnnFusedMHABackwardOpConfig() const { - double scale = 1.0; - if (fmha_scale.has_value()) { - scale = *fmha_scale; - } - TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type)); - - return se::dnn::FusedMHABackwardOp::Config{scale, - bmm1_grad_gemm1_rhs, - bmm1_grad_gemm2_rhs, - bmm2_grad_gemm1_lhs, - bmm2_grad_gemm2_rhs, - d_output, - d_bmm1_lhs, - d_bmm1_rhs, - d_bmm2_rhs, - d_s, - d_bias, - fwd_output, - bias, - dropout_rate, - seed, - mask_type, - force_deterministic}; -} - -/*static*/ absl::StatusOr GpufMHAParams::For( - const GpufMHAConfig &config, se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, se::DeviceMemoryBase rhs_bmm2_buffer, - se::DeviceMemoryBase output_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer) { - GpufMHAParams params; - params.config = &config; - params.lhs_bmm1_buffer = lhs_bmm1_buffer; - params.rhs_bmm1_buffer = rhs_bmm1_buffer; - params.rhs_bmm2_buffer = rhs_bmm2_buffer; - params.output_buffer = output_buffer; - params.activation_buffer = activation_buffer; - params.bias_buffer = bias_buffer; - params.seqlen_q_buffer = seqlen_q_buffer; - params.seqlen_k_buffer = seqlen_k_buffer; - return params; -} - -/*static*/ absl::StatusOr GpufMHABackwardParams::For( - const GpufMHABackwardConfig &config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer) { - GpufMHABackwardParams params; - params.config = &config; - params.bmm1_grad_gemm1_rhs_buffer = bmm1_grad_gemm1_rhs_buffer; - params.bmm1_grad_gemm2_rhs_buffer = bmm1_grad_gemm2_rhs_buffer; - params.bmm2_grad_gemm1_lhs_buffer = bmm2_grad_gemm1_lhs_buffer; - params.bmm2_grad_gemm2_rhs_buffer = bmm2_grad_gemm2_rhs_buffer; - params.d_output_buffer = d_output_buffer; - params.d_bmm1_lhs_buffer = d_bmm1_lhs_buffer; - params.d_bmm1_rhs_buffer = d_bmm1_rhs_buffer; - params.d_bmm2_rhs_buffer = d_bmm2_rhs_buffer; - params.d_s_buffer = d_s_buffer; - params.d_bias_buffer = d_bias_buffer; - params.fwd_output_buffer = fwd_output_buffer; - params.bias_buffer = bias_buffer; - params.seqlen_q_buffer = seqlen_q_buffer; - params.seqlen_k_buffer = seqlen_k_buffer; - return params; -} - -absl::Status RunGpuFMHA(const GpufMHAConfig &fmha_config, - se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm2_buffer, - se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase scratch_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, - se::Stream *stream, RunFusedMHAOptions options) { - TF_ASSIGN_OR_RETURN( - GpufMHAParams params, - GpufMHAParams::For(fmha_config, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, bias_buffer, - activation_buffer, seqlen_q_buffer, seqlen_k_buffer)); - PrimitiveType input_primitive_type = fmha_config.input_type; - switch (input_primitive_type) { - case F16: - return RunGpuFMHAImpl( - params, stream, scratch_buffer, options); - case BF16: - return RunGpuFMHAImpl( - params, stream, scratch_buffer, options); - default: - return absl::UnimplementedError(absl::StrFormat( - "Unimplemented fused MHA with %s", ToString(fmha_config))); - } - return absl::OkStatus(); -} - -absl::Status RunGpuFMHABackward( - const GpufMHABackwardConfig &fmha_config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, se::Stream *stream, - RunFusedMHABackwardOptions options) { - TF_ASSIGN_OR_RETURN( - GpufMHABackwardParams params, - GpufMHABackwardParams::For( - fmha_config, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, - d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, - d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, fwd_output_buffer, - bias_buffer, seqlen_q_buffer, seqlen_k_buffer)); - PrimitiveType input_primitive_type = fmha_config.input_type; - switch (input_primitive_type) { - case F16: - return RunGpuFMHABackwardImpl( - params, stream, scratch_buffer, options); - case BF16: - return RunGpuFMHABackwardImpl(params, stream, - scratch_buffer, options); - default: - return Unimplemented("Unimplemented fused MHA backward"); - } - return absl::OkStatus(); -} - -std::string ToString(const GpufMHAConfig &config) { - std::string result = "GpufMHAConfig:\n"; - absl::StrAppend(&result, - "input_type: ", PrimitiveType_Name(config.input_type), ", "); - absl::StrAppend( - &result, "output_type: ", PrimitiveType_Name(config.output_type), ", "); - absl::StrAppend(&result, "Kind: ", CudnnfMHAKindToString(config.kind), ", "); - if (config.fmha_scale) { - absl::StrAppend(&result, "fmha_scale: ", *config.fmha_scale, ", "); - } - if (config.dropout_rate) { - absl::StrAppend(&result, "dropout_rate: ", *config.dropout_rate, ", "); - } - if (config.seed) { - absl::StrAppend(&result, "seed: ", *config.seed, ", "); - } - absl::StrAppend(&result, "Algorithm Desc: ", config.algorithm.ToString(), - "\n"); - absl::StrAppend(&result, "lhs_bmm1: ", config.lhs_bmm1.ToString(), "\n"); - absl::StrAppend(&result, "rhs_bmm1: ", config.rhs_bmm1.ToString(), "\n"); - absl::StrAppend(&result, "rhs_bmm2: ", config.rhs_bmm2.ToString(), "\n"); - absl::StrAppend(&result, "intermediate_lhs_bmm2: ", - config.intermediate_lhs_bmm2.ToString(), "\n"); - absl::StrAppend(&result, "output: ", config.output.ToString(), "\n"); - - if (config.mask) { - absl::StrAppend(&result, "mask: ", (*config.mask).ToString(), "\n"); - } - - if (config.bias) { - absl::StrAppend(&result, "bias: ", (*config.bias).ToString(), "\n"); - } - - return result; -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h deleted file mode 100644 index d0621cbdff6d74..00000000000000 --- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h +++ /dev/null @@ -1,431 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_ -#define XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/shape.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/dnn.h" -#include "xla/stream_executor/lazy_op_runner.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace gpu { - -inline absl::StatusOr AsCudnnFmhaMaskKind( - xla::gpu::CudnnfMHABackendConfig_MaskType mask_type) { - switch (mask_type) { - case xla::gpu::CudnnfMHABackendConfig::NO_MASK: - return xla::gpu::CudnnfMHAMaskKind::kNoMask; - case xla::gpu::CudnnfMHABackendConfig::PADDING: - return xla::gpu::CudnnfMHAMaskKind::kPadding; - case xla::gpu::CudnnfMHABackendConfig::CAUSAL: - return xla::gpu::CudnnfMHAMaskKind::kCausal; - case xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL: - return xla::gpu::CudnnfMHAMaskKind::kPaddingCausal; - case xla::gpu::CudnnfMHABackendConfig::ALIBI: - return xla::gpu::CudnnfMHAMaskKind::kAlibi; - default: - return xla::Internal("Unknown fmha mask kind."); - } -} - -// This is an interim structure to hold the parameters to construct a -// GpufMHAConfig. -// Struct to describe properties of a FMHA without being tied to specific -// IR. Will be used to help build FMHA thunks from either XLA HLO or -// LHLO GPU dialect in MLIR. -struct GpufMHADescriptor { - CudnnfMHAKind kind; - CudnnfMHABackendConfig backend_config; - CudnnfMHAMaskKind mask_type; - Shape lhs_bmm1_shape; - Shape rhs_bmm1_shape; - Shape rhs_bmm2_shape; - Shape intermediate_lhs_bmm2_shape; - // This will contain both output shape and activation shape - absl::InlinedVector output_shapes; - DotDimensionNumbers bmm1_dnums; - DotDimensionNumbers bmm2_dnums; - - std::optional mask_shape; - std::optional bias_shape; -}; - -struct GpufMHABackwardDescriptor { - CudnnfMHAKind kind; - CudnnfMHABackendConfig backend_config; - CudnnfMHAMaskKind mask_type; - Shape bmm1_grad_gemm1_rhs_shape; - Shape bmm1_grad_gemm2_rhs_shape; - Shape bmm2_grad_gemm1_lhs_shape; - Shape bmm2_grad_gemm2_rhs_shape; - Shape d_output_shape; - Shape d_bmm1_lhs_shape; - Shape d_bmm1_rhs_shape; - Shape d_bmm2_rhs_shape; - DotDimensionNumbers bmm1_grad_gemm1_dnums; - DotDimensionNumbers bmm1_grad_gemm2_dnums; - DotDimensionNumbers bmm2_grad_gemm1_dnums; - DotDimensionNumbers bmm2_grad_gemm2_dnums; - - std::optional d_s_shape; - std::optional fwd_output_shape; - std::optional mask_shape; - std::optional d_bias_shape; - std::optional bias_shape; - bool force_deterministic; -}; - -// Structure to describe static properties of a GPU fused Multi-Headed -// Attention. -struct GpufMHAConfig { - static absl::StatusOr For(const GpufMHADescriptor& fmha_desc); - - absl::StatusOr AsDnnFusedMHAOpConfig() const; - - PrimitiveType - input_type; // Capture the primitive type of one of the inputs of BMM1 - PrimitiveType output_type; - CudnnfMHAKind kind; - std::optional fmha_scale; - std::optional dropout_rate; - std::optional seed; - - se::dnn::AlgorithmDesc algorithm; - CudnnfMHAMaskKind mask_type; - // bias -> [1, num_attn_heads, q_seq_len, kv_seq_len] - // mask -> [batch_size, 1, q_seq_len, kv_seq_len] - se::dnn::MatmulTensorDescriptor lhs_bmm1; - se::dnn::MatmulTensorDescriptor rhs_bmm1; - se::dnn::MatmulTensorDescriptor rhs_bmm2; - se::dnn::MatmulTensorDescriptor intermediate_lhs_bmm2; - se::dnn::TensorDescriptor output; - - std::optional activation; - std::optional mask; - std::optional bias; -}; - -// Structure to describe static properties of a GPU fused Multi-Headed -// Attention backward. -struct GpufMHABackwardConfig { - static absl::StatusOr For( - const GpufMHABackwardDescriptor& fmha_desc); - - absl::StatusOr - AsDnnFusedMHABackwardOpConfig() const; - - PrimitiveType - input_type; // Capture the primitive type of one of the inputs of BMM1 - PrimitiveType output_type; - CudnnfMHAKind kind; - std::optional fmha_scale; - std::optional dropout_rate; - std::optional seed; - - se::dnn::AlgorithmDesc algorithm; - CudnnfMHAMaskKind mask_type; - // mask -> [batch_size, 1, q_seq_len, kv_seq_len] - // d_bias -> [1, num_heads, q_seq_len, kv_seq_len] - se::dnn::MatmulTensorDescriptor bmm1_grad_gemm1_rhs; - se::dnn::MatmulTensorDescriptor bmm1_grad_gemm2_rhs; - se::dnn::MatmulTensorDescriptor bmm2_grad_gemm1_lhs; - se::dnn::MatmulTensorDescriptor bmm2_grad_gemm2_rhs; - se::dnn::MatmulTensorDescriptor d_output; - se::dnn::TensorDescriptor d_bmm1_lhs; - se::dnn::TensorDescriptor d_bmm1_rhs; - se::dnn::TensorDescriptor d_bmm2_rhs; - std::optional d_s; - std::optional mask; - std::optional d_bias; - std::optional fwd_output; - std::optional bias; - bool force_deterministic; -}; - -// Implementation struct exposed for debugging and log analysis. -struct GpufMHAParams { - static absl::StatusOr For( - const GpufMHAConfig& config, se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm2_buffer, se::DeviceMemoryBase output_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer); - - const GpufMHAConfig* config; // Not owned - se::DeviceMemoryBase lhs_bmm1_buffer; - se::DeviceMemoryBase rhs_bmm1_buffer; - se::DeviceMemoryBase rhs_bmm2_buffer; - se::DeviceMemoryBase output_buffer; - std::optional activation_buffer; - std::optional bias_buffer; - std::optional seqlen_q_buffer; - std::optional seqlen_k_buffer; -}; - -struct GpufMHABackwardParams { - static absl::StatusOr For( - const GpufMHABackwardConfig& config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer); - - const GpufMHABackwardConfig* config; // Not owned - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer; - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer; - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer; - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer; - se::DeviceMemoryBase d_output_buffer; - se::DeviceMemoryBase d_bmm1_lhs_buffer; - se::DeviceMemoryBase d_bmm1_rhs_buffer; - se::DeviceMemoryBase d_bmm2_rhs_buffer; - std::optional d_s_buffer; - std::optional d_bias_buffer; - std::optional fwd_output_buffer; - std::optional bias_buffer; - std::optional seqlen_q_buffer; - std::optional seqlen_k_buffer; -}; - -class FusedMultiHeadedAttentionRunner { - public: - using Repr = - std::variant>>; - - FusedMultiHeadedAttentionRunner() = default; - - explicit FusedMultiHeadedAttentionRunner( - std::unique_ptr> runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionRunner(Repr runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionRunner(const GpufMHAConfig& config) - : FusedMultiHeadedAttentionRunner(CreateRunner(config)) { - if (std::holds_alternative(repr_)) { - CHECK(false) << "Cannot construct FusedMultiHeadedAttentionRunner with " - "std::monostate"; - } - } - - se::dnn::AlgorithmDesc ToAlgorithmDesc() const { - return std::visit(ToAlgorithmDescVisitor{}, repr_); - } - - se::dnn::LazyOpRunner* AsFusedMHARunner() { - CHECK(std::holds_alternative< - std::unique_ptr>>(repr_)); - return std::get< - std::unique_ptr>>( - repr_) - .get(); - } - - private: - // The CreateRunner function is defined as static because it - // doesn't need access to any non-static member variables of the - // FusedMultiHeadedAttentionRunner class. Defining it static makes it easy to - // use and makes it clear that it is a utility function that doesn't rely on - // the state of any specific instance of the class. - static Repr CreateRunner(const GpufMHAConfig& config) { - switch (config.kind) { - case CudnnfMHAKind::kSoftmaxDropout: - case CudnnfMHAKind::kSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmaxDropout: - return std::make_unique>( - config.algorithm); - default: - LOG(FATAL) << "Internal error: unsupported CUDNN MHA kind in " - "FusedMultiHeadedAttentionRunner"; - } - } - - struct ToAlgorithmDescVisitor { - template - se::dnn::AlgorithmDesc operator()(const RunnerPtr& runner) { - return runner->ToAlgorithmDesc(); - } - - se::dnn::AlgorithmDesc operator()(const std::monostate&) { - CHECK(false) << "Internal error: uninitialized runner in ToAlgorithmDesc"; - } - }; - - Repr repr_; -}; - -class FusedMultiHeadedAttentionBackwardRunner { - public: - using Repr = std::variant< - std::monostate, // To allow XXX default ctor - std::unique_ptr>>; - - FusedMultiHeadedAttentionBackwardRunner() = default; - - explicit FusedMultiHeadedAttentionBackwardRunner( - std::unique_ptr> - runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionBackwardRunner(Repr runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionBackwardRunner( - const GpufMHABackwardConfig& config) - : FusedMultiHeadedAttentionBackwardRunner(CreateRunner(config)) { - if (std::holds_alternative(repr_)) { - CHECK(false) - << "Cannot construct FusedMultiHeadedAttentionBackwardRunner with " - "std::monostate"; - } - } - - se::dnn::AlgorithmDesc ToAlgorithmDesc() const { - return std::visit(ToAlgorithmDescVisitor{}, repr_); - } - - se::dnn::LazyOpRunner* - AsFusedMHABackwardRunner() { - CHECK(std::holds_alternative< - std::unique_ptr>>( - repr_)); - return std::get>>(repr_) - .get(); - } - - private: - // The CreateRunner function is defined as static because it - // doesn't need access to any non-static member variables of the - // FusedMultiHeadedAttentionBackwardRunner class. Defining it static makes it - // easy to use and makes it clear that it is a utility function that doesn't - // rely on the state of any specific instance of the class. - static Repr CreateRunner(const GpufMHABackwardConfig& config) { - switch (config.kind) { - case CudnnfMHAKind::kBackwardSoftmaxDropout: - case CudnnfMHAKind::kBackwardSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout: - return std::make_unique< - se::dnn::LazyOpRunner>( - config.algorithm); - default: - LOG(FATAL) << "Internal error: unsupported CUDNN MHA kind in " - "FusedMultiHeadedAttentionBackwardRunner"; - } - } - - struct ToAlgorithmDescVisitor { - template - se::dnn::AlgorithmDesc operator()(const RunnerPtr& runner) { - return runner->ToAlgorithmDesc(); - } - - se::dnn::AlgorithmDesc operator()(const std::monostate&) { - CHECK(false) << "Internal error: uninitialized runner in ToAlgorithmDesc"; - } - }; - - Repr repr_; -}; - -struct RunFusedMHAOptions { - // Nullable output-parameter pointer for profiling results. - // Profile results remain unused for now since cuDNN FMHA has only one - // algorithm for now. - se::dnn::ProfileResult* profile_result = nullptr; - - // Use this runner cache (and its configured algorithm), instead of the one - // from the instruction. - FusedMultiHeadedAttentionRunner* runner_cache; -}; - -struct RunFusedMHABackwardOptions { - // Nullable output-parameter pointer for profiling results. - // Profile results remain unused for now since cuDNN FMHA has only one - // algorithm for now. - se::dnn::ProfileResult* profile_result = nullptr; - - // Use this runner cache (and its configured algorithm), instead of the one - // from the instruction. - FusedMultiHeadedAttentionBackwardRunner* runner_cache; -}; - -absl::Status RunGpuFMHA(const GpufMHAConfig& fmha_config, - se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm2_buffer, - se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase scratch_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, - se::Stream* stream, RunFusedMHAOptions = {}); - -absl::Status RunGpuFMHABackward( - const GpufMHABackwardConfig& fmha_config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, se::Stream* stream, - RunFusedMHABackwardOptions = {}); - -std::string ToString(const GpufMHAConfig& config); - -} // namespace gpu -} // namespace xla -#endif // XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_ diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 03690caeaa8121..b73225a1bd3c56 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -26,7 +26,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/flat_hash_map.h" @@ -99,7 +98,6 @@ limitations under the License. #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" #include "xla/service/gpu/gpu_asm_opts_util.h" #include "xla/service/gpu/gpu_conv_runner.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" #include "xla/service/gpu/gpu_norm_runner.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" @@ -122,7 +120,6 @@ limitations under the License. #include "xla/service/gpu/runtime/copy_thunk.h" #include "xla/service/gpu/runtime/custom_call_thunk.h" #include "xla/service/gpu/runtime/fft_thunk.h" -#include "xla/service/gpu/runtime/fused_mha_thunk.h" #include "xla/service/gpu/runtime/gemm_thunk.h" #include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h" #include "xla/service/gpu/runtime/infeed_thunk.h" @@ -173,6 +170,7 @@ limitations under the License. #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/service/gpu/runtime/cholesky_thunk.h" #include "xla/service/gpu/runtime/cub_sort_thunk.h" +#include "xla/service/gpu/runtime/cudnn_thunk.h" #include "xla/service/gpu/runtime/triangular_solve_thunk.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -955,221 +953,17 @@ absl::Status IrEmitterUnnested::EmitNormThunk( return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitFusedMHAThunk( +absl::Status IrEmitterUnnested::EmitCuDnnThunk( const HloCustomCallInstruction* instr) { - const HloInstruction* lhs_bmm1 = instr->operand(0); - const HloInstruction* rhs_bmm1 = instr->operand(1); - const HloInstruction* rhs_bmm2 = instr->operand(2); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_bmm1_slice, - GetAllocationSliceForHlo(lhs_bmm1)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm1_slice, - GetAllocationSliceForHlo(rhs_bmm1)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm2_slice, - GetAllocationSliceForHlo(rhs_bmm2)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, - GetAllocationSliceForHlo(instr, {0})); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, - GetAllocationSliceForHlo( - instr, {instr->shape().tuple_shapes_size() - 1})); - BufferAllocation::Slice activation_slice; - bool has_activation = xla::ShapeUtil::TupleElementCount(instr->shape()) == 3; - if (has_activation) { - TF_ASSIGN_OR_RETURN(activation_slice, GetAllocationSliceForHlo(instr, {1})); - } - - TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, - xla::gpu::GetCudnnfMHAKind(instr)); - BufferAllocation::Slice mask_slice, bias_slice; - BufferAllocation::Slice seqlen_q_slice, seqlen_k_slice; - std::optional mask_shape, bias_shape; - { - bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax || - kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout; - - if (has_bias) { - const HloInstruction* bias = instr->operand(3); - TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSliceForHlo(bias)); - bias_shape = bias->shape(); - } - int64_t seqlen_qk_operand_index = 3 + has_bias; - bool has_seqlen_qk = seqlen_qk_operand_index == instr->operand_count() - 2; - if (has_seqlen_qk) { - const HloInstruction* seqlen_q = instr->operand(seqlen_qk_operand_index); - TF_ASSIGN_OR_RETURN(seqlen_q_slice, GetAllocationSliceForHlo(seqlen_q)); - const HloInstruction* seqlen_k = - instr->operand(seqlen_qk_operand_index + 1); - TF_ASSIGN_OR_RETURN(seqlen_k_slice, GetAllocationSliceForHlo(seqlen_k)); - } - } - - TF_ASSIGN_OR_RETURN(const auto gpu_config, - instr->backend_config()); - const xla::gpu::CudnnfMHABackendConfig& config = - gpu_config.cudnn_fmha_backend_config(); - Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); - absl::InlinedVector output_shapes = { - ShapeUtil::GetSubshape(instr->shape(), {0})}; - if (has_activation) { - output_shapes.push_back(ShapeUtil::GetSubshape(instr->shape(), {1})); - } - TF_ASSIGN_OR_RETURN(const auto mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - GpufMHADescriptor descriptor = {kind, - config, - mask_type, - lhs_bmm1->shape(), - rhs_bmm1->shape(), - rhs_bmm2->shape(), - intermediate_tensor_shape, - output_shapes, - config.bmm1_dot_dimension_numbers(), - config.bmm2_dot_dimension_numbers(), - mask_shape, - bias_shape}; - - TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config, - GpufMHAConfig::For(descriptor)); - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(fmha_config), - lhs_bmm1_slice, rhs_bmm1_slice, rhs_bmm2_slice, output_slice, - scratch_slice, mask_slice, bias_slice, activation_slice, seqlen_q_slice, - seqlen_k_slice)); - return absl::OkStatus(); -} - -absl::Status IrEmitterUnnested::EmitFusedMHABackwardThunk( - const HloCustomCallInstruction* instr) { - TF_ASSIGN_OR_RETURN(const auto gpu_config, - instr->backend_config()); - const xla::gpu::CudnnfMHABackendConfig& config = - gpu_config.cudnn_fmha_backend_config(); - - int input_index = 0; - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm1_grad_gemm1_rhs_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm2_rhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm1_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm2_rhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm2_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm1_lhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm2_grad_gemm1_lhs_shape; - - Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); - bmm2_grad_gemm1_lhs_shape = intermediate_tensor_shape; - input_index++; - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_output_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape d_output_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, GetCudnnfMHAKind(instr)); - BufferAllocation::Slice mask_slice; - std::optional mask_shape; - - bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax || - kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout); - BufferAllocation::Slice bias_slice; - std::optional bias_shape; - if (has_bias) { - TF_ASSIGN_OR_RETURN(bias_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - bias_shape = instr->operand(input_index++)->shape(); - } - - BufferAllocation::Slice fwd_output_slice; - std::optional fwd_output_shape; - - TF_ASSIGN_OR_RETURN(fwd_output_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - fwd_output_shape = instr->operand(input_index++)->shape(); - - BufferAllocation::Slice seqlen_q_slice, seqlen_k_slice; - bool has_seqlen_qk = input_index == instr->operand_count() - 2; - if (has_seqlen_qk) { - const HloInstruction* seqlen_q = instr->operand(input_index); - TF_ASSIGN_OR_RETURN(seqlen_q_slice, GetAllocationSliceForHlo(seqlen_q)); - const HloInstruction* seqlen_k = instr->operand(input_index + 1); - TF_ASSIGN_OR_RETURN(seqlen_k_slice, GetAllocationSliceForHlo(seqlen_k)); - input_index += 2; - } - TF_RET_CHECK(input_index == instr->operand_count()); - - int output_index = 0; - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_lhs_slice, - GetAllocationSliceForHlo(instr, {output_index})); - Shape d_bmm1_lhs_shape = - ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_rhs_slice, - GetAllocationSliceForHlo(instr, {output_index})); - Shape d_bmm1_rhs_shape = - ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm2_rhs_slice, - GetAllocationSliceForHlo(instr, {output_index})); - Shape d_bmm2_rhs_shape = - ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - - BufferAllocation::Slice d_s_slice; - std::optional d_s_shape; - - bool has_dbias = instr->shape().tuple_shapes().size() == 5; - BufferAllocation::Slice d_bias_slice; - std::optional d_bias_shape; - if (has_dbias) { - TF_ASSIGN_OR_RETURN(d_bias_slice, - GetAllocationSliceForHlo(instr, {output_index})); - d_bias_shape = ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - } - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, - GetAllocationSliceForHlo(instr, {output_index++})); - TF_RET_CHECK(output_index == instr->shape().tuple_shapes().size()); - TF_ASSIGN_OR_RETURN(const auto mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - bool force_deterministic = config.force_deterministic(); - GpufMHABackwardDescriptor descriptor = { - kind, - config, - mask_type, - bmm1_grad_gemm1_rhs_shape, - bmm1_grad_gemm2_rhs_shape, - bmm2_grad_gemm1_lhs_shape, - bmm2_grad_gemm2_rhs_shape, - d_output_shape, - d_bmm1_lhs_shape, - d_bmm1_rhs_shape, - d_bmm2_rhs_shape, - config.bmm1_grad_gemm1_dot_dimension_numbers(), - config.bmm1_grad_gemm2_dot_dimension_numbers(), - config.bmm2_grad_gemm1_dot_dimension_numbers(), - config.bmm2_grad_gemm2_dot_dimension_numbers(), - d_s_shape, - fwd_output_shape, - mask_shape, - d_bias_shape, - bias_shape, - force_deterministic}; - - TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_backward_config, - GpufMHABackwardConfig::For(descriptor)); - - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), - std::move(fmha_backward_config), bmm1_grad_gemm1_rhs_slice, - bmm1_grad_gemm2_rhs_slice, bmm2_grad_gemm1_lhs_slice, - bmm2_grad_gemm2_rhs_slice, d_output_slice, scratch_slice, - d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, d_s_slice, - mask_slice, d_bias_slice, fwd_output_slice, bias_slice, seqlen_q_slice, - seqlen_k_slice)); - + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + KernelArguments::Create(ir_emitter_context_->buffer_assignment(), instr, + instr->operands())); + TF_ASSIGN_OR_RETURN(const std::string fingerprint, + FingerprintWithBackendConfig(*instr)); + AddThunkToThunkSequence(std::make_unique( + fingerprint, Thunk::ThunkInfo::WithProfileAnnotation(instr), + kernel_arguments.args())); return absl::OkStatus(); } @@ -2921,11 +2715,8 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( if (IsCustomCallToDnnNorm(*instr)) { return EmitNormThunk(custom_call); } - if (IsFwdCustomCallTofMHA(*instr)) { - return EmitFusedMHAThunk(custom_call); - } - if (IsBwdCustomCallTofMHA(*instr)) { - return EmitFusedMHABackwardThunk(custom_call); + if (IsCustomCallTofMHA(*instr)) { + return EmitCuDnnThunk(custom_call); } #endif // GOOGLE_CUDA if (IsCustomCallToTopK(*instr)) { diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h index f97f106ddfc0df..d19dd5d9c4172c 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h @@ -147,8 +147,7 @@ class IrEmitterUnnested : public IrEmitter { absl::Status EmitConvolutionReorderThunk( const HloCustomCallInstruction* instr); absl::Status EmitNormThunk(const HloCustomCallInstruction* instr); - absl::Status EmitFusedMHAThunk(const HloCustomCallInstruction* instr); - absl::Status EmitFusedMHABackwardThunk(const HloCustomCallInstruction* instr); + absl::Status EmitCuDnnThunk(const HloCustomCallInstruction* instr); #endif // GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM absl::Status EmitCubDeviceRadixSort(const HloCustomCallInstruction* instr); diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index a2d3988fbfb6b6..7925ebc33baa16 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -69,6 +69,7 @@ limitations under the License. #include "xla/service/gpu/transforms/conv_padding_legalization.h" #include "xla/service/gpu/transforms/conv_rewriter.h" #include "xla/service/gpu/transforms/cublas_pad_for_gemms.h" +#include "xla/service/gpu/transforms/cudnn_custom_call_compiler.h" #include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h" #include "xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h" #include "xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h" @@ -77,7 +78,6 @@ limitations under the License. #include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h" #include "xla/service/gpu/transforms/cudnn_simplify_padding.h" #include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h" -#include "xla/service/gpu/transforms/cudnn_workspace_rewriter.h" #include "xla/service/gpu/transforms/dot_sparsity_rewriter.h" #include "xla/service/gpu/transforms/gpusolver_rewriter.h" #include "xla/service/gpu/transforms/sort_rewriter.h" @@ -342,9 +342,6 @@ absl::Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( // Transform TriangularSolve ops into custom-calls, so we can add temp // memory. post_pipeline.AddPass(); - if (stream_exec) { - post_pipeline.AddPass(*stream_exec); - } TF_RETURN_IF_ERROR(post_pipeline.Run(hlo_module).status()); return absl::OkStatus(); @@ -395,15 +392,17 @@ absl::Status NVPTXCompiler::AddCustomKernelReplacementPasses( return absl::OkStatus(); } -absl::Status NVPTXCompiler::RunCudnnFusionCompilerPass( +absl::Status NVPTXCompiler::RunCudnnCompilerPasses( HloModule* module, se::StreamExecutor* stream_exec, BinaryMap* dnn_compiled_graphs) { tsl::profiler::ScopedAnnotation annotation([&] { return absl::StrFormat("XlaCompileCudnnFusion:#module=%s,program_id=%d#", module->name(), module->unique_id()); }); - CuDnnFusionCompiler cudnn_compiler(*stream_exec, *dnn_compiled_graphs); - return cudnn_compiler.Run(module).status(); + CuDnnFusionCompiler fusion_compiler(*stream_exec, *dnn_compiled_graphs); + TF_RETURN_IF_ERROR(fusion_compiler.Run(module).status()); + CuDnnCustomCallCompiler call_compiler(*stream_exec, *dnn_compiled_graphs); + return call_compiler.Run(module).status(); } namespace { diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.h b/third_party/xla/xla/service/gpu/nvptx_compiler.h index fb74f553d67967..6d84deb4398176 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.h +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.h @@ -84,9 +84,9 @@ class NVPTXCompiler : public GpuCompiler { absl::Status AddCustomKernelReplacementPasses( HloPassPipeline* pipeline, const DebugOptions& debug_options) override; - absl::Status RunCudnnFusionCompilerPass( - HloModule* module, se::StreamExecutor* stream_exec, - BinaryMap* dnn_compiled_graphs) override; + absl::Status RunCudnnCompilerPasses(HloModule* module, + se::StreamExecutor* stream_exec, + BinaryMap* dnn_compiled_graphs) override; HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const override; diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 861fcca6b45eb4..d6b6f1ffa78319 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -79,7 +79,6 @@ cc_library( "//xla/service:executable", "//xla/service:global_device_id", "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:gpu_fused_mha_runner", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:stream_executor_util", @@ -121,7 +120,6 @@ cc_library( ":copy_thunk", ":cudnn_thunk", ":custom_call_thunk", - ":fused_mha_thunk", ":gemm_thunk", ":gpublas_lt_matmul_thunk", ":kernel_thunk", @@ -660,28 +658,6 @@ cc_library( ], ) -cc_library( - name = "fused_mha_thunk", - srcs = ["fused_mha_thunk.cc"], - hdrs = ["fused_mha_thunk.h"], - deps = [ - ":thunk", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:buffer_assignment", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:gpu_fused_mha_runner", - "//xla/stream_executor", - "//xla/stream_executor:lazy_op_runner", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "gemm_thunk", srcs = ["gemm_thunk.cc"], diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc index d913871332933d..ce99623c16de6e 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -1167,314 +1167,6 @@ CommandBufferCmd::BufferUsageVector GemmCmd::buffers() { {workspace_, MemoryAccess::kWrite}}; } -//===----------------------------------------------------------------------===// -// FusedMHACmd -//===----------------------------------------------------------------------===// - -FusedMHACmd::FusedMHACmd( - ExecutionStreamId execution_stream_id, GpufMHAConfig config, - BufferAllocation::Slice lhs_bmm1, BufferAllocation::Slice rhs_bmm1, - BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output, - BufferAllocation::Slice scratch, BufferAllocation::Slice mask, - BufferAllocation::Slice bias, BufferAllocation::Slice activation, - BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k) - : TracedCommandBufferCmd(CommandBufferCmdType::kFusedMHACmd, - execution_stream_id), - config_(std::move(config)), - lhs_bmm1_buffer_(lhs_bmm1), - rhs_bmm1_buffer_(rhs_bmm1), - rhs_bmm2_buffer_(rhs_bmm2), - output_buffer_(output), - scratch_buffer_(scratch), - bias_buffer_(bias), - activation_buffer_(activation), - seqlen_q_buffer_(seqlen_q), - seqlen_k_buffer_(seqlen_k) {} - -FusedMultiHeadedAttentionRunner& FusedMHACmd::GetOrCreateRunner( - const stream_executor::Stream* stream) { - absl::MutexLock lock(&mutex_); - auto it = runner_cache_.find(stream); - if (it == runner_cache_.end()) { - it = runner_cache_ - .insert({stream, std::make_unique( - config_)}) - .first; - } - return *it->second; -} - -absl::Status FusedMHACmd::Initialize(const Thunk::InitializeParams& params, - StateManager& state) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(params.command_buffer_trace_stream).AsFusedMHARunner(); - TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHAOpConfig()); - return lazy_runner - ->GetOrCreateRunner(config, params.command_buffer_trace_stream) - .status(); -} - -absl::Status FusedMHACmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(execute_params.command_buffer_trace_stream) - .AsFusedMHARunner(); - CHECK(lazy_runner) << "FusedMHA lazy runner cache should have been populated"; - - const auto& buffer_allocations = *execute_params.buffer_allocations; - se::DeviceMemoryBase lhs_bmm1_buffer = - buffer_allocations.GetDeviceAddress(lhs_bmm1_buffer_); - se::DeviceMemoryBase rhs_bmm1_buffer = - buffer_allocations.GetDeviceAddress(rhs_bmm1_buffer_); - se::DeviceMemoryBase rhs_bmm2_buffer = - buffer_allocations.GetDeviceAddress(rhs_bmm2_buffer_); - se::DeviceMemoryBase output_buffer = - buffer_allocations.GetDeviceAddress(output_buffer_); - se::DeviceMemoryBase scratch_buffer = - buffer_allocations.GetDeviceAddress(scratch_buffer_); - - std::optional bias_buffer = - AssignBufferIfNotNull(buffer_allocations, bias_buffer_); - std::optional activation_buffer = - AssignBufferIfNotNull(buffer_allocations, activation_buffer_); - std::optional seqlen_q_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_); - std::optional seqlen_k_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_); - - ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); - VLOG(5) << "FusedMHACmd with execution_scope_id: " - << execution_scope_id.value(); - VLOG(5) << " lhs_bmm1_buffer: " << lhs_bmm1_buffer_.ToString(); - VLOG(5) << " rhs_bmm1_buffer: " << rhs_bmm1_buffer_.ToString(); - VLOG(5) << " rhs_bmm2_buffer: " << rhs_bmm2_buffer_.ToString(); - VLOG(5) << " output_buffer: " << output_buffer_.ToString(); - VLOG(5) << " scratch_buffer: " << scratch_buffer_.ToString(); - VLOG(5) << " bias_buffer: " << bias_buffer_.ToString(); - VLOG(5) << " activation_buffer: " << activation_buffer_.ToString(); - VLOG(5) << " seqlen_q_buffer: " << seqlen_q_buffer_.ToString(); - VLOG(5) << " seqlen_k_buffer: " << seqlen_k_buffer_.ToString(); - - RunFusedMHAOptions opts; - opts.runner_cache = - &GetOrCreateRunner(execute_params.command_buffer_trace_stream); - return AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { - return RunGpuFMHA(config_, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, scratch_buffer, - bias_buffer, activation_buffer, seqlen_q_buffer, - seqlen_k_buffer, stream, opts); - }); -} - -FusedMHACmd::BufferUsageVector FusedMHACmd::buffers() { - BufferUsageVector buffer_usage; - buffer_usage.reserve(9); - buffer_usage.push_back({lhs_bmm1_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({rhs_bmm1_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({rhs_bmm2_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({output_buffer_, MemoryAccess::kWrite}); - buffer_usage.push_back({scratch_buffer_, MemoryAccess::kWrite}); - if (bias_buffer_.allocation() != nullptr) { - buffer_usage.push_back({bias_buffer_, MemoryAccess::kRead}); - } - if (activation_buffer_.allocation() != nullptr) { - buffer_usage.push_back({activation_buffer_, MemoryAccess::kRead}); - } - if (seqlen_q_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_q_buffer_, MemoryAccess::kRead}); - } - if (seqlen_k_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_k_buffer_, MemoryAccess::kRead}); - } - return buffer_usage; -} - -//===----------------------------------------------------------------------===// -// FusedMHABackwardCmd -//===----------------------------------------------------------------------===// - -FusedMHABackwardCmd::FusedMHABackwardCmd( - ExecutionStreamId execution_stream_id, GpufMHABackwardConfig config, - BufferAllocation::Slice bmm1_grad_gemm1_rhs, - BufferAllocation::Slice bmm1_grad_gemm2_rhs, - BufferAllocation::Slice bmm2_grad_gemm1_lhs, - BufferAllocation::Slice bmm2_grad_gemm2_rhs, - BufferAllocation::Slice d_output, BufferAllocation::Slice scratch, - BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs, - BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s, - BufferAllocation::Slice d_bias, BufferAllocation::Slice fwd_output, - BufferAllocation::Slice bias, BufferAllocation::Slice seqlen_q, - BufferAllocation::Slice seqlen_k) - : TracedCommandBufferCmd(CommandBufferCmdType::kFusedMHABackwardCmd, - execution_stream_id), - config_(std::move(config)), - bmm1_grad_gemm1_rhs_buffer_(bmm1_grad_gemm1_rhs), - bmm1_grad_gemm2_rhs_buffer_(bmm1_grad_gemm2_rhs), - bmm2_grad_gemm1_lhs_buffer_(bmm2_grad_gemm1_lhs), - bmm2_grad_gemm2_rhs_buffer_(bmm2_grad_gemm2_rhs), - d_output_buffer_(d_output), - scratch_buffer_(scratch), - d_bmm1_lhs_buffer_(d_bmm1_lhs), - d_bmm1_rhs_buffer_(d_bmm1_rhs), - d_bmm2_rhs_buffer_(d_bmm2_rhs), - d_s_buffer_(d_s), - d_bias_buffer_(d_bias), - fwd_output_buffer_(fwd_output), - bias_buffer_(bias), - seqlen_q_buffer_(seqlen_q), - seqlen_k_buffer_(seqlen_k) {} - -FusedMultiHeadedAttentionBackwardRunner& FusedMHABackwardCmd::GetOrCreateRunner( - const stream_executor::Stream* stream) { - absl::MutexLock lock(&mutex_); - auto it = runner_cache_.find(stream); - if (it == runner_cache_.end()) { - it = runner_cache_ - .insert({stream, - std::make_unique( - config_)}) - .first; - } - return *it->second; -} - -absl::Status FusedMHABackwardCmd::Initialize( - const Thunk::InitializeParams& params, StateManager& state) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(params.command_buffer_trace_stream) - .AsFusedMHABackwardRunner(); - TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHABackwardOpConfig()); - return lazy_runner - ->GetOrCreateRunner(config, params.command_buffer_trace_stream) - .status(); -} - -absl::Status FusedMHABackwardCmd::Record( - const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(execute_params.command_buffer_trace_stream) - .AsFusedMHABackwardRunner(); - CHECK(lazy_runner) - << "FusedMHABackward lazy runner cache should have been populated"; - - const auto& buffer_allocations = *execute_params.buffer_allocations; - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm1_grad_gemm1_rhs_buffer_); - - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm1_grad_gemm2_rhs_buffer_); - - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer = - buffer_allocations.GetDeviceAddress(bmm2_grad_gemm1_lhs_buffer_); - - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm2_grad_gemm2_rhs_buffer_); - - se::DeviceMemoryBase d_output_buffer = - buffer_allocations.GetDeviceAddress(d_output_buffer_); - - se::DeviceMemoryBase scratch_buffer = - buffer_allocations.GetDeviceAddress(scratch_buffer_); - - se::DeviceMemoryBase d_bmm1_lhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm1_lhs_buffer_); - - se::DeviceMemoryBase d_bmm1_rhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm1_rhs_buffer_); - - se::DeviceMemoryBase d_bmm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm2_rhs_buffer_); - - std::optional d_s_buffer = - AssignBufferIfNotNull(buffer_allocations, d_s_buffer_); - std::optional d_bias_buffer = - AssignBufferIfNotNull(buffer_allocations, d_bias_buffer_); - std::optional fwd_output_buffer = - AssignBufferIfNotNull(buffer_allocations, fwd_output_buffer_); - std::optional bias_buffer = - AssignBufferIfNotNull(buffer_allocations, bias_buffer_); - std::optional seqlen_q_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_); - std::optional seqlen_k_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_); - - ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); - VLOG(5) << "FusedMHABackwardCmd with execution_scope_id: " - << execution_scope_id.value(); - VLOG(5) << "bmm1_grad_gemm1_rhs_buffer" - << bmm1_grad_gemm1_rhs_buffer_.ToString(); - VLOG(5) << "bmm1_grad_gemm2_rhs_buffer" - << bmm1_grad_gemm2_rhs_buffer_.ToString(); - VLOG(5) << "bmm2_grad_gemm1_lhs_buffer" - << bmm2_grad_gemm1_lhs_buffer_.ToString(); - VLOG(5) << "bmm2_grad_gemm2_rhs_buffer" - << bmm2_grad_gemm2_rhs_buffer_.ToString(); - VLOG(5) << "d_output_buffer" << d_output_buffer_.ToString(); - VLOG(5) << "scratch_buffer" << scratch_buffer_.ToString(); - VLOG(5) << "d_bmm1_lhs_buffer" << d_bmm1_lhs_buffer_.ToString(); - VLOG(5) << "d_bmm1_rhs_buffer" << d_bmm1_rhs_buffer_.ToString(); - VLOG(5) << "d_bmm2_rhs_buffer" << d_bmm2_rhs_buffer_.ToString(); - VLOG(5) << "d_s_buffer" << d_s_buffer_.ToString(); - VLOG(5) << "d_bias_buffer" << d_bias_buffer_.ToString(); - VLOG(5) << "fwd_output_buffer" << fwd_output_buffer_.ToString(); - VLOG(5) << "bias_buffer" << bias_buffer_.ToString(); - VLOG(5) << "seqlen_q_buffer" << seqlen_q_buffer_.ToString(); - VLOG(5) << "seqlen_k_buffer" << seqlen_k_buffer_.ToString(); - - RunFusedMHABackwardOptions opts; - opts.runner_cache = - &GetOrCreateRunner(execute_params.command_buffer_trace_stream); - return AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { - return RunGpuFMHABackward( - config_, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, - d_output_buffer, scratch_buffer, d_bmm1_lhs_buffer, - d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, - fwd_output_buffer, bias_buffer, seqlen_q_buffer, seqlen_k_buffer, - stream, opts); - }); -} - -FusedMHABackwardCmd::BufferUsageVector FusedMHABackwardCmd::buffers() { - BufferUsageVector buffer_usage; - buffer_usage.reserve(15); - - buffer_usage.push_back({bmm1_grad_gemm1_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({bmm1_grad_gemm2_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({bmm2_grad_gemm1_lhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({bmm2_grad_gemm2_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({d_output_buffer_, MemoryAccess::kWrite}); - buffer_usage.push_back({scratch_buffer_, MemoryAccess::kWrite}); - buffer_usage.push_back({d_bmm1_lhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({d_bmm1_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({d_bmm2_rhs_buffer_, MemoryAccess::kRead}); - - if (d_s_buffer_.allocation() != nullptr) { - buffer_usage.push_back({d_s_buffer_, MemoryAccess::kRead}); - }; - if (d_bias_buffer_.allocation() != nullptr) { - buffer_usage.push_back({d_bias_buffer_, MemoryAccess::kRead}); - }; - if (fwd_output_buffer_.allocation() != nullptr) { - buffer_usage.push_back({fwd_output_buffer_, MemoryAccess::kRead}); - }; - if (bias_buffer_.allocation() != nullptr) { - buffer_usage.push_back({bias_buffer_, MemoryAccess::kRead}); - }; - if (seqlen_q_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_q_buffer_, MemoryAccess::kRead}); - }; - if (seqlen_k_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_k_buffer_, MemoryAccess::kRead}); - }; - - return buffer_usage; -} - //===----------------------------------------------------------------------===// // CublasLtCmd //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h index b7a077e81a9e4f..27e8fea0d86366 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h @@ -40,7 +40,6 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" @@ -81,8 +80,6 @@ namespace xla::gpu { V(kReduceScatter, "ReduceScatterCmd") \ V(kAllGatherCmd, "AllGatherCmd") \ V(kCollectiveBroadcastCmd, "CollectiveBroadcastCmd") \ - V(kFusedMHACmd, "FusedMHACmd") \ - V(kFusedMHABackwardCmd, "FusedMHABackwardCmd") \ V(kUnknownCmd, "UnknownCmd") \ // clang-format on @@ -782,112 +779,6 @@ class GemmCmd : public TracedCommandBufferCmd { const bool deterministic_; }; -//===----------------------------------------------------------------------===// -// FusedMHACmd -//===----------------------------------------------------------------------===// - -class FusedMHACmd : public TracedCommandBufferCmd { - public: - FusedMHACmd(ExecutionStreamId execution_stream_id, GpufMHAConfig config, - BufferAllocation::Slice lhs_bmm1, - BufferAllocation::Slice rhs_bmm1, - BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output, - BufferAllocation::Slice scratch, BufferAllocation::Slice mask, - BufferAllocation::Slice bias, BufferAllocation::Slice activation, - BufferAllocation::Slice seqlen_q, - BufferAllocation::Slice seqlen_k); - - absl::Status Initialize(const Thunk::InitializeParams& params, - StateManager& state) override; - - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; - - BufferUsageVector buffers() override; - - bool IsNestedCommandBuffer() const final { return true; } - - private: - FusedMultiHeadedAttentionRunner& GetOrCreateRunner( - const stream_executor::Stream* stream); - - const GpufMHAConfig config_; - BufferAllocation::Slice lhs_bmm1_buffer_; - BufferAllocation::Slice rhs_bmm1_buffer_; - BufferAllocation::Slice rhs_bmm2_buffer_; - BufferAllocation::Slice output_buffer_; - BufferAllocation::Slice scratch_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice activation_buffer_; - BufferAllocation::Slice seqlen_q_buffer_; - BufferAllocation::Slice seqlen_k_buffer_; - - // FusedMHA config - absl::Mutex mutex_; - absl::flat_hash_map> - runner_cache_ ABSL_GUARDED_BY(mutex_); -}; - -//===----------------------------------------------------------------------===// -// FusedMHABackwardCmd -//===----------------------------------------------------------------------===// - -class FusedMHABackwardCmd : public TracedCommandBufferCmd { - public: - FusedMHABackwardCmd( - ExecutionStreamId execution_stream_id, GpufMHABackwardConfig config, - BufferAllocation::Slice bmm1_grad_gemm1_rhs, - BufferAllocation::Slice bmm1_grad_gemm2_rhs, - BufferAllocation::Slice bmm2_grad_gemm1_lhs, - BufferAllocation::Slice bmm2_grad_gemm2_rhs, - BufferAllocation::Slice d_output, BufferAllocation::Slice scratch, - BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs, - BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s, - BufferAllocation::Slice d_bias, BufferAllocation::Slice fwd_output, - BufferAllocation::Slice bias, BufferAllocation::Slice seqlen_q, - BufferAllocation::Slice seqlen_k); - - absl::Status Initialize(const Thunk::InitializeParams& params, - StateManager& state) override; - - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; - - BufferUsageVector buffers() override; - - bool IsNestedCommandBuffer() const final { return true; } - - private: - FusedMultiHeadedAttentionBackwardRunner& GetOrCreateRunner( - const stream_executor::Stream* stream); - - const GpufMHABackwardConfig config_; - BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer_; - BufferAllocation::Slice bmm1_grad_gemm2_rhs_buffer_; - BufferAllocation::Slice bmm2_grad_gemm1_lhs_buffer_; - BufferAllocation::Slice bmm2_grad_gemm2_rhs_buffer_; - BufferAllocation::Slice d_output_buffer_; - BufferAllocation::Slice scratch_buffer_; - BufferAllocation::Slice d_bmm1_lhs_buffer_; - BufferAllocation::Slice d_bmm1_rhs_buffer_; - BufferAllocation::Slice d_bmm2_rhs_buffer_; - BufferAllocation::Slice d_s_buffer_; - BufferAllocation::Slice d_bias_buffer_; - BufferAllocation::Slice fwd_output_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice seqlen_q_buffer_; - BufferAllocation::Slice seqlen_k_buffer_; - - // FusedMHA config - absl::Mutex mutex_; - absl::flat_hash_map> - runner_cache_ ABSL_GUARDED_BY(mutex_); -}; - //===----------------------------------------------------------------------===// // CublasLtCmd //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc index 54e01fab8e1109..230d050856fcc2 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc @@ -29,7 +29,6 @@ limitations under the License. #include "xla/service/gpu/runtime/copy_thunk.h" #include "xla/service/gpu/runtime/cudnn_thunk.h" #include "xla/service/gpu/runtime/custom_call_thunk.h" -#include "xla/service/gpu/runtime/fused_mha_thunk.h" #include "xla/service/gpu/runtime/gemm_thunk.h" #include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h" #include "xla/service/gpu/runtime/kernel_thunk.h" @@ -143,27 +142,6 @@ static absl::StatusOr Convert(const CublasLtMatmulThunk& thunk) { thunk.workspace().value()); } -static absl::StatusOr Convert(const FusedMHAThunk& thunk) { - return std::make_unique( - thunk.execution_stream_id(), thunk.config(), thunk.lhs_bmm1_buffer(), - thunk.rhs_bmm1_buffer(), thunk.rhs_bmm2_buffer(), thunk.output_buffer(), - thunk.scratch_buffer(), BufferAllocation::Slice(), thunk.bias_buffer(), - thunk.activation_buffer(), thunk.seqlen_q_buffer(), - thunk.seqlen_k_buffer()); -} - -static absl::StatusOr Convert(const FusedMHABackwardThunk& thunk) { - return std::make_unique( - thunk.execution_stream_id(), thunk.config(), - thunk.bmm1_grad_gemm1_rhs_buffer(), thunk.bmm1_grad_gemm2_rhs_buffer(), - thunk.bmm2_grad_gemm1_lhs_buffer(), thunk.bmm2_grad_gemm2_rhs_buffer(), - thunk.d_output_buffer(), thunk.scratch_buffer(), - thunk.d_bmm1_lhs_buffer(), thunk.d_bmm1_rhs_buffer(), - thunk.d_bmm2_rhs_buffer(), thunk.d_s_buffer(), thunk.d_bias_buffer(), - thunk.fwd_output_buffer(), thunk.bias_buffer(), thunk.seqlen_q_buffer(), - thunk.seqlen_k_buffer()); -} - static absl::StatusOr Convert( const ConditionalThunk& thunk, CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { @@ -276,10 +254,6 @@ static absl::Status AppendCommands( return append(Convert(thunk)); case Thunk::Kind::kCustomKernel: return append(Convert(thunk)); - case Thunk::Kind::kFusedMHA: - return append(Convert(thunk)); - case Thunk::Kind::kFusedMHABackward: - return append(Convert(thunk)); case Thunk::Kind::kKernel: return append(Convert(thunk)); case Thunk::Kind::kGemm: diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 85c4df22b8b215..b52668fbee028c 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -1124,9 +1124,9 @@ xla_cc_test( # TODO(b/358278858): Currently lacking test coverage. cc_library( - name = "cudnn_workspace_rewriter", - srcs = if_cuda_is_configured(["cudnn_workspace_rewriter.cc"]), - hdrs = if_cuda_is_configured(["cudnn_workspace_rewriter.h"]), + name = "cudnn_custom_call_compiler", + srcs = if_cuda_is_configured(["cudnn_custom_call_compiler.cc"]), + hdrs = if_cuda_is_configured(["cudnn_custom_call_compiler.h"]), deps = if_cuda_is_configured([ "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -1141,9 +1141,9 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", + "//xla/service/gpu/runtime:cudnn_thunk", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", - "//xla/service/gpu:gpu_fused_mha_runner", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:stream_executor_util", "//xla/stream_executor:dnn", diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc new file mode 100644 index 00000000000000..00b73c9112b40b --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc @@ -0,0 +1,660 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/cudnn_custom_call_compiler.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/stream_executor_util.h" +#include "xla/stream_executor/cuda/cuda_dnn.h" +#include "xla/stream_executor/cuda/cudnn_frontend_helpers.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +inline absl::StatusOr AsCudnnFmhaMaskKind( + CudnnfMHABackendConfig_MaskType mask_type) { + switch (mask_type) { + case CudnnfMHABackendConfig::NO_MASK: + return CudnnfMHAMaskKind::kNoMask; + case CudnnfMHABackendConfig::PADDING: + return CudnnfMHAMaskKind::kPadding; + case CudnnfMHABackendConfig::CAUSAL: + return CudnnfMHAMaskKind::kCausal; + case CudnnfMHABackendConfig::PADDING_CAUSAL: + return CudnnfMHAMaskKind::kPaddingCausal; + case CudnnfMHABackendConfig::ALIBI: + return CudnnfMHAMaskKind::kAlibi; + default: + return xla::Internal("Unknown fmha mask kind."); + } +} + +// This is an interim structure to hold the parameters to construct a +// GpufMHAConfig. +// Struct to describe properties of a FMHA without being tied to specific +// IR. Will be used to help build FMHA thunks from either XLA HLO or +// LHLO GPU dialect in MLIR. +struct GpufMHADescriptor { + CudnnfMHAKind kind; + CudnnfMHABackendConfig backend_config; + CudnnfMHAMaskKind mask_type; + Shape lhs_bmm1_shape; + Shape rhs_bmm1_shape; + Shape rhs_bmm2_shape; + Shape intermediate_lhs_bmm2_shape; + // This will contain both output shape and activation shape + absl::InlinedVector output_shapes; + DotDimensionNumbers bmm1_dnums; + DotDimensionNumbers bmm2_dnums; + + std::optional mask_shape; + std::optional bias_shape; +}; + +struct GpufMHABackwardDescriptor { + CudnnfMHAKind kind; + CudnnfMHABackendConfig backend_config; + CudnnfMHAMaskKind mask_type; + Shape bmm1_grad_gemm1_rhs_shape; + Shape bmm1_grad_gemm2_rhs_shape; + Shape bmm2_grad_gemm1_lhs_shape; + Shape bmm2_grad_gemm2_rhs_shape; + Shape d_output_shape; + Shape d_bmm1_lhs_shape; + Shape d_bmm1_rhs_shape; + Shape d_bmm2_rhs_shape; + DotDimensionNumbers bmm1_grad_gemm1_dnums; + DotDimensionNumbers bmm1_grad_gemm2_dnums; + DotDimensionNumbers bmm2_grad_gemm1_dnums; + DotDimensionNumbers bmm2_grad_gemm2_dnums; + + std::optional d_s_shape; + std::optional fwd_output_shape; + std::optional mask_shape; + std::optional d_bias_shape; + std::optional bias_shape; + bool force_deterministic; +}; + +// Structure to describe static properties of a GPU fused Multi-Headed +// Attention. +struct GpufMHAConfig { + static absl::StatusOr For(const GpufMHADescriptor &fmha_desc); + PrimitiveType + input_type; // Capture the primitive type of one of the inputs of BMM1 + PrimitiveType output_type; + CudnnfMHAKind kind; + std::optional fmha_scale; + std::optional dropout_rate; + std::optional seed; + + se::dnn::AlgorithmDesc algorithm; + CudnnfMHAMaskKind mask_type; + // bias -> [1, num_attn_heads, q_seq_len, kv_seq_len] + // mask -> [batch_size, 1, q_seq_len, kv_seq_len] + se::dnn::MatmulTensorDescriptor lhs_bmm1; + se::dnn::MatmulTensorDescriptor rhs_bmm1; + se::dnn::MatmulTensorDescriptor rhs_bmm2; + se::dnn::MatmulTensorDescriptor intermediate_lhs_bmm2; + se::dnn::TensorDescriptor output; + + std::optional activation; + std::optional mask; + std::optional bias; +}; + +// Structure to describe static properties of a GPU fused Multi-Headed +// Attention backward. +struct GpufMHABackwardConfig { + static absl::StatusOr For( + const GpufMHABackwardDescriptor &fmha_desc); + PrimitiveType + input_type; // Capture the primitive type of one of the inputs of BMM1 + PrimitiveType output_type; + CudnnfMHAKind kind; + std::optional fmha_scale; + std::optional dropout_rate; + std::optional seed; + + se::dnn::AlgorithmDesc algorithm; + CudnnfMHAMaskKind mask_type; + // mask -> [batch_size, 1, q_seq_len, kv_seq_len] + // d_bias -> [1, num_heads, q_seq_len, kv_seq_len] + se::dnn::MatmulTensorDescriptor bmm1_grad_gemm1_rhs; + se::dnn::MatmulTensorDescriptor bmm1_grad_gemm2_rhs; + se::dnn::MatmulTensorDescriptor bmm2_grad_gemm1_lhs; + se::dnn::MatmulTensorDescriptor bmm2_grad_gemm2_rhs; + se::dnn::MatmulTensorDescriptor d_output; + se::dnn::TensorDescriptor d_bmm1_lhs; + se::dnn::TensorDescriptor d_bmm1_rhs; + se::dnn::TensorDescriptor d_bmm2_rhs; + std::optional d_s; + std::optional mask; + std::optional d_bias; + std::optional fwd_output; + std::optional bias; +}; + +using se::DeviceMemory; +using se::DeviceMemoryBase; +using se::dnn::DataType; +using se::dnn::MatmulTensorDescriptor; +using se::dnn::TensorDescriptor; + +/*static*/ absl::StatusOr GpufMHAConfig::For( + const GpufMHADescriptor &desc) { + // Get shapes from desc. + const Shape &lhs_bmm1_shape = desc.lhs_bmm1_shape; + const Shape &rhs_bmm1_shape = desc.rhs_bmm1_shape; + const Shape &rhs_bmm2_shape = desc.rhs_bmm2_shape; + const Shape &intermediate_lhs_bmm2_shape = desc.intermediate_lhs_bmm2_shape; + const Shape &output_shape = desc.output_shapes[0]; + + // Get DNN dtype from primtive types + TF_ASSIGN_OR_RETURN( + DataType lhs_bmm1_type, + GetDNNDataTypeFromPrimitiveType(lhs_bmm1_shape.element_type())); + TF_ASSIGN_OR_RETURN( + DataType rhs_bmm1_type, + GetDNNDataTypeFromPrimitiveType(rhs_bmm1_shape.element_type())); + + TF_ASSIGN_OR_RETURN( + DataType rhs_bmm2_type, + GetDNNDataTypeFromPrimitiveType(rhs_bmm2_shape.element_type())); + TF_ASSIGN_OR_RETURN(DataType lhs_bmm2_type, + GetDNNDataTypeFromPrimitiveType( + intermediate_lhs_bmm2_shape.element_type())); + TF_ASSIGN_OR_RETURN(DataType output_type, GetDNNDataTypeFromPrimitiveType( + output_shape.element_type())); + GpufMHAConfig config; + config.input_type = lhs_bmm1_shape.element_type(); + config.output_type = output_shape.element_type(); + + // Get MatmulTensorDescriptors for BMM1 + config.lhs_bmm1 = + MatmulTensorDescriptor::For(lhs_bmm1_type, lhs_bmm1_shape.dimensions(), + desc.lhs_bmm1_shape.layout().minor_to_major(), + desc.bmm1_dnums.lhs_batch_dimensions(), + desc.bmm1_dnums.lhs_contracting_dimensions()); + config.rhs_bmm1 = + MatmulTensorDescriptor::For(rhs_bmm1_type, rhs_bmm1_shape.dimensions(), + desc.rhs_bmm1_shape.layout().minor_to_major(), + desc.bmm1_dnums.rhs_batch_dimensions(), + desc.bmm1_dnums.rhs_contracting_dimensions()); + + // Get MatmulTensorDescriptors for BMM2 + config.rhs_bmm2 = + MatmulTensorDescriptor::For(rhs_bmm2_type, rhs_bmm2_shape.dimensions(), + desc.rhs_bmm2_shape.layout().minor_to_major(), + desc.bmm2_dnums.rhs_batch_dimensions(), + desc.bmm2_dnums.rhs_contracting_dimensions()); + + config.intermediate_lhs_bmm2 = MatmulTensorDescriptor::For( + lhs_bmm2_type, intermediate_lhs_bmm2_shape.dimensions(), + desc.intermediate_lhs_bmm2_shape.layout().minor_to_major(), + desc.bmm2_dnums.lhs_batch_dimensions(), + desc.bmm2_dnums.lhs_contracting_dimensions()); + + config.output = TensorDescriptor::For(output_type, output_shape.dimensions(), + output_shape.layout().minor_to_major()); + + if (desc.output_shapes.size() > 1) { + const Shape &activation_shape = desc.output_shapes.back(); + // Generally, activation should have same type as output, but set it + // explicityly just to be safe. + TF_ASSIGN_OR_RETURN( + DataType activation_type, + GetDNNDataTypeFromPrimitiveType(activation_shape.element_type())); + config.activation = + TensorDescriptor::For(activation_type, activation_shape.dimensions(), + activation_shape.layout().minor_to_major()); + } + + if (desc.mask_shape) { + const Shape &mask_shape = *desc.mask_shape; + TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType( + mask_shape.element_type())); + config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(), + mask_shape.layout().minor_to_major()); + } + + if (desc.bias_shape) { + const Shape &bias_shape = *desc.bias_shape; + TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType( + bias_shape.element_type())); + config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(), + bias_shape.layout().minor_to_major()); + } + config.kind = desc.kind; + config.mask_type = desc.mask_type; + const CudnnfMHABackendConfig &backend_config = desc.backend_config; + config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); + config.fmha_scale.emplace(backend_config.fmha_scale()); + config.dropout_rate.emplace(backend_config.dropout_rate()); + config.seed.emplace(backend_config.seed()); + return config; +} + +/*static*/ absl::StatusOr GpufMHABackwardConfig::For( + const GpufMHABackwardDescriptor &desc) { + // Get shapes from desc. + const Shape &bmm1_grad_gemm1_rhs_shape = desc.bmm1_grad_gemm1_rhs_shape; + const Shape &bmm1_grad_gemm2_rhs_shape = desc.bmm1_grad_gemm2_rhs_shape; + const Shape &bmm2_grad_gemm1_lhs_shape = desc.bmm2_grad_gemm1_lhs_shape; + const Shape &bmm2_grad_gemm2_rhs_shape = desc.bmm2_grad_gemm2_rhs_shape; + const Shape &d_output_shape = desc.d_output_shape; + const Shape &d_bmm1_lhs_shape = desc.d_bmm1_lhs_shape; + const Shape &d_bmm1_rhs_shape = desc.d_bmm1_rhs_shape; + const Shape &d_bmm2_rhs_shape = desc.d_bmm2_rhs_shape; + // Get DNN dtype from primtive types + TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm1_rhs_type, + GetDNNDataTypeFromPrimitiveType( + bmm1_grad_gemm1_rhs_shape.element_type())); + + TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm2_rhs_type, + GetDNNDataTypeFromPrimitiveType( + bmm1_grad_gemm2_rhs_shape.element_type())); + + TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm1_lhs_type, + GetDNNDataTypeFromPrimitiveType( + bmm2_grad_gemm1_lhs_shape.element_type())); + + TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm2_rhs_type, + GetDNNDataTypeFromPrimitiveType( + bmm2_grad_gemm2_rhs_shape.element_type())); + + TF_ASSIGN_OR_RETURN( + DataType d_output_type, + GetDNNDataTypeFromPrimitiveType(d_output_shape.element_type())); + + TF_ASSIGN_OR_RETURN( + DataType d_bmm1_lhs_type, + GetDNNDataTypeFromPrimitiveType(d_bmm1_lhs_shape.element_type())); + + TF_ASSIGN_OR_RETURN( + DataType d_bmm1_rhs_type, + GetDNNDataTypeFromPrimitiveType(d_bmm1_rhs_shape.element_type())); + + TF_ASSIGN_OR_RETURN( + DataType d_bmm2_rhs_type, + GetDNNDataTypeFromPrimitiveType(d_bmm2_rhs_shape.element_type())); + + GpufMHABackwardConfig config; + config.input_type = bmm1_grad_gemm1_rhs_shape.element_type(); + config.output_type = d_bmm1_lhs_shape.element_type(); + + // Get MatmulTensorDescriptors for lhs of BMM1 grad GEMM 1 + config.bmm1_grad_gemm1_rhs = MatmulTensorDescriptor::For( + bmm1_grad_gemm1_rhs_type, bmm1_grad_gemm1_rhs_shape.dimensions(), + desc.bmm1_grad_gemm1_rhs_shape.layout().minor_to_major(), + desc.bmm1_grad_gemm1_dnums.rhs_batch_dimensions(), + desc.bmm1_grad_gemm1_dnums.rhs_contracting_dimensions()); + + // Get MatmulTensorDescriptors for rhs of BMM1 grad GEMM 2 + config.bmm1_grad_gemm2_rhs = MatmulTensorDescriptor::For( + bmm1_grad_gemm2_rhs_type, bmm1_grad_gemm2_rhs_shape.dimensions(), + desc.bmm1_grad_gemm2_rhs_shape.layout().minor_to_major(), + desc.bmm1_grad_gemm2_dnums.rhs_batch_dimensions(), + desc.bmm1_grad_gemm2_dnums.rhs_contracting_dimensions()); + + // Get MatmulTensorDescriptors for BMM2 grad GEMM 1 + config.bmm2_grad_gemm1_lhs = MatmulTensorDescriptor::For( + bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(), + desc.bmm2_grad_gemm1_lhs_shape.layout().minor_to_major(), + desc.bmm2_grad_gemm1_dnums.lhs_batch_dimensions(), + desc.bmm2_grad_gemm1_dnums.lhs_contracting_dimensions()); + + config.d_output = MatmulTensorDescriptor::For( + d_output_type, d_output_shape.dimensions(), + desc.d_output_shape.layout().minor_to_major(), + desc.bmm2_grad_gemm1_dnums.rhs_batch_dimensions(), + desc.bmm2_grad_gemm1_dnums.rhs_contracting_dimensions()); + + // Get MatmulTensorDescriptors for BMM2 grad GEMM 2 + config.bmm2_grad_gemm2_rhs = MatmulTensorDescriptor::For( + bmm2_grad_gemm2_rhs_type, bmm2_grad_gemm2_rhs_shape.dimensions(), + desc.bmm2_grad_gemm2_rhs_shape.layout().minor_to_major(), + desc.bmm2_grad_gemm2_dnums.rhs_batch_dimensions(), + desc.bmm2_grad_gemm2_dnums + .rhs_contracting_dimensions()); // FMHA TODO: transpose here? + + config.d_bmm1_lhs = + TensorDescriptor::For(d_bmm1_lhs_type, d_bmm1_lhs_shape.dimensions(), + d_bmm1_lhs_shape.layout().minor_to_major()); + config.d_bmm1_rhs = + TensorDescriptor::For(d_bmm1_rhs_type, d_bmm1_rhs_shape.dimensions(), + d_bmm1_rhs_shape.layout().minor_to_major()); + config.d_bmm2_rhs = + TensorDescriptor::For(d_bmm2_rhs_type, d_bmm2_rhs_shape.dimensions(), + d_bmm2_rhs_shape.layout().minor_to_major()); + config.d_s = TensorDescriptor::For( + bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(), + bmm2_grad_gemm1_lhs_shape.layout().minor_to_major()); + + if (desc.d_bias_shape) { + const Shape &d_bias_shape = *desc.d_bias_shape; + // Get DNN dtype from primtive types + TF_ASSIGN_OR_RETURN(DataType d_bias_type, GetDNNDataTypeFromPrimitiveType( + d_bias_shape.element_type())); + config.d_bias = + TensorDescriptor::For(d_bias_type, d_bias_shape.dimensions(), + d_bias_shape.layout().minor_to_major()); + } + + if (desc.mask_shape) { + const Shape &mask_shape = *desc.mask_shape; + TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType( + mask_shape.element_type())); + config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(), + mask_shape.layout().minor_to_major()); + } + if (desc.fwd_output_shape) { + const Shape &fwd_output_shape = *desc.fwd_output_shape; + TF_ASSIGN_OR_RETURN( + DataType fwd_output_type, + GetDNNDataTypeFromPrimitiveType(fwd_output_shape.element_type())); + config.fwd_output = + TensorDescriptor::For(fwd_output_type, fwd_output_shape.dimensions(), + fwd_output_shape.layout().minor_to_major()); + } + + if (desc.bias_shape) { + const Shape &bias_shape = *desc.bias_shape; + TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType( + bias_shape.element_type())); + config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(), + bias_shape.layout().minor_to_major()); + } + + config.kind = desc.kind; + config.mask_type = desc.mask_type; + const CudnnfMHABackendConfig &backend_config = desc.backend_config; + config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); + config.fmha_scale.emplace(backend_config.fmha_scale()); + config.dropout_rate.emplace(backend_config.dropout_rate()); + config.seed.emplace(backend_config.seed()); + return config; +} + +absl::StatusOr HloCustomCallToCuDnnGraph( + se::dnn::DnnSupport &dnn_support, HloCustomCallInstruction *custom_call) { + if (IsFwdCustomCallTofMHA(*custom_call)) { + TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, + xla::gpu::GetCudnnfMHAKind(custom_call)); + std::optional mask_shape, bias_shape; + { + bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax || + kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout; + + if (has_bias) { + const HloInstruction *bias = custom_call->operand(3); + bias_shape = bias->shape(); + } + } + + TF_ASSIGN_OR_RETURN( + const auto gpu_config, + custom_call->backend_config()); + const xla::gpu::CudnnfMHABackendConfig &config = + gpu_config.cudnn_fmha_backend_config(); + Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); + absl::InlinedVector output_shapes = { + ShapeUtil::GetSubshape(custom_call->shape(), {0})}; + + bool has_activation = + xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 3; + if (has_activation) { + output_shapes.push_back( + ShapeUtil::GetSubshape(custom_call->shape(), {1})); + } + + Shape q_shape = custom_call->operand(0)->shape(); + Shape k_shape = custom_call->operand(1)->shape(); + Shape v_shape = custom_call->operand(2)->shape(); + TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); + GpufMHADescriptor descriptor = {kind, + config, + cudnn_mask_type, + q_shape, + k_shape, + v_shape, + intermediate_tensor_shape, + output_shapes, + config.bmm1_dot_dimension_numbers(), + config.bmm2_dot_dimension_numbers(), + mask_shape, + bias_shape}; + + TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config, + GpufMHAConfig::For(descriptor)); + TF_ASSIGN_OR_RETURN( + se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type)); + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + se::gpu::GetCudnnFlashAttentionOperationGraph( + dnn_support, fmha_config.lhs_bmm1, fmha_config.rhs_bmm1, + fmha_config.rhs_bmm2, fmha_config.output, fmha_config.bias, + fmha_config.activation, static_cast(*fmha_config.fmha_scale), + fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0, + fmha_config.dropout_rate, dnn_mask_type)); + return std::move(graph); + } else { + TF_ASSIGN_OR_RETURN( + auto gpu_config, + custom_call->backend_config()); + xla::gpu::CudnnfMHABackendConfig &config = + *gpu_config.mutable_cudnn_fmha_backend_config(); + + int input_index = 0; + Shape bmm1_grad_gemm1_rhs_shape = + custom_call->operand(input_index++)->shape(); + Shape bmm1_grad_gemm2_rhs_shape = + custom_call->operand(input_index++)->shape(); + Shape bmm2_grad_gemm2_rhs_shape = + custom_call->operand(input_index++)->shape(); + Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape()); + input_index++; + Shape d_output_shape = custom_call->operand(input_index++)->shape(); + + TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, + GetCudnnfMHAKind(custom_call)); + std::optional mask_shape; + + bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax || + kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout); + std::optional bias_shape; + if (has_bias) { + bias_shape = custom_call->operand(input_index++)->shape(); + } + + std::optional fwd_output_shape = + custom_call->operand(input_index++)->shape(); + if (config.mask_type() == xla::gpu::CudnnfMHABackendConfig::PADDING || + config.mask_type() == + xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL) { + // skip q_seqlen and kv_seqlen + input_index += 2; + } + TF_RET_CHECK(input_index == custom_call->operand_count()); + + int output_index = 0; + Shape d_bmm1_lhs_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + Shape d_bmm1_rhs_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + Shape d_bmm2_rhs_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + std::optional d_s_shape; + std::optional d_bias_shape; + bool has_dbias = custom_call->shape().tuple_shapes().size() == 5; + if (has_dbias) { + d_bias_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + } + // The last one is the workspace. + TF_RET_CHECK(output_index == + custom_call->shape().tuple_shapes().size() - 1); + TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); + + const DebugOptions &debug_options = + custom_call->GetModule()->config().debug_options(); + bool force_deterministic = + debug_options.xla_gpu_deterministic_ops() || + debug_options.xla_gpu_exclude_nondeterministic_ops(); + config.set_force_deterministic(force_deterministic); + TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); + + GpufMHABackwardDescriptor descriptor = { + kind, + config, + cudnn_mask_type, + bmm1_grad_gemm1_rhs_shape, + bmm1_grad_gemm2_rhs_shape, + bmm2_grad_gemm1_lhs_shape, + bmm2_grad_gemm2_rhs_shape, + d_output_shape, + d_bmm1_lhs_shape, + d_bmm1_rhs_shape, + d_bmm2_rhs_shape, + config.bmm1_grad_gemm1_dot_dimension_numbers(), + config.bmm1_grad_gemm2_dot_dimension_numbers(), + config.bmm2_grad_gemm1_dot_dimension_numbers(), + config.bmm2_grad_gemm2_dot_dimension_numbers(), + d_s_shape, + fwd_output_shape, + mask_shape, + d_bias_shape, + bias_shape, + force_deterministic}; + + TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_config, + GpufMHABackwardConfig::For(descriptor)); + TF_ASSIGN_OR_RETURN( + se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type)); + + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + se::gpu::GetCudnnFlashAttentionBackwardOperationGraph( + dnn_support, fmha_config.bmm1_grad_gemm1_rhs, + fmha_config.bmm1_grad_gemm2_rhs, fmha_config.bmm2_grad_gemm1_lhs, + fmha_config.bmm2_grad_gemm2_rhs, fmha_config.d_output, + fmha_config.d_bmm1_lhs, fmha_config.d_bmm1_rhs, + fmha_config.d_bmm2_rhs, fmha_config.bias, fmha_config.dropout_rate, + fmha_config.seed, *fmha_config.fmha_scale, + fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0, + fmha_config.bias != std::nullopt, dnn_mask_type, + force_deterministic)); + return std::move(graph); + } +} + +class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor { + public: + explicit CuDnnCustomCallVisitor(se::dnn::DnnSupport &dnn_support, + BinaryMap &compilation_results) + : dnn_support_(dnn_support), compilation_results_(compilation_results) {} + + void AddWorkspace(HloInstruction &hlo, int64_t workspace_size) { + if (workspace_size == 0) { + return; + } + VLOG(4) << "Applying workspace size " << workspace_size << " to " + << hlo.ToString(); + Shape *shape = hlo.mutable_shape(); + shape->mutable_tuple_shapes()->back().set_dimensions(0, workspace_size); + } + + absl::Status HandleCustomCall(HloInstruction *hlo) override { + if (!IsCustomCallTofMHA(*hlo)) { + return absl::OkStatus(); + } + + TF_ASSIGN_OR_RETURN(const std::string fingerprint_without_workspace, + FingerprintWithBackendConfig(*hlo)); + auto workspace_size_it = + workspace_sizes_.find(fingerprint_without_workspace); + if (workspace_size_it == workspace_sizes_.cend()) { + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + HloCustomCallToCuDnnGraph(dnn_support_, + DynCast(hlo))); + + const int64_t workspace_size = graph.Graph().get_workspace_size(); + workspace_sizes_.insert(workspace_size_it, + {fingerprint_without_workspace, workspace_size}); + AddWorkspace(*hlo, workspace_size); + + std::vector serialized_graph; + RETURN_IF_CUDNN_FRONTEND_ERROR(graph.Graph().serialize(serialized_graph)); + // Compute a new fingerprint with a potential workspace for the + // compilation results to match a fingerprint computed by the emitter. + TF_ASSIGN_OR_RETURN(const std::string fingerprint_with_workspace, + FingerprintWithBackendConfig(*hlo)); + compilation_results_[fingerprint_with_workspace] = + std::string(reinterpret_cast(serialized_graph.data()), + serialized_graph.size()); + } else { + VLOG(4) << "Cache hit."; + AddWorkspace(*hlo, workspace_size_it->second); + } + + MarkAsChanged(); + return absl::OkStatus(); + } + + private: + se::dnn::DnnSupport &dnn_support_; + BinaryMap &compilation_results_; + absl::flat_hash_map workspace_sizes_; +}; + +} // namespace + +absl::StatusOr CuDnnCustomCallCompiler::Run( + HloModule *module, + const absl::flat_hash_set &execution_threads) { + XLA_SCOPED_LOGGING_TIMER_LEVEL("cuDNN custom call compiler", 8); + return CuDnnCustomCallVisitor(dnn_support_, compilation_results_) + .RunOnModule(module, execution_threads); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h similarity index 61% rename from third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h index 962841289b58dc..810286f91b8472 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_WORKSPACE_REWRITER_H_ -#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_WORKSPACE_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/hlo_pass_interface.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream_executor.h" @@ -27,14 +28,18 @@ limitations under the License. namespace xla { namespace gpu { -// Rewrite cuDNN custom call to have correct workspace size by build graph -// and serialize so we can use it later -class CuDnnWorkspaceRewriter : public HloModulePass { +// Compile cuDNN custom calls to binaries and serialize them. +// Also adjust them in HLO to have correct workspace size. +class CuDnnCustomCallCompiler : public HloModulePass { public: - explicit CuDnnWorkspaceRewriter(se::StreamExecutor& stream_exec) - : dnn_support_(*stream_exec.AsDnn()) {} + explicit CuDnnCustomCallCompiler(se::StreamExecutor& stream_exec, + BinaryMap& compilation_results) + : dnn_support_(*stream_exec.AsDnn()), + compilation_results_(compilation_results) {} - absl::string_view name() const override { return "cudnn-workspace-rewriter"; } + absl::string_view name() const override { + return "cudnn-custom-call-compiler"; + } using HloPassInterface::Run; absl::StatusOr Run( @@ -43,9 +48,10 @@ class CuDnnWorkspaceRewriter : public HloModulePass { private: se::dnn::DnnSupport& dnn_support_; + BinaryMap& compilation_results_; }; } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_WORKSPACE_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.cc deleted file mode 100644 index b5440a8a2af53f..00000000000000 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.cc +++ /dev/null @@ -1,272 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/transforms/cudnn_workspace_rewriter.h" - -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_clone_context.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/stream_executor/cuda/cuda_dnn.h" -#include "xla/stream_executor/dnn.h" -#include "xla/util.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { - -namespace { - -// create cuDNN graphs from HloCustomCall -absl::StatusOr HloCustomCallToCuDnnGraph( - se::dnn::DnnSupport& dnn_support, HloCustomCallInstruction* custom_call) { - if (IsFwdCustomCallTofMHA(*custom_call)) { - TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, - xla::gpu::GetCudnnfMHAKind(custom_call)); - std::optional mask_shape, bias_shape; - { - bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax || - kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout; - - if (has_bias) { - const HloInstruction* bias = custom_call->operand(3); - bias_shape = bias->shape(); - } - } - - TF_ASSIGN_OR_RETURN( - const auto gpu_config, - custom_call->backend_config()); - const xla::gpu::CudnnfMHABackendConfig& config = - gpu_config.cudnn_fmha_backend_config(); - Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); - absl::InlinedVector output_shapes = { - ShapeUtil::GetSubshape(custom_call->shape(), {0})}; - - bool has_activation = - xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 3; - if (has_activation) { - output_shapes.push_back( - ShapeUtil::GetSubshape(custom_call->shape(), {1})); - } - - Shape q_shape = custom_call->operand(0)->shape(); - Shape k_shape = custom_call->operand(1)->shape(); - Shape v_shape = custom_call->operand(2)->shape(); - TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - GpufMHADescriptor descriptor = {kind, - config, - cudnn_mask_type, - q_shape, - k_shape, - v_shape, - intermediate_tensor_shape, - output_shapes, - config.bmm1_dot_dimension_numbers(), - config.bmm2_dot_dimension_numbers(), - mask_shape, - bias_shape}; - - TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config, - GpufMHAConfig::For(descriptor)); - TF_ASSIGN_OR_RETURN( - se::dnn::FMHAMaskKind dnn_mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type)); - TF_ASSIGN_OR_RETURN( - se::gpu::CudnnGraph graph, - se::gpu::GetCudnnFlashAttentionOperationGraph( - dnn_support, fmha_config.lhs_bmm1, fmha_config.rhs_bmm1, - fmha_config.rhs_bmm2, fmha_config.output, fmha_config.bias, - fmha_config.activation, static_cast(*fmha_config.fmha_scale), - fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0, - fmha_config.dropout_rate, dnn_mask_type)); - return std::move(graph); - } else { - TF_ASSIGN_OR_RETURN( - auto gpu_config, - custom_call->backend_config()); - xla::gpu::CudnnfMHABackendConfig& config = - *gpu_config.mutable_cudnn_fmha_backend_config(); - - int input_index = 0; - Shape bmm1_grad_gemm1_rhs_shape = - custom_call->operand(input_index++)->shape(); - Shape bmm1_grad_gemm2_rhs_shape = - custom_call->operand(input_index++)->shape(); - Shape bmm2_grad_gemm2_rhs_shape = - custom_call->operand(input_index++)->shape(); - Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape()); - input_index++; - Shape d_output_shape = custom_call->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, - GetCudnnfMHAKind(custom_call)); - std::optional mask_shape; - - bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax || - kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout); - std::optional bias_shape; - if (has_bias) { - bias_shape = custom_call->operand(input_index++)->shape(); - } - - std::optional fwd_output_shape = - custom_call->operand(input_index++)->shape(); - if (config.mask_type() == xla::gpu::CudnnfMHABackendConfig::PADDING || - config.mask_type() == - xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL) { - // skip q_seqlen and kv_seqlen - input_index += 2; - } - TF_RET_CHECK(input_index == custom_call->operand_count()); - - int output_index = 0; - Shape d_bmm1_lhs_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - Shape d_bmm1_rhs_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - Shape d_bmm2_rhs_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - std::optional d_s_shape; - std::optional d_bias_shape; - bool has_dbias = custom_call->shape().tuple_shapes().size() == 5; - if (has_dbias) { - d_bias_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - } - // The last one is the workspace. - TF_RET_CHECK(output_index == - custom_call->shape().tuple_shapes().size() - 1); - TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - - const bool force_deterministic = - RequireDeterminism(custom_call->GetModule()->config()); - // set the correct force_deterministic attribute here - config.set_force_deterministic(force_deterministic); - TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); - - GpufMHABackwardDescriptor descriptor = { - kind, - config, - cudnn_mask_type, - bmm1_grad_gemm1_rhs_shape, - bmm1_grad_gemm2_rhs_shape, - bmm2_grad_gemm1_lhs_shape, - bmm2_grad_gemm2_rhs_shape, - d_output_shape, - d_bmm1_lhs_shape, - d_bmm1_rhs_shape, - d_bmm2_rhs_shape, - config.bmm1_grad_gemm1_dot_dimension_numbers(), - config.bmm1_grad_gemm2_dot_dimension_numbers(), - config.bmm2_grad_gemm1_dot_dimension_numbers(), - config.bmm2_grad_gemm2_dot_dimension_numbers(), - d_s_shape, - fwd_output_shape, - mask_shape, - d_bias_shape, - bias_shape, - force_deterministic}; - - TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_config, - GpufMHABackwardConfig::For(descriptor)); - TF_ASSIGN_OR_RETURN( - se::dnn::FMHAMaskKind dnn_mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type)); - - TF_ASSIGN_OR_RETURN( - se::gpu::CudnnGraph graph, - se::gpu::GetCudnnFlashAttentionBackwardOperationGraph( - dnn_support, fmha_config.bmm1_grad_gemm1_rhs, - fmha_config.bmm1_grad_gemm2_rhs, fmha_config.bmm2_grad_gemm1_lhs, - fmha_config.bmm2_grad_gemm2_rhs, fmha_config.d_output, - fmha_config.d_bmm1_lhs, fmha_config.d_bmm1_rhs, - fmha_config.d_bmm2_rhs, fmha_config.bias, fmha_config.dropout_rate, - fmha_config.seed, *fmha_config.fmha_scale, - fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0, - fmha_config.bias != std::nullopt, dnn_mask_type, - force_deterministic)); - return std::move(graph); - } -} - -class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor { - public: - explicit CuDnnCustomCallVisitor(se::dnn::DnnSupport& dnn_support) - : dnn_support_(dnn_support) {} - - absl::Status HandleCustomCall(HloInstruction* hlo) override { - if (!IsCustomCallTofMHA(*hlo)) { - // don't do anything about other cuDNN custom calls - return absl::OkStatus(); - } - TF_ASSIGN_OR_RETURN(auto gpu_config, - hlo->backend_config()); - - TF_ASSIGN_OR_RETURN( - se::gpu::CudnnGraph graph, - HloCustomCallToCuDnnGraph(dnn_support_, - DynCast(hlo))); - auto workspace = graph.Graph().get_workspace_size(); - if (workspace != 0) { - // rewrite custom call to have correct workspace size - VLOG(4) << "Rewriting: " << hlo->ToString(); - Shape* shape = hlo->mutable_shape(); - shape->mutable_tuple_shapes(shape->tuple_shapes_size() - 1) - ->set_dimensions(0, workspace); - MarkAsChanged(); - } - return absl::OkStatus(); - } - - private: - se::dnn::DnnSupport& dnn_support_; -}; - -} // namespace - -absl::StatusOr CuDnnWorkspaceRewriter::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - XLA_SCOPED_LOGGING_TIMER("cuDNN workspace rewriter"); - return CuDnnCustomCallVisitor(dnn_support_) - .RunOnModule(module, execution_threads); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index 90640a839dcf23..440f647b84f1ce 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -40,7 +40,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -3762,32 +3761,6 @@ absl::StatusOr CreateCudnnTensor( } #if CUDNN_VERSION >= 8800 -enum CudnnfMHAUid { - Q_ID = 400, - K_ID, - V_ID, - P_ID, - O_ID, - dQ_ID, - dK_ID, - dV_ID, - dP_ID, - dO_ID, - dS_ID, - dBIAS_ID, - BIAS_ID, - MASK_ID, - ZERO_VAL_ID, - ONE_VAL_ID, - NEG_INFINITY_ID, - ALPHA_SCALE_ID, - DROPOUT_SCALE_ID, - Q_SEQLEN_ID, - K_SEQLEN_ID, - D_OFFSET_ID, - D_SEED_ID, - VIRTUAL_ID = 34857 -}; absl::StatusOr CreatePwDesc( dnn::DataType dtype, cudnnPointwiseMode_t mode) { @@ -5032,12 +5005,14 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_io_data_type(ioDataType) .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; + std::shared_ptr q_tensor = graph.tensor(Tensor_attributes() .set_name("Q") .set_dim(q_descriptor.GetCudnnCompatibleDimensions(true)) .set_stride(q_descriptor.GetCudnnCompatibleStrides(true)) - .set_uid(CudnnfMHAUid::Q_ID)); + .set_uid(next_uid())); auto dim = k_descriptor.GetCudnnCompatibleDimensions(true); std::shared_ptr k_tensor = @@ -5045,13 +5020,13 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_name("K") .set_dim(k_descriptor.GetCudnnCompatibleDimensions(true)) .set_stride(k_descriptor.GetCudnnCompatibleStrides(true)) - .set_uid(CudnnfMHAUid::K_ID)); + .set_uid(next_uid())); std::shared_ptr v_tensor = graph.tensor( Tensor_attributes() .set_name("V") .set_dim(v_descriptor.GetCudnnCompatibleDimensions(false)) .set_stride(v_descriptor.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::V_ID)); + .set_uid(next_uid())); // Setting sdpa, and is_inference bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || @@ -5069,7 +5044,7 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_name("bias") .set_dim(bias_descriptor->dimensions()) .set_stride(bias_descriptor->GetLogicalStrides()) - .set_uid(CudnnfMHAUid::BIAS_ID)); + .set_uid(next_uid())); sdpa_options.set_bias(bias_tensor); } // Setting actual seqlen @@ -5083,37 +5058,38 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_name("seq_q") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::Q_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); auto seq_kv_tensor = graph.tensor(Tensor_attributes() .set_name("seq_kv") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::K_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); sdpa_options.set_padding_mask(true); sdpa_options.set_seq_len_q(seq_q_tensor); sdpa_options.set_seq_len_kv(seq_kv_tensor); } // Setting seed and offset + std::shared_ptr seed_tensor; + std::shared_ptr offset_tensor; if (use_dropout) { - auto seed_tensor = + // Skip setting UIDs: pass by value tensors go at the end. + seed_tensor = graph.tensor(Tensor_attributes() .set_name("seed") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_SEED_ID)); - auto offset_tensor = + .set_is_pass_by_value(true)); + offset_tensor = graph.tensor(Tensor_attributes() .set_name("offset") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_OFFSET_ID)); + .set_is_pass_by_value(true)); sdpa_options.set_dropout((float)dropout_rate.value(), seed_tensor, offset_tensor); } @@ -5127,7 +5103,7 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_output(true) .set_dim(o_descriptor.dimensions()) .set_stride(o_descriptor.GetLogicalStrides()) - .set_uid(CudnnfMHAUid::O_ID); + .set_uid(next_uid()); if (stats_descriptor.has_value()) { cudnn_frontend::DataType_t statsType = ToCudnnFrontendDataType(stats_descriptor->type()); @@ -5140,7 +5116,13 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_data_type(statsType) .set_dim(stat_dims) .set_stride(stat_strides) - .set_uid(CudnnfMHAUid::P_ID); + .set_uid(next_uid()); + } + if (seed_tensor != nullptr) { + seed_tensor->set_uid(next_uid()); + } + if (offset_tensor != nullptr) { + offset_tensor->set_uid(next_uid()); } CudnnGraph cudnnGraph(std::move(graph)); TF_RETURN_IF_ERROR(cudnnGraph.Prepare( @@ -5195,71 +5177,66 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT) .set_io_data_type(ioDataType); + auto p_dims = p_desc.GetCudnnCompatibleDimensions(false); + auto p_strides = p_desc.GetCudnnCompatibleStrides(false); + std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); + p_reduction_dims.push_back(1); + + // Divide every stride by the last dim value. + std::vector p_reduction_strides; + p_reduction_strides.reserve(p_strides.size()); + int64_t p_reduced_dim_len = p_dims.back(); + for (auto stride : p_strides) { + p_reduction_strides.push_back(stride / p_reduced_dim_len); + } + p_reduction_strides[3] = 1; + bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || + mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; + auto sdpa_backward_options = + cudnn_frontend::graph::SDPA_backward_attributes() + .set_name("flash_attention_backward") + .set_causal_mask(is_causal) + .set_attn_scale(scale) + .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + + auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; + std::shared_ptr q = graph.tensor(Tensor_attributes() .set_name("Q") .set_dim(q_desc.GetCudnnCompatibleDimensions(false)) .set_stride(q_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::Q_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr k = graph.tensor(Tensor_attributes() .set_name("K") .set_dim(k_desc.GetCudnnCompatibleDimensions(false)) .set_stride(k_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::K_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr v = graph.tensor(Tensor_attributes() .set_name("V") .set_dim(v_desc.GetCudnnCompatibleDimensions(true)) .set_stride(v_desc.GetCudnnCompatibleStrides(true)) - .set_uid(CudnnfMHAUid::V_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); - std::shared_ptr o = + std::shared_ptr stats = graph.tensor(Tensor_attributes() - .set_name("O") - .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) - .set_stride(do_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::O_ID) - .set_data_type(ioDataType)); + .set_name("stats") + .set_dim(p_reduction_dims) + .set_stride(p_reduction_strides) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::FLOAT)); std::shared_ptr dO = graph.tensor(Tensor_attributes() .set_name("dO") .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) .set_stride(do_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::dO_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); - auto p_dims = p_desc.GetCudnnCompatibleDimensions(false); - auto p_strides = p_desc.GetCudnnCompatibleStrides(false); - std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); - p_reduction_dims.push_back(1); - - // Divide every stride by the last dim value. - std::vector p_reduction_strides; - p_reduction_strides.reserve(p_strides.size()); - int64_t p_reduced_dim_len = p_dims.back(); - for (auto stride : p_strides) { - p_reduction_strides.push_back(stride / p_reduced_dim_len); - } - p_reduction_strides[3] = 1; - std::shared_ptr stats = - graph.tensor(Tensor_attributes() - .set_name("stats") - .set_dim(p_reduction_dims) - .set_stride(p_reduction_strides) - .set_uid(CudnnfMHAUid::P_ID) - .set_data_type(cudnn_frontend::DataType_t::FLOAT)); - bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || - mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - auto sdpa_backward_options = - cudnn_frontend::graph::SDPA_backward_attributes() - .set_name("flash_attention_backward") - .set_causal_mask(is_causal) - .set_attn_scale(scale) - .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); - - // Setting bias + std::shared_ptr d_bias_tensor; if (use_bias) { DCHECK(bias_descriptor != std::nullopt); auto bias_dim = bias_descriptor->dimensions(); @@ -5272,21 +5249,29 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_name("bias") .set_dim(bias_descriptor->dimensions()) .set_stride(bias_descriptor->GetLogicalStrides()) - .set_uid(CudnnfMHAUid::BIAS_ID)); + .set_uid(next_uid())); sdpa_backward_options.set_bias(bias_tensor); // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] are not supported for // dbias calculation but they are supported for forward bias calculation + // Set UID later: this is the last output tuple element. if (b == 1 && n == q_n) { - auto d_bias_tensor = + d_bias_tensor = graph.tensor(Tensor_attributes() .set_name("dBias") .set_dim(bias_descriptor->dimensions()) - .set_stride(bias_descriptor->GetLogicalStrides()) - .set_uid(CudnnfMHAUid::dBIAS_ID)); + .set_stride(bias_descriptor->GetLogicalStrides())); sdpa_backward_options.set_dbias(d_bias_tensor); } } + std::shared_ptr o = + graph.tensor(Tensor_attributes() + .set_name("O") + .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(do_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + // Setting actual seqlen bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; @@ -5298,38 +5283,39 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_name("seq_q") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::Q_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); auto seq_kv_tensor = graph.tensor(Tensor_attributes() .set_name("seq_kv") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::K_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); sdpa_backward_options.set_padding_mask(true); sdpa_backward_options.set_seq_len_q(seq_q_tensor); sdpa_backward_options.set_seq_len_kv(seq_kv_tensor); } // Setting seed and offset + std::shared_ptr seed_tensor; + std::shared_ptr offset_tensor; if (use_dropout) { DCHECK(dropout_rate != std::nullopt); - auto seed_tensor = + // Skip setting UIDs: pass by value tensors go at the end. + seed_tensor = graph.tensor(Tensor_attributes() .set_name("seed") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_SEED_ID)); - auto offset_tensor = + .set_is_pass_by_value(true)); + offset_tensor = graph.tensor(Tensor_attributes() .set_name("offset") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_OFFSET_ID)); + .set_is_pass_by_value(true)); sdpa_backward_options.set_dropout((float)dropout_rate.value(), seed_tensor, offset_tensor); } @@ -5344,21 +5330,30 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( dQ->set_output(true) .set_dim(dq_desc.dimensions()) .set_stride(dq_desc.GetLogicalStrides()) + .set_uid(next_uid()) .set_name("dQ") - .set_uid(CudnnfMHAUid::dQ_ID) .set_data_type(ioDataType); dK->set_output(true) .set_dim(dk_desc.dimensions()) .set_stride(dk_desc.GetLogicalStrides()) + .set_uid(next_uid()) .set_name("dK") - .set_uid(CudnnfMHAUid::dK_ID) .set_data_type(ioDataType); dV->set_output(true) .set_dim(dv_desc.dimensions()) .set_stride(dv_desc.GetLogicalStrides()) + .set_uid(next_uid()) .set_name("dV") - .set_uid(CudnnfMHAUid::dV_ID) .set_data_type(ioDataType); + if (d_bias_tensor != nullptr) { + d_bias_tensor->set_uid(next_uid()); + } + if (seed_tensor != nullptr) { + seed_tensor->set_uid(next_uid()); + } + if (offset_tensor != nullptr) { + offset_tensor->set_uid(next_uid()); + } CudnnGraph cudnnGraph(std::move(graph)); TF_RETURN_IF_ERROR( @@ -5696,8 +5691,8 @@ absl::Status CudnnSupport::DoConvolve( } // Utility for dealing with CUDA's type-erased scaling parameters, where some -// sets of parameters expect a void* pointing at a float while others expect it -// to point at a double. +// sets of parameters expect a void* pointing at a float while others expect +// it to point at a double. // // This is rather ugly, but its purpose is to quarantine the corresponding // ugliness that already exists in the CUDA API. @@ -5721,9 +5716,9 @@ class ScalingParam { // // See // https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters - // for more info; the behavior for int8 result tensors is not described there, - // but is maintained from the existing behavior (namely, using a float scaling - // parameter). + // for more info; the behavior for int8 result tensors is not described + // there, but is maintained from the existing behavior (namely, using a + // float scaling parameter). void* ToVoidPointer(dnn::DataType element_type) { if (element_type == dnn::DataType::kDouble) { return &as_double_; @@ -5795,10 +5790,11 @@ absl::StatusOr> GetDescriptorAttribute( absl::c_transform(result, std::back_inserter(raw_ptrs), [](const BackendDescriptor& ptr) { return ptr.get(); }); - // This API evidently does a deep copy of the descriptors into the pointers in - // the output array, rather than writing pointers to the descriptors into the - // output array. So, this writes the memory behind each BackendDescriptor in - // result, rather than writing the contents of raw_ptrs. + // This API evidently does a deep copy of the descriptors into the pointers + // in the output array, rather than writing pointers to the descriptors into + // the output array. So, this writes the memory behind each + // BackendDescriptor in result, rather than writing the contents of + // raw_ptrs. RETURN_IF_CUDNN_ERROR(cudnnBackendGetAttribute( desc, name, CUDNN_TYPE_BACKEND_DESCRIPTOR, n, &n, raw_ptrs.data())); @@ -5834,9 +5830,9 @@ absl::StatusOr ExecutionPlanToAlgorithmDesc( cudnnBackendGetAttribute(engines[0].get(), CUDNN_ATTR_ENGINE_GLOBAL_INDEX, CUDNN_TYPE_INT64, 1, &n, &engine_id)); - // Apparently for CUDNN_ATTR_ENGINECFG_KNOB_CHOICES only, trying to query the - // number of elements in the attribute by using an output limit value of 0 - // just returns 0; the only way to find out how many there are is to + // Apparently for CUDNN_ATTR_ENGINECFG_KNOB_CHOICES only, trying to query + // the number of elements in the attribute by using an output limit value of + // 0 just returns 0; the only way to find out how many there are is to // pre-allocate space for every existing knob type (as an upper bound on the // number of knob choices a config can have), and then look back at how many // were filled. @@ -6047,103 +6043,7 @@ class CudnnExecutionPlanRunner std::vector scalar_input_uids_; std::vector scalar_input_values_; }; -#endif // CUDNN_VERSION >= 8100 - -template -class CudnnGraphRunner; -// An OpRunner implemented by a cuDNN frontend graph. -// -// This is the class holding the implementation of ToString, GetWorkspaceSize, -// and operator() for use by the cudnn frontend op runners. -template -class CudnnGraphRunner : public dnn::OpRunner { - private: - using Graph = cudnn_frontend::graph::Graph; - using Tensor_attributes = cudnn_frontend::graph::Tensor_attributes; - - public: - std::string ToString() const override { return graph_.Graph().print(); } - - size_t GetWorkspaceSize() const override { - return graph_.Graph().get_workspace_size(); - } - - absl::StatusOr ToAlgorithmDesc() const override { - return absl::InternalError( - "Unexpected call to CudnnGraphRunner::ToAlgorithmDesc"); - } - - absl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, - DeviceMemoryBase scratch_memory, - Args... inputs) const override { - if (parent_ != stream->parent()) { - return tsl::errors::Internal( - "CudnnExecutionPlanRunner cached across multiple StreamExecutors."); - } - CudnnHandle handle = cudnn_->GetHandle(parent_, stream); - std::unordered_map variant_pack; - std::vector vec = {inputs.opaque()...}; - - // add device buffers to the variant pack - for (int i = 0; i < uids_.size(); ++i) { - if (uids_[i].has_value()) { - variant_pack[*uids_[i]] = vec[i]; - } - } - if (dropout_rng_offset_increment_ > 0) { -#if CUDNN_VERSION >= 8800 - variant_pack[CudnnfMHAUid::D_SEED_ID] = (void*)&dropout_rng_seed_; - current_dropout_rng_offset_ += dropout_rng_offset_increment_; - variant_pack[CudnnfMHAUid::D_OFFSET_ID] = - (void*)¤t_dropout_rng_offset_; -#else - return absl::UnimplementedError( - "Cudnn dropout offset and seed are only supported with Cudnn >= " - "8.8.0"); -#endif // CUDNN_VERSION >= 8800 - } - int workspace = graph_.Graph().get_workspace_size(); - if (workspace > scratch_memory.size()) { - return tsl::errors::Internal( - absl::StrFormat("CuDNN FMHA requires %d workspace, got %d workspace.", - workspace, scratch_memory.size())); - } - RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.Graph().execute( - handle.handle(), variant_pack, scratch_memory.opaque())); - - return absl::OkStatus(); - } - - static absl::StatusOr Create( - GpuExecutor* parent, CudnnAccess* cudnn, CudnnGraph graph, - int64_t dropout_rng_seed, int64_t dropout_rng_offset, - std::vector> uids) { - return CudnnGraphRunner(parent, cudnn, std::move(graph), dropout_rng_seed, - dropout_rng_offset, uids); - } - - private: - CudnnGraphRunner(GpuExecutor* parent, CudnnAccess* cudnn, CudnnGraph graph, - int64_t dropout_rng_seed, int64_t dropout_rng_offset, - std::vector> uids) - : parent_(parent), - cudnn_(cudnn), - graph_(std::move(graph)), - dropout_rng_seed_(dropout_rng_seed), - current_dropout_rng_offset_(0), - dropout_rng_offset_increment_(dropout_rng_offset), - uids_(uids) {} - GpuExecutor* parent_; - CudnnAccess* cudnn_; - Stream* stream_; - CudnnGraph graph_; - int64_t dropout_rng_seed_; - mutable int64_t current_dropout_rng_offset_; - int64_t dropout_rng_offset_increment_; - std::vector> uids_; -}; -#if CUDNN_VERSION >= 8100 namespace { template @@ -6929,7 +6829,8 @@ absl::Status CudnnSupport::GetFusedMatmulRunners( use_fallback, out_exec_plans, /*need_side_input=*/true, numeric_options); #else return tsl::errors::Unimplemented( - "Cudnn execution plans for matmul are only supported with Cudnn >= 8.4."); + "Cudnn execution plans for matmul are only supported with Cudnn >= " + "8.4."); #endif // CUDNN_VERSION >= 8400 } @@ -7131,139 +7032,6 @@ int64_t GetDropoutRngOffset(std::vector& intermediate_shape) { return max_seq_len * max_seq_len / cudnn_mha_num_threads; } -absl::StatusOr> -CudnnSupport::FusedMHARunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type) { -#if CUDNN_VERSION >= 90000 - auto cudnn = cudnn_->GetHandle(parent_, stream); - bool use_dropout = dropout_rate && *dropout_rate > 0.0; - std::vector intermediate_shape; - - TF_ASSIGN_OR_RETURN(auto graph, - GetCudnnFlashAttentionOperationGraph( - *this, /*q_descriptor=*/bmm1_lhs_descriptor, - /*k_descriptor=*/bmm1_rhs_descriptor, - /*v_descriptor=*/bmm2_rhs_descriptor, - /*o_descriptor=*/output_descriptor, bias_descriptor, - /*stats_descriptor=*/activation_descriptor, - /*scale=*/static_cast(scale), use_dropout, - dropout_rate, mask_type)); - - std::vector intermediate_bmm2_lhs_dims = - intermediate_bmm2_lhs_descriptor.GetCudnnCompatibleDimensions(true); - intermediate_shape = intermediate_bmm2_lhs_dims; - int64_t dropout_rng_offset = GetDropoutRngOffset(intermediate_shape); - int64_t dropout_rng_seed = seed.has_value() ? *seed : 0; - std::vector> uids = { - CudnnfMHAUid::Q_ID, CudnnfMHAUid::K_ID, CudnnfMHAUid::V_ID, - CudnnfMHAUid::O_ID}; - uids.emplace_back(bias_descriptor.has_value() - ? std::optional(CudnnfMHAUid::BIAS_ID) - : std::nullopt); - uids.emplace_back(activation_descriptor.has_value() - ? std::optional(CudnnfMHAUid::P_ID) - : std::nullopt); - bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || - mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::Q_SEQLEN_ID) - : std::nullopt); - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::K_SEQLEN_ID) - : std::nullopt); - TF_ASSIGN_OR_RETURN(auto runner, - CudnnGraphRunner::Create( - parent_, cudnn_.get(), std::move(graph), - dropout_rng_seed, dropout_rng_offset, uids)); - - return {std::make_unique>( - std::move(runner))}; -#else - return absl::UnimplementedError( - "Cudnn flash attention are only supported with Cudnn >= 9.0.0"); -#endif // CUDNN_VERSION >= 90000 -} - -absl::StatusOr> -CudnnSupport::FusedMHABackwardRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& d_output_descriptor, - const dnn::TensorDescriptor& d_bmm1_lhs_descriptor, - const dnn::TensorDescriptor& d_bmm1_rhs_descriptor, - const dnn::TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic) { -#if CUDNN_VERSION >= 90000 - auto cudnn = cudnn_->GetHandle(parent_, stream); - - bool use_dropout = dropout_rate && *dropout_rate > 0.0; - std::vector intermediate_shape; - - TF_ASSIGN_OR_RETURN( - auto graph, - GetCudnnFlashAttentionBackwardOperationGraph( - *this, bmm1_grad_gemm1_rhs_descriptor, bmm1_grad_gemm2_rhs_descriptor, - bmm2_grad_gemm1_lhs_descriptor, bmm2_grad_gemm2_rhs_descriptor, - d_output_descriptor, d_bmm1_lhs_descriptor, d_bmm1_rhs_descriptor, - d_bmm2_rhs_descriptor, bias_descriptor, dropout_rate, seed, scale, - use_dropout, bias_descriptor != std::nullopt, mask_type, - force_deterministic)); - - std::vector p_dims = - bmm2_grad_gemm1_lhs_descriptor.GetCudnnCompatibleDimensions(false); - intermediate_shape = p_dims; - int64_t dropout_rng_offset = GetDropoutRngOffset(intermediate_shape); - int64_t dropout_rng_seed = seed.has_value() ? *seed : 0; - - std::vector> uids; - uids = {CudnnfMHAUid::Q_ID, CudnnfMHAUid::K_ID, CudnnfMHAUid::P_ID, - CudnnfMHAUid::V_ID, CudnnfMHAUid::dO_ID, CudnnfMHAUid::dQ_ID, - CudnnfMHAUid::dK_ID, CudnnfMHAUid::dV_ID, std::nullopt}; - uids.emplace_back(d_bias_descriptor.has_value() - ? std::optional(CudnnfMHAUid::dBIAS_ID) - : std::nullopt); - uids.push_back(CudnnfMHAUid::O_ID); - uids.emplace_back(bias_descriptor.has_value() - ? std::optional(CudnnfMHAUid::BIAS_ID) - : std::nullopt); - bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || - mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::Q_SEQLEN_ID) - : std::nullopt); - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::K_SEQLEN_ID) - : std::nullopt); - TF_ASSIGN_OR_RETURN(auto runner, - CudnnGraphRunner::Create( - parent_, cudnn_.get(), graph, dropout_rng_seed, - dropout_rng_offset, uids)); - return {std::make_unique>( - std::move(runner))}; -#else - return absl::UnimplementedError( - "Cudnn flash attention bwd are only " - "supported with Cudnn >= 9.0.0"); -#endif // CUDNN_VERSION >= 90000 -} - bool CudnnSupport::GetRnnAlgorithms( std::vector* out_algorithms) { PreloadCudnnSubLibs(PreloadCudnnType::Rnn); @@ -8348,15 +8116,30 @@ absl::Status CudnnGraph::Execute(Stream& stream, std::unordered_map tensor_to_ptr_map; absl::Span operands_without_workspace = operands; DeviceMemoryBase workspace; - if (graph_.get_workspace_size() != 0) { + if (graph_.get_workspace_size() > 0) { workspace = operands.back(); CHECK_EQ(graph_.get_workspace_size(), workspace.size()); + } + if (graph_.get_workspace_size() > 0 || operands.back().size() == 0) { operands_without_workspace = operands.first(operands.size() - 1); } - int operand_number = 0; + auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; for (DeviceMemoryBase operand : operands_without_workspace) { - tensor_to_ptr_map[CuDnnTensorUID(operand_number++)] = operand.opaque(); + tensor_to_ptr_map[next_uid()] = operand.opaque(); } + + if (dropout_rng_offset_increment_ > 0) { +#if CUDNN_VERSION >= 8800 + tensor_to_ptr_map[next_uid()] = (void*)&dropout_rng_seed_; + current_dropout_rng_offset_ += dropout_rng_offset_increment_; + tensor_to_ptr_map[next_uid()] = (void*)¤t_dropout_rng_offset_; +#else + return absl::UnimplementedError( + "Cudnn dropout offset and seed are only supported with Cudnn >= " + "8.8.0"); +#endif // CUDNN_VERSION >= 8800 + } + const CudnnSupport& dnn_support = static_cast(*stream.parent()->AsDnn()); RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.execute( diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h index 083a2b431e62c1..24d84e369cb138 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h @@ -70,6 +70,9 @@ class CudnnGraph : public dnn::DnnGraph { private: cudnn_frontend::graph::Graph graph_; + int64_t dropout_rng_seed_; + mutable int64_t current_dropout_rng_offset_; + int64_t dropout_rng_offset_increment_ = 0; }; #endif // CUDNN_VERSION >= 8100 @@ -335,37 +338,6 @@ class CudnnSupport : public dnn::DnnSupport { std::optional dscale_descriptor, std::optional dbias_descriptor) override; - absl::StatusOr> - FusedMHARunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type) override; - - absl::StatusOr> - FusedMHABackwardRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& d_output_descriptor, - const dnn::TensorDescriptor& d_bmm1_lhs_descriptor, - const dnn::TensorDescriptor& d_bmm1_rhs_descriptor, - const dnn::TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic); - bool GetRnnAlgorithms( std::vector* out_algorithms) override; diff --git a/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h b/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h index aa59af500ba7a3..0a30c1af59c0c4 100644 --- a/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h +++ b/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h @@ -29,6 +29,11 @@ namespace gpu { } \ } while (false) +// UIDs for cuDNN are unique identifiers of tensors within a graph. They are +// assigned during graph construction; then graph execution takes a {uid: +// buffer pointer} map defining the correspondance of buffers to tensors. +// UID assignment scheme can be arbitrary; at the moment for simplicity XLA uses +// a scheme UID = (HLO operand number + 1). int CuDnnTensorUID(int offset); } // namespace gpu diff --git a/third_party/xla/xla/stream_executor/dnn.cc b/third_party/xla/xla/stream_executor/dnn.cc index b7da7e50eb5eb9..951b2f6e147cd8 100644 --- a/third_party/xla/xla/stream_executor/dnn.cc +++ b/third_party/xla/xla/stream_executor/dnn.cc @@ -249,42 +249,6 @@ DnnSupport::NormRunnerFromDesc( return absl::UnimplementedError("NormRunnerFromDesc not implemented."); } -absl::StatusOr> -DnnSupport::FusedMHARunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type) { - return absl::UnimplementedError("FusedMHARunnerFromDesc not implemented."); -} - -absl::StatusOr> -DnnSupport::FusedMHABackwardRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& d_output_descriptor, - const TensorDescriptor& d_bmm1_lhs_descriptor, - const TensorDescriptor& d_bmm1_rhs_descriptor, - const TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic) { - return absl::UnimplementedError( - "FusedMHABackwardRunnerFromDesc not implemented."); -} - bool DnnSupport::GetMIOpenConvolveAlgorithms( dnn::ConvolutionKind /*kind*/, dnn::DataType /*element_type*/, dnn::DataType /*output_type*/, Stream* /*stream*/, diff --git a/third_party/xla/xla/stream_executor/dnn.h b/third_party/xla/xla/stream_executor/dnn.h index af709946eeb241..a2e1cd629dc2b4 100644 --- a/third_party/xla/xla/stream_executor/dnn.h +++ b/third_party/xla/xla/stream_executor/dnn.h @@ -993,30 +993,6 @@ using FusedMatmulRunner = OpRunner; using NormSignature = void(std::vector); using NormRunner = OpRunner; -using FusedMHASignature = void(DeviceMemoryBase /*BMM1_inputA_data*/, - DeviceMemoryBase /* BMM1_inputB_data */, - DeviceMemoryBase /* BMM2_inputA_data */, - DeviceMemoryBase /* output_data */, - DeviceMemoryBase /* bias_data */, - DeviceMemoryBase /* activation_data */, - DeviceMemoryBase /* seqlen_q_data */, - DeviceMemoryBase /* seqlen_k_data */); -using FusedMHARunner = OpRunner; - -using FusedMHABackwardSignature = void( - DeviceMemoryBase /* BMM1_GRAD_GEMM1_inputA_data */, - DeviceMemoryBase /* BMM1_GRAD_GEMM2_inputB_data */, - DeviceMemoryBase /* BMM2_GRAD_GEMM1_inputA_data */, - DeviceMemoryBase /* BMM2_GRAD_GEMM2_inputB_data */, - DeviceMemoryBase /* d_output_data */, - DeviceMemoryBase /* d_BMM1_inputA_data */, - DeviceMemoryBase /* d_BMM1_inputB_data */, - DeviceMemoryBase /* d_BMM2_inputB_data */, DeviceMemoryBase /* d_S_data */, - DeviceMemoryBase /* d_bias_data */, DeviceMemoryBase /* fwd_output_data */, - DeviceMemoryBase /* bias_data */, DeviceMemoryBase /* seqlen_q_data */, - DeviceMemoryBase /* seqlen_k_data */); -using FusedMHABackwardRunner = OpRunner; - // Describes the configuration for the algorithms that will used. // // Arguments: @@ -1731,37 +1707,6 @@ class DnnSupport { return absl::UnimplementedError("Graph support requires cuDNN >= 8.1."); }; - virtual absl::StatusOr> - FusedMHARunnerFromDesc( - Stream* stream, const AlgorithmDesc& algorithm_desc, - const MatmulTensorDescriptor& bmm1_lhs_descriptor, - const MatmulTensorDescriptor& bmm1_rhs_descriptor, - const MatmulTensorDescriptor& bmm2_rhs_descriptor, - const MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type); - - virtual absl::StatusOr> - FusedMHABackwardRunnerFromDesc( - Stream* stream, const AlgorithmDesc& algorithm_desc, - const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& d_output_descriptor, - const TensorDescriptor& d_bmm1_lhs_descriptor, - const TensorDescriptor& d_bmm1_rhs_descriptor, - const TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic); - virtual bool GetMIOpenConvolveAlgorithms( ConvolutionKind kind, DataType element_type, DataType output_type, Stream* stream, const BatchDescriptor& input_descriptor, diff --git a/third_party/xla/xla/stream_executor/lazy_op_runner.h b/third_party/xla/xla/stream_executor/lazy_op_runner.h index c74a03e1ad5226..bf964e05bbaae6 100644 --- a/third_party/xla/xla/stream_executor/lazy_op_runner.h +++ b/third_party/xla/xla/stream_executor/lazy_op_runner.h @@ -280,76 +280,6 @@ struct FusedMatmulOp { } }; -struct FusedMHAOp { - using Signature = FusedMHASignature; - struct Config { - double scale; - const MatmulTensorDescriptor& bmm1_lhs_descriptor; - const MatmulTensorDescriptor& bmm1_rhs_descriptor; - const MatmulTensorDescriptor& bmm2_rhs_descriptor; - const MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor; - const TensorDescriptor& output_descriptor; - std::optional bias_descriptor; - std::optional activation_descriptor; - std::optional dropout_rate; - std::optional seed; - FMHAMaskKind mask_type; - }; - - static absl::StatusOr>> - RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, - Stream* stream) { - TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); - return dnn->FusedMHARunnerFromDesc( - stream, desc, config.bmm1_lhs_descriptor, config.bmm1_rhs_descriptor, - config.bmm2_rhs_descriptor, config.intermediate_bmm2_lhs_descriptor, - config.output_descriptor, config.activation_descriptor, - config.bias_descriptor, config.scale, config.dropout_rate, config.seed, - config.mask_type); - } -}; - -struct FusedMHABackwardOp { - using Signature = FusedMHABackwardSignature; - - struct Config { - double scale; - const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor; - const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor; - const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor; - const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor; - const MatmulTensorDescriptor& d_output_descriptor; - const TensorDescriptor& d_bmm1_lhs_descriptor; - const TensorDescriptor& d_bmm1_rhs_descriptor; - const TensorDescriptor& d_bmm2_rhs_descriptor; - std::optional d_s_descriptor; - std::optional d_bias_descriptor; - std::optional fwd_output_descriptor; - std::optional bias_descriptor; - std::optional dropout_rate; - std::optional seed; - FMHAMaskKind mask_type; - bool force_deterministic; - }; - - static absl::StatusOr< - std::unique_ptr>> - RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, - Stream* stream) { - TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); - return dnn->FusedMHABackwardRunnerFromDesc( - stream, desc, config.bmm1_grad_gemm1_rhs_descriptor, - config.bmm1_grad_gemm2_rhs_descriptor, - config.bmm2_grad_gemm1_lhs_descriptor, - config.bmm2_grad_gemm2_rhs_descriptor, config.d_output_descriptor, - config.d_bmm1_lhs_descriptor, config.d_bmm1_rhs_descriptor, - config.d_bmm2_rhs_descriptor, config.d_s_descriptor, - config.d_bias_descriptor, config.fwd_output_descriptor, - config.bias_descriptor, config.scale, config.dropout_rate, config.seed, - config.mask_type, config.force_deterministic); - } -}; - } // namespace dnn } // namespace stream_executor