diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 96388e0c2b7864..65ec5409801709 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -313,7 +313,6 @@ cc_library( ":execution_stream_assignment", ":gpu_asm_opts_util", ":gpu_conv_runner", - ":gpu_fused_mha_runner", ":gpu_norm_runner", ":hlo_fusion_analysis", ":ir_emission_utils", @@ -356,9 +355,9 @@ cc_library( "//xla/service/gpu/runtime:conditional_thunk", "//xla/service/gpu/runtime:convolution_thunk", "//xla/service/gpu/runtime:copy_thunk", + "//xla/service/gpu/runtime:cudnn_thunk", "//xla/service/gpu/runtime:custom_call_thunk", "//xla/service/gpu/runtime:fft_thunk", - "//xla/service/gpu/runtime:fused_mha_thunk", "//xla/service/gpu/runtime:gemm_thunk", "//xla/service/gpu/runtime:gpublas_lt_matmul_thunk", "//xla/service/gpu/runtime:infeed_thunk", @@ -1045,31 +1044,6 @@ cc_library( ]), ) -cc_library( - name = "gpu_fused_mha_runner", - srcs = ["gpu_fused_mha_runner.cc"], - hdrs = ["gpu_fused_mha_runner.h"], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":stream_executor_util", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/stream_executor", - "//xla/stream_executor:dnn", - "//xla/stream_executor:lazy_op_runner", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "cusolver_context", srcs = if_gpu_is_configured(["cusolver_context.cc"]), @@ -1779,6 +1753,7 @@ cc_library( "//xla/service/gpu/transforms:conv_padding_legalization", "//xla/service/gpu/transforms:conv_rewriter", "//xla/service/gpu/transforms:cublas_pad_for_gemms", + "//xla/service/gpu/transforms:cudnn_custom_call_compiler", "//xla/service/gpu/transforms:cudnn_fused_conv_rewriter", "//xla/service/gpu/transforms:cudnn_fused_mha_rewriter", "//xla/service/gpu/transforms:cudnn_fused_mha_transpose_fusion", @@ -1787,7 +1762,6 @@ cc_library( "//xla/service/gpu/transforms:cudnn_pad_for_convolutions", "//xla/service/gpu/transforms:cudnn_simplify_padding", "//xla/service/gpu/transforms:cudnn_vectorize_convolutions", - "//xla/service/gpu/transforms:cudnn_workspace_rewriter", "//xla/service/gpu/transforms:dot_sparsity_rewriter", "//xla/service/gpu/transforms:gpusolver_rewriter", "//xla/service/gpu/transforms:sort_rewriter", diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index d367419d3b4776..46eacbbe358869 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -2150,8 +2150,8 @@ absl::StatusOr> GpuCompiler::RunBackend( }}; BinaryMap dnn_compiled_graphs; if (stream_exec) { - TF_RETURN_IF_ERROR(RunCudnnFusionCompilerPass(module.get(), stream_exec, - &dnn_compiled_graphs)); + TF_RETURN_IF_ERROR(RunCudnnCompilerPasses(module.get(), stream_exec, + &dnn_compiled_graphs)); } const DebugOptions& debug_opts = module->config().debug_options(); diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.h b/third_party/xla/xla/service/gpu/gpu_compiler.h index aa22bfcf3ba338..456e6755b0d83a 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.h +++ b/third_party/xla/xla/service/gpu/gpu_compiler.h @@ -171,10 +171,10 @@ class GpuCompiler : public LLVMCompiler { return absl::OkStatus(); } - // Runs cuDNN fusion compiler pass. - virtual absl::Status RunCudnnFusionCompilerPass( - HloModule* module, se::StreamExecutor* stream_exec, - BinaryMap* dnn_compiled_graphs) { + // Runs cuDNN fusion and custom call compiler passes. + virtual absl::Status RunCudnnCompilerPasses(HloModule* module, + se::StreamExecutor* stream_exec, + BinaryMap* dnn_compiled_graphs) { return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc deleted file mode 100644 index 566c0068f5dbba..00000000000000 --- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc +++ /dev/null @@ -1,719 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/gpu_fused_mha_runner.h" - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "Eigen/Core" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/shape.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/dnn.h" -#include "xla/stream_executor/lazy_op_runner.h" -#include "xla/stream_executor/stream.h" -#include "xla/util.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { - -namespace { -using se::DeviceMemory; -using se::DeviceMemoryBase; -using se::dnn::DataType; -using se::dnn::MatmulTensorDescriptor; -using se::dnn::TensorDescriptor; - -template -absl::Status RunFusedMHA(GpufMHAParams params, se::Stream *stream, - RunFusedMHAOptions options, - DeviceMemory lhs_bmm1_buffer, - DeviceMemory rhs_bmm1_buffer, - DeviceMemory rhs_bmm2_buffer, - DeviceMemory output_buffer, - DeviceMemoryBase bias_buffer, - DeviceMemoryBase scratch_memory, - DeviceMemoryBase activation_output, - DeviceMemoryBase seqlen_q, DeviceMemoryBase seqlen_k) { - se::dnn::LazyOpRunner *lazy_runner = - options.runner_cache->AsFusedMHARunner(); - std::optional> local_runner; - if (!lazy_runner) { - local_runner.emplace(params.config->algorithm); - lazy_runner = &*local_runner; - } - std::optional dropout_rate; - if (params.config->dropout_rate) { - dropout_rate = *params.config->dropout_rate; - } - - std::optional seed; - if (params.config->seed) { - seed = *params.config->seed; - } - - TF_ASSIGN_OR_RETURN(se::dnn::FusedMHAOp::Config config, - params.config->AsDnnFusedMHAOpConfig()); - TF_ASSIGN_OR_RETURN(auto *runner, - lazy_runner->GetOrCreateRunner(config, stream)); - return (*runner)(stream, options.profile_result, scratch_memory, - lhs_bmm1_buffer, rhs_bmm1_buffer, rhs_bmm2_buffer, - output_buffer, bias_buffer, activation_output, seqlen_q, - seqlen_k); -} - -template -absl::Status RunGpuFMHAImpl(const GpufMHAParams ¶ms, se::Stream *stream, - se::DeviceMemoryBase scratch_memory, - RunFusedMHAOptions options) { - auto lhs_bmm1_buffer = se::DeviceMemory(params.lhs_bmm1_buffer); - auto rhs_bmm1_buffer = se::DeviceMemory(params.rhs_bmm1_buffer); - auto rhs_bmm2_buffer = se::DeviceMemory(params.rhs_bmm2_buffer); - auto output_buffer = se::DeviceMemory(params.output_buffer); - auto activation_buffer = - params.activation_buffer.has_value() - ? se::DeviceMemory(*params.activation_buffer) - : se::DeviceMemoryBase(); - auto bias_buffer = params.bias_buffer.has_value() - ? se::DeviceMemory(*params.bias_buffer) - : se::DeviceMemoryBase(); - auto seqlen_q_buffer = - params.seqlen_q_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_q_buffer) - : se::DeviceMemoryBase(); - auto seqlen_k_buffer = - params.seqlen_k_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_k_buffer) - : se::DeviceMemoryBase(); - se::dnn::AlgorithmDesc algorithm = params.config->algorithm; - if (options.runner_cache) { - algorithm = options.runner_cache->ToAlgorithmDesc(); - } - - absl::Status run_status = absl::OkStatus(); - switch (params.config->kind) { - case CudnnfMHAKind::kSoftmaxDropout: - case CudnnfMHAKind::kSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmaxDropout: - run_status = RunFusedMHA( - params, stream, options, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, bias_buffer, scratch_memory, - activation_buffer, seqlen_q_buffer, seqlen_k_buffer); - break; - default: - return Internal("Invalid cuDNN fMHA kind"); - } - - if (!run_status.ok()) { - return run_status; - } - - if (!stream->ok()) { - return Internal("Unable to launch FMHA with type %s and algorithm %s", - CudnnfMHAKindToString(params.config->kind), - algorithm.ToString()); - } - - return absl::OkStatus(); -} - -template -absl::Status RunFusedMHABackward( - GpufMHABackwardParams params, se::Stream *stream, - RunFusedMHABackwardOptions options, - DeviceMemory bmm1_grad_gemm1_rhs_buffer, - DeviceMemory bmm1_grad_gemm2_rhs_buffer, - DeviceMemory bmm2_grad_gemm1_lhs_buffer, - DeviceMemory bmm2_grad_gemm2_rhs_buffer, - DeviceMemory d_output_buffer, - DeviceMemory d_bmm1_lhs_buffer, - DeviceMemory d_bmm1_rhs_buffer, - DeviceMemory d_bmm2_rhs_buffer, DeviceMemoryBase d_s_buffer, - DeviceMemoryBase d_bias_buffer, DeviceMemoryBase fwd_output_buffer, - DeviceMemoryBase bias_buffer, DeviceMemoryBase scratch_memory, - DeviceMemoryBase seqlen_q, DeviceMemoryBase seqlen_k) { - se::dnn::LazyOpRunner *lazy_runner = - options.runner_cache->AsFusedMHABackwardRunner(); - std::optional> - local_runner; - if (!lazy_runner) { - local_runner.emplace(params.config->algorithm); - lazy_runner = &*local_runner; - } - std::optional dropout_rate; - if (params.config->dropout_rate) { - dropout_rate = *params.config->dropout_rate; - } - - std::optional seed; - if (params.config->seed) { - seed = *params.config->seed; - } - - TF_ASSIGN_OR_RETURN(se::dnn::FusedMHABackwardOp::Config config, - params.config->AsDnnFusedMHABackwardOpConfig()); - TF_ASSIGN_OR_RETURN(auto *runner, - lazy_runner->GetOrCreateRunner(config, stream)); - // TODO: pass in real softmax_sum, dQ_accum, fwd_output - return (*runner)(stream, options.profile_result, scratch_memory, - bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, - d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, - d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, - fwd_output_buffer, bias_buffer, seqlen_q, seqlen_k); - return absl::OkStatus(); -} - -template -absl::Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, - se::Stream *stream, - se::DeviceMemoryBase scratch_memory, - RunFusedMHABackwardOptions options) { - auto bmm1_grad_gemm1_rhs_buffer = - se::DeviceMemory(params.bmm1_grad_gemm1_rhs_buffer); - auto bmm1_grad_gemm2_rhs_buffer = - se::DeviceMemory(params.bmm1_grad_gemm2_rhs_buffer); - auto bmm2_grad_gemm1_lhs_buffer = - se::DeviceMemory(params.bmm2_grad_gemm1_lhs_buffer); - auto bmm2_grad_gemm2_rhs_buffer = - se::DeviceMemory(params.bmm2_grad_gemm2_rhs_buffer); - auto d_output_buffer = se::DeviceMemory(params.d_output_buffer); - auto d_bmm1_lhs_buffer = - se::DeviceMemory(params.d_bmm1_lhs_buffer); - auto d_bmm1_rhs_buffer = - se::DeviceMemory(params.d_bmm1_rhs_buffer); - auto d_bmm2_rhs_buffer = - se::DeviceMemory(params.d_bmm2_rhs_buffer); - - // optional buffers - auto d_s_buffer = params.d_s_buffer.has_value() - ? se::DeviceMemory(*params.d_s_buffer) - : se::DeviceMemoryBase(); - - auto d_bias_buffer = params.d_bias_buffer.has_value() - ? se::DeviceMemory(*params.d_bias_buffer) - : se::DeviceMemoryBase(); - - auto fwd_output_buffer = - params.fwd_output_buffer.has_value() - ? se::DeviceMemory(*params.fwd_output_buffer) - : se::DeviceMemoryBase(); - - auto bias_buffer = params.bias_buffer.has_value() - ? se::DeviceMemory(*params.bias_buffer) - : se::DeviceMemoryBase(); - - auto seqlen_q_buffer = - params.seqlen_q_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_q_buffer) - : se::DeviceMemoryBase(); - - auto seqlen_k_buffer = - params.seqlen_k_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_k_buffer) - : se::DeviceMemoryBase(); - - se::dnn::AlgorithmDesc algorithm = params.config->algorithm; - if (options.runner_cache) { - algorithm = options.runner_cache->ToAlgorithmDesc(); - } - - absl::Status run_status = absl::OkStatus(); - switch (params.config->kind) { - case CudnnfMHAKind::kBackwardSoftmaxDropout: - case CudnnfMHAKind::kBackwardSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout: - run_status = RunFusedMHABackward( - params, stream, options, bmm1_grad_gemm1_rhs_buffer, - bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer, - bmm2_grad_gemm2_rhs_buffer, d_output_buffer, d_bmm1_lhs_buffer, - d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, - fwd_output_buffer, bias_buffer, scratch_memory, seqlen_q_buffer, - seqlen_k_buffer); - break; - default: - return Internal("Invalid cuDNN fMHA kind"); - } - - if (!run_status.ok()) { - return run_status; - } - - if (!stream->ok()) { - return Internal("Unable to launch FMHA with type %s and algorithm %s", - CudnnfMHAKindToString(params.config->kind), - algorithm.ToString()); - } - - return run_status; -} -} // namespace - -/*static*/ absl::StatusOr GpufMHAConfig::For( - const GpufMHADescriptor &desc) { - // Get shapes from desc. - const Shape &lhs_bmm1_shape = desc.lhs_bmm1_shape; - const Shape &rhs_bmm1_shape = desc.rhs_bmm1_shape; - const Shape &rhs_bmm2_shape = desc.rhs_bmm2_shape; - const Shape &intermediate_lhs_bmm2_shape = desc.intermediate_lhs_bmm2_shape; - const Shape &output_shape = desc.output_shapes[0]; - - // Get DNN dtype from primtive types - TF_ASSIGN_OR_RETURN( - DataType lhs_bmm1_type, - GetDNNDataTypeFromPrimitiveType(lhs_bmm1_shape.element_type())); - TF_ASSIGN_OR_RETURN( - DataType rhs_bmm1_type, - GetDNNDataTypeFromPrimitiveType(rhs_bmm1_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType rhs_bmm2_type, - GetDNNDataTypeFromPrimitiveType(rhs_bmm2_shape.element_type())); - TF_ASSIGN_OR_RETURN(DataType lhs_bmm2_type, - GetDNNDataTypeFromPrimitiveType( - intermediate_lhs_bmm2_shape.element_type())); - TF_ASSIGN_OR_RETURN(DataType output_type, GetDNNDataTypeFromPrimitiveType( - output_shape.element_type())); - GpufMHAConfig config; - config.input_type = lhs_bmm1_shape.element_type(); - config.output_type = output_shape.element_type(); - - // Get MatmulTensorDescriptors for BMM1 - config.lhs_bmm1 = - MatmulTensorDescriptor::For(lhs_bmm1_type, lhs_bmm1_shape.dimensions(), - desc.lhs_bmm1_shape.layout().minor_to_major(), - desc.bmm1_dnums.lhs_batch_dimensions(), - desc.bmm1_dnums.lhs_contracting_dimensions()); - config.rhs_bmm1 = - MatmulTensorDescriptor::For(rhs_bmm1_type, rhs_bmm1_shape.dimensions(), - desc.rhs_bmm1_shape.layout().minor_to_major(), - desc.bmm1_dnums.rhs_batch_dimensions(), - desc.bmm1_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for BMM2 - config.rhs_bmm2 = - MatmulTensorDescriptor::For(rhs_bmm2_type, rhs_bmm2_shape.dimensions(), - desc.rhs_bmm2_shape.layout().minor_to_major(), - desc.bmm2_dnums.rhs_batch_dimensions(), - desc.bmm2_dnums.rhs_contracting_dimensions()); - - config.intermediate_lhs_bmm2 = MatmulTensorDescriptor::For( - lhs_bmm2_type, intermediate_lhs_bmm2_shape.dimensions(), - desc.intermediate_lhs_bmm2_shape.layout().minor_to_major(), - desc.bmm2_dnums.lhs_batch_dimensions(), - desc.bmm2_dnums.lhs_contracting_dimensions()); - - config.output = TensorDescriptor::For(output_type, output_shape.dimensions(), - output_shape.layout().minor_to_major()); - - if (desc.output_shapes.size() > 1) { - const Shape &activation_shape = desc.output_shapes.back(); - // Generally, activation should have same type as output, but set it - // explicityly just to be safe. - TF_ASSIGN_OR_RETURN( - DataType activation_type, - GetDNNDataTypeFromPrimitiveType(activation_shape.element_type())); - config.activation = - TensorDescriptor::For(activation_type, activation_shape.dimensions(), - activation_shape.layout().minor_to_major()); - } - - if (desc.mask_shape) { - const Shape &mask_shape = *desc.mask_shape; - TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType( - mask_shape.element_type())); - config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(), - mask_shape.layout().minor_to_major()); - } - - if (desc.bias_shape) { - const Shape &bias_shape = *desc.bias_shape; - TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType( - bias_shape.element_type())); - config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(), - bias_shape.layout().minor_to_major()); - } - config.kind = desc.kind; - config.mask_type = desc.mask_type; - const CudnnfMHABackendConfig &backend_config = desc.backend_config; - config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); - config.fmha_scale.emplace(backend_config.fmha_scale()); - config.dropout_rate.emplace(backend_config.dropout_rate()); - config.seed.emplace(backend_config.seed()); - return config; -} - -absl::StatusOr -GpufMHAConfig::AsDnnFusedMHAOpConfig() const { - double scale = 1.0; - if (fmha_scale.has_value()) { - scale = *fmha_scale; - } - TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type)); - - return se::dnn::FusedMHAOp::Config{ - scale, lhs_bmm1, rhs_bmm1, rhs_bmm2, intermediate_lhs_bmm2, - output, bias, activation, dropout_rate, seed, - mask_type}; -} - -/*static*/ absl::StatusOr GpufMHABackwardConfig::For( - const GpufMHABackwardDescriptor &desc) { - // Get shapes from desc. - - const Shape &bmm1_grad_gemm1_rhs_shape = desc.bmm1_grad_gemm1_rhs_shape; - const Shape &bmm1_grad_gemm2_rhs_shape = desc.bmm1_grad_gemm2_rhs_shape; - const Shape &bmm2_grad_gemm1_lhs_shape = desc.bmm2_grad_gemm1_lhs_shape; - const Shape &bmm2_grad_gemm2_rhs_shape = desc.bmm2_grad_gemm2_rhs_shape; - const Shape &d_output_shape = desc.d_output_shape; - const Shape &d_bmm1_lhs_shape = desc.d_bmm1_lhs_shape; - const Shape &d_bmm1_rhs_shape = desc.d_bmm1_rhs_shape; - const Shape &d_bmm2_rhs_shape = desc.d_bmm2_rhs_shape; - // Get DNN dtype from primtive types - TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm1_rhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm1_grad_gemm1_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm2_rhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm1_grad_gemm2_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm1_lhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm2_grad_gemm1_lhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm2_rhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm2_grad_gemm2_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_output_type, - GetDNNDataTypeFromPrimitiveType(d_output_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_bmm1_lhs_type, - GetDNNDataTypeFromPrimitiveType(d_bmm1_lhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_bmm1_rhs_type, - GetDNNDataTypeFromPrimitiveType(d_bmm1_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_bmm2_rhs_type, - GetDNNDataTypeFromPrimitiveType(d_bmm2_rhs_shape.element_type())); - - GpufMHABackwardConfig config; - config.input_type = bmm1_grad_gemm1_rhs_shape.element_type(); - config.output_type = d_bmm1_lhs_shape.element_type(); - - // Get MatmulTensorDescriptors for lhs of BMM1 grad GEMM 1 - config.bmm1_grad_gemm1_rhs = MatmulTensorDescriptor::For( - bmm1_grad_gemm1_rhs_type, bmm1_grad_gemm1_rhs_shape.dimensions(), - desc.bmm1_grad_gemm1_rhs_shape.layout().minor_to_major(), - desc.bmm1_grad_gemm1_dnums.rhs_batch_dimensions(), - desc.bmm1_grad_gemm1_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for rhs of BMM1 grad GEMM 2 - config.bmm1_grad_gemm2_rhs = MatmulTensorDescriptor::For( - bmm1_grad_gemm2_rhs_type, bmm1_grad_gemm2_rhs_shape.dimensions(), - desc.bmm1_grad_gemm2_rhs_shape.layout().minor_to_major(), - desc.bmm1_grad_gemm2_dnums.rhs_batch_dimensions(), - desc.bmm1_grad_gemm2_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for BMM2 grad GEMM 1 - config.bmm2_grad_gemm1_lhs = MatmulTensorDescriptor::For( - bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(), - desc.bmm2_grad_gemm1_lhs_shape.layout().minor_to_major(), - desc.bmm2_grad_gemm1_dnums.lhs_batch_dimensions(), - desc.bmm2_grad_gemm1_dnums.lhs_contracting_dimensions()); - - config.d_output = MatmulTensorDescriptor::For( - d_output_type, d_output_shape.dimensions(), - desc.d_output_shape.layout().minor_to_major(), - desc.bmm2_grad_gemm1_dnums.rhs_batch_dimensions(), - desc.bmm2_grad_gemm1_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for BMM2 grad GEMM 2 - config.bmm2_grad_gemm2_rhs = MatmulTensorDescriptor::For( - bmm2_grad_gemm2_rhs_type, bmm2_grad_gemm2_rhs_shape.dimensions(), - desc.bmm2_grad_gemm2_rhs_shape.layout().minor_to_major(), - desc.bmm2_grad_gemm2_dnums.rhs_batch_dimensions(), - desc.bmm2_grad_gemm2_dnums - .rhs_contracting_dimensions()); // FMHA TODO: transpose here? - - config.d_bmm1_lhs = - TensorDescriptor::For(d_bmm1_lhs_type, d_bmm1_lhs_shape.dimensions(), - d_bmm1_lhs_shape.layout().minor_to_major()); - config.d_bmm1_rhs = - TensorDescriptor::For(d_bmm1_rhs_type, d_bmm1_rhs_shape.dimensions(), - d_bmm1_rhs_shape.layout().minor_to_major()); - config.d_bmm2_rhs = - TensorDescriptor::For(d_bmm2_rhs_type, d_bmm2_rhs_shape.dimensions(), - d_bmm2_rhs_shape.layout().minor_to_major()); - config.d_s = TensorDescriptor::For( - bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(), - bmm2_grad_gemm1_lhs_shape.layout().minor_to_major()); - - if (desc.d_bias_shape) { - const Shape &d_bias_shape = *desc.d_bias_shape; - // Get DNN dtype from primtive types - TF_ASSIGN_OR_RETURN(DataType d_bias_type, GetDNNDataTypeFromPrimitiveType( - d_bias_shape.element_type())); - config.d_bias = - TensorDescriptor::For(d_bias_type, d_bias_shape.dimensions(), - d_bias_shape.layout().minor_to_major()); - } - - if (desc.mask_shape) { - const Shape &mask_shape = *desc.mask_shape; - TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType( - mask_shape.element_type())); - config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(), - mask_shape.layout().minor_to_major()); - } - if (desc.fwd_output_shape) { - const Shape &fwd_output_shape = *desc.fwd_output_shape; - TF_ASSIGN_OR_RETURN( - DataType fwd_output_type, - GetDNNDataTypeFromPrimitiveType(fwd_output_shape.element_type())); - config.fwd_output = - TensorDescriptor::For(fwd_output_type, fwd_output_shape.dimensions(), - fwd_output_shape.layout().minor_to_major()); - } - - if (desc.bias_shape) { - const Shape &bias_shape = *desc.bias_shape; - TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType( - bias_shape.element_type())); - config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(), - bias_shape.layout().minor_to_major()); - } - - config.kind = desc.kind; - config.mask_type = desc.mask_type; - config.force_deterministic = desc.force_deterministic; - const CudnnfMHABackendConfig &backend_config = desc.backend_config; - config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); - config.fmha_scale.emplace(backend_config.fmha_scale()); - config.dropout_rate.emplace(backend_config.dropout_rate()); - config.seed.emplace(backend_config.seed()); - return config; -} - -absl::StatusOr -GpufMHABackwardConfig::AsDnnFusedMHABackwardOpConfig() const { - double scale = 1.0; - if (fmha_scale.has_value()) { - scale = *fmha_scale; - } - TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type)); - - return se::dnn::FusedMHABackwardOp::Config{scale, - bmm1_grad_gemm1_rhs, - bmm1_grad_gemm2_rhs, - bmm2_grad_gemm1_lhs, - bmm2_grad_gemm2_rhs, - d_output, - d_bmm1_lhs, - d_bmm1_rhs, - d_bmm2_rhs, - d_s, - d_bias, - fwd_output, - bias, - dropout_rate, - seed, - mask_type, - force_deterministic}; -} - -/*static*/ absl::StatusOr GpufMHAParams::For( - const GpufMHAConfig &config, se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, se::DeviceMemoryBase rhs_bmm2_buffer, - se::DeviceMemoryBase output_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer) { - GpufMHAParams params; - params.config = &config; - params.lhs_bmm1_buffer = lhs_bmm1_buffer; - params.rhs_bmm1_buffer = rhs_bmm1_buffer; - params.rhs_bmm2_buffer = rhs_bmm2_buffer; - params.output_buffer = output_buffer; - params.activation_buffer = activation_buffer; - params.bias_buffer = bias_buffer; - params.seqlen_q_buffer = seqlen_q_buffer; - params.seqlen_k_buffer = seqlen_k_buffer; - return params; -} - -/*static*/ absl::StatusOr GpufMHABackwardParams::For( - const GpufMHABackwardConfig &config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer) { - GpufMHABackwardParams params; - params.config = &config; - params.bmm1_grad_gemm1_rhs_buffer = bmm1_grad_gemm1_rhs_buffer; - params.bmm1_grad_gemm2_rhs_buffer = bmm1_grad_gemm2_rhs_buffer; - params.bmm2_grad_gemm1_lhs_buffer = bmm2_grad_gemm1_lhs_buffer; - params.bmm2_grad_gemm2_rhs_buffer = bmm2_grad_gemm2_rhs_buffer; - params.d_output_buffer = d_output_buffer; - params.d_bmm1_lhs_buffer = d_bmm1_lhs_buffer; - params.d_bmm1_rhs_buffer = d_bmm1_rhs_buffer; - params.d_bmm2_rhs_buffer = d_bmm2_rhs_buffer; - params.d_s_buffer = d_s_buffer; - params.d_bias_buffer = d_bias_buffer; - params.fwd_output_buffer = fwd_output_buffer; - params.bias_buffer = bias_buffer; - params.seqlen_q_buffer = seqlen_q_buffer; - params.seqlen_k_buffer = seqlen_k_buffer; - return params; -} - -absl::Status RunGpuFMHA(const GpufMHAConfig &fmha_config, - se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm2_buffer, - se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase scratch_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, - se::Stream *stream, RunFusedMHAOptions options) { - TF_ASSIGN_OR_RETURN( - GpufMHAParams params, - GpufMHAParams::For(fmha_config, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, bias_buffer, - activation_buffer, seqlen_q_buffer, seqlen_k_buffer)); - PrimitiveType input_primitive_type = fmha_config.input_type; - switch (input_primitive_type) { - case F16: - return RunGpuFMHAImpl( - params, stream, scratch_buffer, options); - case BF16: - return RunGpuFMHAImpl( - params, stream, scratch_buffer, options); - default: - return absl::UnimplementedError(absl::StrFormat( - "Unimplemented fused MHA with %s", ToString(fmha_config))); - } - return absl::OkStatus(); -} - -absl::Status RunGpuFMHABackward( - const GpufMHABackwardConfig &fmha_config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, se::Stream *stream, - RunFusedMHABackwardOptions options) { - TF_ASSIGN_OR_RETURN( - GpufMHABackwardParams params, - GpufMHABackwardParams::For( - fmha_config, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, - d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, - d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, fwd_output_buffer, - bias_buffer, seqlen_q_buffer, seqlen_k_buffer)); - PrimitiveType input_primitive_type = fmha_config.input_type; - switch (input_primitive_type) { - case F16: - return RunGpuFMHABackwardImpl( - params, stream, scratch_buffer, options); - case BF16: - return RunGpuFMHABackwardImpl(params, stream, - scratch_buffer, options); - default: - return Unimplemented("Unimplemented fused MHA backward"); - } - return absl::OkStatus(); -} - -std::string ToString(const GpufMHAConfig &config) { - std::string result = "GpufMHAConfig:\n"; - absl::StrAppend(&result, - "input_type: ", PrimitiveType_Name(config.input_type), ", "); - absl::StrAppend( - &result, "output_type: ", PrimitiveType_Name(config.output_type), ", "); - absl::StrAppend(&result, "Kind: ", CudnnfMHAKindToString(config.kind), ", "); - if (config.fmha_scale) { - absl::StrAppend(&result, "fmha_scale: ", *config.fmha_scale, ", "); - } - if (config.dropout_rate) { - absl::StrAppend(&result, "dropout_rate: ", *config.dropout_rate, ", "); - } - if (config.seed) { - absl::StrAppend(&result, "seed: ", *config.seed, ", "); - } - absl::StrAppend(&result, "Algorithm Desc: ", config.algorithm.ToString(), - "\n"); - absl::StrAppend(&result, "lhs_bmm1: ", config.lhs_bmm1.ToString(), "\n"); - absl::StrAppend(&result, "rhs_bmm1: ", config.rhs_bmm1.ToString(), "\n"); - absl::StrAppend(&result, "rhs_bmm2: ", config.rhs_bmm2.ToString(), "\n"); - absl::StrAppend(&result, "intermediate_lhs_bmm2: ", - config.intermediate_lhs_bmm2.ToString(), "\n"); - absl::StrAppend(&result, "output: ", config.output.ToString(), "\n"); - - if (config.mask) { - absl::StrAppend(&result, "mask: ", (*config.mask).ToString(), "\n"); - } - - if (config.bias) { - absl::StrAppend(&result, "bias: ", (*config.bias).ToString(), "\n"); - } - - return result; -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h deleted file mode 100644 index d0621cbdff6d74..00000000000000 --- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h +++ /dev/null @@ -1,431 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_ -#define XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/shape.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/dnn.h" -#include "xla/stream_executor/lazy_op_runner.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace gpu { - -inline absl::StatusOr AsCudnnFmhaMaskKind( - xla::gpu::CudnnfMHABackendConfig_MaskType mask_type) { - switch (mask_type) { - case xla::gpu::CudnnfMHABackendConfig::NO_MASK: - return xla::gpu::CudnnfMHAMaskKind::kNoMask; - case xla::gpu::CudnnfMHABackendConfig::PADDING: - return xla::gpu::CudnnfMHAMaskKind::kPadding; - case xla::gpu::CudnnfMHABackendConfig::CAUSAL: - return xla::gpu::CudnnfMHAMaskKind::kCausal; - case xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL: - return xla::gpu::CudnnfMHAMaskKind::kPaddingCausal; - case xla::gpu::CudnnfMHABackendConfig::ALIBI: - return xla::gpu::CudnnfMHAMaskKind::kAlibi; - default: - return xla::Internal("Unknown fmha mask kind."); - } -} - -// This is an interim structure to hold the parameters to construct a -// GpufMHAConfig. -// Struct to describe properties of a FMHA without being tied to specific -// IR. Will be used to help build FMHA thunks from either XLA HLO or -// LHLO GPU dialect in MLIR. -struct GpufMHADescriptor { - CudnnfMHAKind kind; - CudnnfMHABackendConfig backend_config; - CudnnfMHAMaskKind mask_type; - Shape lhs_bmm1_shape; - Shape rhs_bmm1_shape; - Shape rhs_bmm2_shape; - Shape intermediate_lhs_bmm2_shape; - // This will contain both output shape and activation shape - absl::InlinedVector output_shapes; - DotDimensionNumbers bmm1_dnums; - DotDimensionNumbers bmm2_dnums; - - std::optional mask_shape; - std::optional bias_shape; -}; - -struct GpufMHABackwardDescriptor { - CudnnfMHAKind kind; - CudnnfMHABackendConfig backend_config; - CudnnfMHAMaskKind mask_type; - Shape bmm1_grad_gemm1_rhs_shape; - Shape bmm1_grad_gemm2_rhs_shape; - Shape bmm2_grad_gemm1_lhs_shape; - Shape bmm2_grad_gemm2_rhs_shape; - Shape d_output_shape; - Shape d_bmm1_lhs_shape; - Shape d_bmm1_rhs_shape; - Shape d_bmm2_rhs_shape; - DotDimensionNumbers bmm1_grad_gemm1_dnums; - DotDimensionNumbers bmm1_grad_gemm2_dnums; - DotDimensionNumbers bmm2_grad_gemm1_dnums; - DotDimensionNumbers bmm2_grad_gemm2_dnums; - - std::optional d_s_shape; - std::optional fwd_output_shape; - std::optional mask_shape; - std::optional d_bias_shape; - std::optional bias_shape; - bool force_deterministic; -}; - -// Structure to describe static properties of a GPU fused Multi-Headed -// Attention. -struct GpufMHAConfig { - static absl::StatusOr For(const GpufMHADescriptor& fmha_desc); - - absl::StatusOr AsDnnFusedMHAOpConfig() const; - - PrimitiveType - input_type; // Capture the primitive type of one of the inputs of BMM1 - PrimitiveType output_type; - CudnnfMHAKind kind; - std::optional fmha_scale; - std::optional dropout_rate; - std::optional seed; - - se::dnn::AlgorithmDesc algorithm; - CudnnfMHAMaskKind mask_type; - // bias -> [1, num_attn_heads, q_seq_len, kv_seq_len] - // mask -> [batch_size, 1, q_seq_len, kv_seq_len] - se::dnn::MatmulTensorDescriptor lhs_bmm1; - se::dnn::MatmulTensorDescriptor rhs_bmm1; - se::dnn::MatmulTensorDescriptor rhs_bmm2; - se::dnn::MatmulTensorDescriptor intermediate_lhs_bmm2; - se::dnn::TensorDescriptor output; - - std::optional activation; - std::optional mask; - std::optional bias; -}; - -// Structure to describe static properties of a GPU fused Multi-Headed -// Attention backward. -struct GpufMHABackwardConfig { - static absl::StatusOr For( - const GpufMHABackwardDescriptor& fmha_desc); - - absl::StatusOr - AsDnnFusedMHABackwardOpConfig() const; - - PrimitiveType - input_type; // Capture the primitive type of one of the inputs of BMM1 - PrimitiveType output_type; - CudnnfMHAKind kind; - std::optional fmha_scale; - std::optional dropout_rate; - std::optional seed; - - se::dnn::AlgorithmDesc algorithm; - CudnnfMHAMaskKind mask_type; - // mask -> [batch_size, 1, q_seq_len, kv_seq_len] - // d_bias -> [1, num_heads, q_seq_len, kv_seq_len] - se::dnn::MatmulTensorDescriptor bmm1_grad_gemm1_rhs; - se::dnn::MatmulTensorDescriptor bmm1_grad_gemm2_rhs; - se::dnn::MatmulTensorDescriptor bmm2_grad_gemm1_lhs; - se::dnn::MatmulTensorDescriptor bmm2_grad_gemm2_rhs; - se::dnn::MatmulTensorDescriptor d_output; - se::dnn::TensorDescriptor d_bmm1_lhs; - se::dnn::TensorDescriptor d_bmm1_rhs; - se::dnn::TensorDescriptor d_bmm2_rhs; - std::optional d_s; - std::optional mask; - std::optional d_bias; - std::optional fwd_output; - std::optional bias; - bool force_deterministic; -}; - -// Implementation struct exposed for debugging and log analysis. -struct GpufMHAParams { - static absl::StatusOr For( - const GpufMHAConfig& config, se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm2_buffer, se::DeviceMemoryBase output_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer); - - const GpufMHAConfig* config; // Not owned - se::DeviceMemoryBase lhs_bmm1_buffer; - se::DeviceMemoryBase rhs_bmm1_buffer; - se::DeviceMemoryBase rhs_bmm2_buffer; - se::DeviceMemoryBase output_buffer; - std::optional activation_buffer; - std::optional bias_buffer; - std::optional seqlen_q_buffer; - std::optional seqlen_k_buffer; -}; - -struct GpufMHABackwardParams { - static absl::StatusOr For( - const GpufMHABackwardConfig& config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer); - - const GpufMHABackwardConfig* config; // Not owned - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer; - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer; - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer; - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer; - se::DeviceMemoryBase d_output_buffer; - se::DeviceMemoryBase d_bmm1_lhs_buffer; - se::DeviceMemoryBase d_bmm1_rhs_buffer; - se::DeviceMemoryBase d_bmm2_rhs_buffer; - std::optional d_s_buffer; - std::optional d_bias_buffer; - std::optional fwd_output_buffer; - std::optional bias_buffer; - std::optional seqlen_q_buffer; - std::optional seqlen_k_buffer; -}; - -class FusedMultiHeadedAttentionRunner { - public: - using Repr = - std::variant>>; - - FusedMultiHeadedAttentionRunner() = default; - - explicit FusedMultiHeadedAttentionRunner( - std::unique_ptr> runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionRunner(Repr runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionRunner(const GpufMHAConfig& config) - : FusedMultiHeadedAttentionRunner(CreateRunner(config)) { - if (std::holds_alternative(repr_)) { - CHECK(false) << "Cannot construct FusedMultiHeadedAttentionRunner with " - "std::monostate"; - } - } - - se::dnn::AlgorithmDesc ToAlgorithmDesc() const { - return std::visit(ToAlgorithmDescVisitor{}, repr_); - } - - se::dnn::LazyOpRunner* AsFusedMHARunner() { - CHECK(std::holds_alternative< - std::unique_ptr>>(repr_)); - return std::get< - std::unique_ptr>>( - repr_) - .get(); - } - - private: - // The CreateRunner function is defined as static because it - // doesn't need access to any non-static member variables of the - // FusedMultiHeadedAttentionRunner class. Defining it static makes it easy to - // use and makes it clear that it is a utility function that doesn't rely on - // the state of any specific instance of the class. - static Repr CreateRunner(const GpufMHAConfig& config) { - switch (config.kind) { - case CudnnfMHAKind::kSoftmaxDropout: - case CudnnfMHAKind::kSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmaxDropout: - return std::make_unique>( - config.algorithm); - default: - LOG(FATAL) << "Internal error: unsupported CUDNN MHA kind in " - "FusedMultiHeadedAttentionRunner"; - } - } - - struct ToAlgorithmDescVisitor { - template - se::dnn::AlgorithmDesc operator()(const RunnerPtr& runner) { - return runner->ToAlgorithmDesc(); - } - - se::dnn::AlgorithmDesc operator()(const std::monostate&) { - CHECK(false) << "Internal error: uninitialized runner in ToAlgorithmDesc"; - } - }; - - Repr repr_; -}; - -class FusedMultiHeadedAttentionBackwardRunner { - public: - using Repr = std::variant< - std::monostate, // To allow XXX default ctor - std::unique_ptr>>; - - FusedMultiHeadedAttentionBackwardRunner() = default; - - explicit FusedMultiHeadedAttentionBackwardRunner( - std::unique_ptr> - runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionBackwardRunner(Repr runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionBackwardRunner( - const GpufMHABackwardConfig& config) - : FusedMultiHeadedAttentionBackwardRunner(CreateRunner(config)) { - if (std::holds_alternative(repr_)) { - CHECK(false) - << "Cannot construct FusedMultiHeadedAttentionBackwardRunner with " - "std::monostate"; - } - } - - se::dnn::AlgorithmDesc ToAlgorithmDesc() const { - return std::visit(ToAlgorithmDescVisitor{}, repr_); - } - - se::dnn::LazyOpRunner* - AsFusedMHABackwardRunner() { - CHECK(std::holds_alternative< - std::unique_ptr>>( - repr_)); - return std::get>>(repr_) - .get(); - } - - private: - // The CreateRunner function is defined as static because it - // doesn't need access to any non-static member variables of the - // FusedMultiHeadedAttentionBackwardRunner class. Defining it static makes it - // easy to use and makes it clear that it is a utility function that doesn't - // rely on the state of any specific instance of the class. - static Repr CreateRunner(const GpufMHABackwardConfig& config) { - switch (config.kind) { - case CudnnfMHAKind::kBackwardSoftmaxDropout: - case CudnnfMHAKind::kBackwardSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout: - return std::make_unique< - se::dnn::LazyOpRunner>( - config.algorithm); - default: - LOG(FATAL) << "Internal error: unsupported CUDNN MHA kind in " - "FusedMultiHeadedAttentionBackwardRunner"; - } - } - - struct ToAlgorithmDescVisitor { - template - se::dnn::AlgorithmDesc operator()(const RunnerPtr& runner) { - return runner->ToAlgorithmDesc(); - } - - se::dnn::AlgorithmDesc operator()(const std::monostate&) { - CHECK(false) << "Internal error: uninitialized runner in ToAlgorithmDesc"; - } - }; - - Repr repr_; -}; - -struct RunFusedMHAOptions { - // Nullable output-parameter pointer for profiling results. - // Profile results remain unused for now since cuDNN FMHA has only one - // algorithm for now. - se::dnn::ProfileResult* profile_result = nullptr; - - // Use this runner cache (and its configured algorithm), instead of the one - // from the instruction. - FusedMultiHeadedAttentionRunner* runner_cache; -}; - -struct RunFusedMHABackwardOptions { - // Nullable output-parameter pointer for profiling results. - // Profile results remain unused for now since cuDNN FMHA has only one - // algorithm for now. - se::dnn::ProfileResult* profile_result = nullptr; - - // Use this runner cache (and its configured algorithm), instead of the one - // from the instruction. - FusedMultiHeadedAttentionBackwardRunner* runner_cache; -}; - -absl::Status RunGpuFMHA(const GpufMHAConfig& fmha_config, - se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm2_buffer, - se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase scratch_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, - se::Stream* stream, RunFusedMHAOptions = {}); - -absl::Status RunGpuFMHABackward( - const GpufMHABackwardConfig& fmha_config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, se::Stream* stream, - RunFusedMHABackwardOptions = {}); - -std::string ToString(const GpufMHAConfig& config); - -} // namespace gpu -} // namespace xla -#endif // XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_ diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 03690caeaa8121..b73225a1bd3c56 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -26,7 +26,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/flat_hash_map.h" @@ -99,7 +98,6 @@ limitations under the License. #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" #include "xla/service/gpu/gpu_asm_opts_util.h" #include "xla/service/gpu/gpu_conv_runner.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" #include "xla/service/gpu/gpu_norm_runner.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" @@ -122,7 +120,6 @@ limitations under the License. #include "xla/service/gpu/runtime/copy_thunk.h" #include "xla/service/gpu/runtime/custom_call_thunk.h" #include "xla/service/gpu/runtime/fft_thunk.h" -#include "xla/service/gpu/runtime/fused_mha_thunk.h" #include "xla/service/gpu/runtime/gemm_thunk.h" #include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h" #include "xla/service/gpu/runtime/infeed_thunk.h" @@ -173,6 +170,7 @@ limitations under the License. #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/service/gpu/runtime/cholesky_thunk.h" #include "xla/service/gpu/runtime/cub_sort_thunk.h" +#include "xla/service/gpu/runtime/cudnn_thunk.h" #include "xla/service/gpu/runtime/triangular_solve_thunk.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -955,221 +953,17 @@ absl::Status IrEmitterUnnested::EmitNormThunk( return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitFusedMHAThunk( +absl::Status IrEmitterUnnested::EmitCuDnnThunk( const HloCustomCallInstruction* instr) { - const HloInstruction* lhs_bmm1 = instr->operand(0); - const HloInstruction* rhs_bmm1 = instr->operand(1); - const HloInstruction* rhs_bmm2 = instr->operand(2); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_bmm1_slice, - GetAllocationSliceForHlo(lhs_bmm1)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm1_slice, - GetAllocationSliceForHlo(rhs_bmm1)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm2_slice, - GetAllocationSliceForHlo(rhs_bmm2)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, - GetAllocationSliceForHlo(instr, {0})); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, - GetAllocationSliceForHlo( - instr, {instr->shape().tuple_shapes_size() - 1})); - BufferAllocation::Slice activation_slice; - bool has_activation = xla::ShapeUtil::TupleElementCount(instr->shape()) == 3; - if (has_activation) { - TF_ASSIGN_OR_RETURN(activation_slice, GetAllocationSliceForHlo(instr, {1})); - } - - TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, - xla::gpu::GetCudnnfMHAKind(instr)); - BufferAllocation::Slice mask_slice, bias_slice; - BufferAllocation::Slice seqlen_q_slice, seqlen_k_slice; - std::optional mask_shape, bias_shape; - { - bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax || - kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout; - - if (has_bias) { - const HloInstruction* bias = instr->operand(3); - TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSliceForHlo(bias)); - bias_shape = bias->shape(); - } - int64_t seqlen_qk_operand_index = 3 + has_bias; - bool has_seqlen_qk = seqlen_qk_operand_index == instr->operand_count() - 2; - if (has_seqlen_qk) { - const HloInstruction* seqlen_q = instr->operand(seqlen_qk_operand_index); - TF_ASSIGN_OR_RETURN(seqlen_q_slice, GetAllocationSliceForHlo(seqlen_q)); - const HloInstruction* seqlen_k = - instr->operand(seqlen_qk_operand_index + 1); - TF_ASSIGN_OR_RETURN(seqlen_k_slice, GetAllocationSliceForHlo(seqlen_k)); - } - } - - TF_ASSIGN_OR_RETURN(const auto gpu_config, - instr->backend_config()); - const xla::gpu::CudnnfMHABackendConfig& config = - gpu_config.cudnn_fmha_backend_config(); - Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); - absl::InlinedVector output_shapes = { - ShapeUtil::GetSubshape(instr->shape(), {0})}; - if (has_activation) { - output_shapes.push_back(ShapeUtil::GetSubshape(instr->shape(), {1})); - } - TF_ASSIGN_OR_RETURN(const auto mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - GpufMHADescriptor descriptor = {kind, - config, - mask_type, - lhs_bmm1->shape(), - rhs_bmm1->shape(), - rhs_bmm2->shape(), - intermediate_tensor_shape, - output_shapes, - config.bmm1_dot_dimension_numbers(), - config.bmm2_dot_dimension_numbers(), - mask_shape, - bias_shape}; - - TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config, - GpufMHAConfig::For(descriptor)); - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(fmha_config), - lhs_bmm1_slice, rhs_bmm1_slice, rhs_bmm2_slice, output_slice, - scratch_slice, mask_slice, bias_slice, activation_slice, seqlen_q_slice, - seqlen_k_slice)); - return absl::OkStatus(); -} - -absl::Status IrEmitterUnnested::EmitFusedMHABackwardThunk( - const HloCustomCallInstruction* instr) { - TF_ASSIGN_OR_RETURN(const auto gpu_config, - instr->backend_config()); - const xla::gpu::CudnnfMHABackendConfig& config = - gpu_config.cudnn_fmha_backend_config(); - - int input_index = 0; - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm1_grad_gemm1_rhs_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm2_rhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm1_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm2_rhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm2_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm1_lhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm2_grad_gemm1_lhs_shape; - - Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); - bmm2_grad_gemm1_lhs_shape = intermediate_tensor_shape; - input_index++; - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_output_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape d_output_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, GetCudnnfMHAKind(instr)); - BufferAllocation::Slice mask_slice; - std::optional mask_shape; - - bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax || - kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout); - BufferAllocation::Slice bias_slice; - std::optional bias_shape; - if (has_bias) { - TF_ASSIGN_OR_RETURN(bias_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - bias_shape = instr->operand(input_index++)->shape(); - } - - BufferAllocation::Slice fwd_output_slice; - std::optional fwd_output_shape; - - TF_ASSIGN_OR_RETURN(fwd_output_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - fwd_output_shape = instr->operand(input_index++)->shape(); - - BufferAllocation::Slice seqlen_q_slice, seqlen_k_slice; - bool has_seqlen_qk = input_index == instr->operand_count() - 2; - if (has_seqlen_qk) { - const HloInstruction* seqlen_q = instr->operand(input_index); - TF_ASSIGN_OR_RETURN(seqlen_q_slice, GetAllocationSliceForHlo(seqlen_q)); - const HloInstruction* seqlen_k = instr->operand(input_index + 1); - TF_ASSIGN_OR_RETURN(seqlen_k_slice, GetAllocationSliceForHlo(seqlen_k)); - input_index += 2; - } - TF_RET_CHECK(input_index == instr->operand_count()); - - int output_index = 0; - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_lhs_slice, - GetAllocationSliceForHlo(instr, {output_index})); - Shape d_bmm1_lhs_shape = - ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_rhs_slice, - GetAllocationSliceForHlo(instr, {output_index})); - Shape d_bmm1_rhs_shape = - ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm2_rhs_slice, - GetAllocationSliceForHlo(instr, {output_index})); - Shape d_bmm2_rhs_shape = - ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - - BufferAllocation::Slice d_s_slice; - std::optional d_s_shape; - - bool has_dbias = instr->shape().tuple_shapes().size() == 5; - BufferAllocation::Slice d_bias_slice; - std::optional d_bias_shape; - if (has_dbias) { - TF_ASSIGN_OR_RETURN(d_bias_slice, - GetAllocationSliceForHlo(instr, {output_index})); - d_bias_shape = ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - } - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, - GetAllocationSliceForHlo(instr, {output_index++})); - TF_RET_CHECK(output_index == instr->shape().tuple_shapes().size()); - TF_ASSIGN_OR_RETURN(const auto mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - bool force_deterministic = config.force_deterministic(); - GpufMHABackwardDescriptor descriptor = { - kind, - config, - mask_type, - bmm1_grad_gemm1_rhs_shape, - bmm1_grad_gemm2_rhs_shape, - bmm2_grad_gemm1_lhs_shape, - bmm2_grad_gemm2_rhs_shape, - d_output_shape, - d_bmm1_lhs_shape, - d_bmm1_rhs_shape, - d_bmm2_rhs_shape, - config.bmm1_grad_gemm1_dot_dimension_numbers(), - config.bmm1_grad_gemm2_dot_dimension_numbers(), - config.bmm2_grad_gemm1_dot_dimension_numbers(), - config.bmm2_grad_gemm2_dot_dimension_numbers(), - d_s_shape, - fwd_output_shape, - mask_shape, - d_bias_shape, - bias_shape, - force_deterministic}; - - TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_backward_config, - GpufMHABackwardConfig::For(descriptor)); - - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), - std::move(fmha_backward_config), bmm1_grad_gemm1_rhs_slice, - bmm1_grad_gemm2_rhs_slice, bmm2_grad_gemm1_lhs_slice, - bmm2_grad_gemm2_rhs_slice, d_output_slice, scratch_slice, - d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, d_s_slice, - mask_slice, d_bias_slice, fwd_output_slice, bias_slice, seqlen_q_slice, - seqlen_k_slice)); - + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + KernelArguments::Create(ir_emitter_context_->buffer_assignment(), instr, + instr->operands())); + TF_ASSIGN_OR_RETURN(const std::string fingerprint, + FingerprintWithBackendConfig(*instr)); + AddThunkToThunkSequence(std::make_unique( + fingerprint, Thunk::ThunkInfo::WithProfileAnnotation(instr), + kernel_arguments.args())); return absl::OkStatus(); } @@ -2921,11 +2715,8 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( if (IsCustomCallToDnnNorm(*instr)) { return EmitNormThunk(custom_call); } - if (IsFwdCustomCallTofMHA(*instr)) { - return EmitFusedMHAThunk(custom_call); - } - if (IsBwdCustomCallTofMHA(*instr)) { - return EmitFusedMHABackwardThunk(custom_call); + if (IsCustomCallTofMHA(*instr)) { + return EmitCuDnnThunk(custom_call); } #endif // GOOGLE_CUDA if (IsCustomCallToTopK(*instr)) { diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h index f97f106ddfc0df..d19dd5d9c4172c 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h @@ -147,8 +147,7 @@ class IrEmitterUnnested : public IrEmitter { absl::Status EmitConvolutionReorderThunk( const HloCustomCallInstruction* instr); absl::Status EmitNormThunk(const HloCustomCallInstruction* instr); - absl::Status EmitFusedMHAThunk(const HloCustomCallInstruction* instr); - absl::Status EmitFusedMHABackwardThunk(const HloCustomCallInstruction* instr); + absl::Status EmitCuDnnThunk(const HloCustomCallInstruction* instr); #endif // GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM absl::Status EmitCubDeviceRadixSort(const HloCustomCallInstruction* instr); diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index a2d3988fbfb6b6..7925ebc33baa16 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -69,6 +69,7 @@ limitations under the License. #include "xla/service/gpu/transforms/conv_padding_legalization.h" #include "xla/service/gpu/transforms/conv_rewriter.h" #include "xla/service/gpu/transforms/cublas_pad_for_gemms.h" +#include "xla/service/gpu/transforms/cudnn_custom_call_compiler.h" #include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h" #include "xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h" #include "xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h" @@ -77,7 +78,6 @@ limitations under the License. #include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h" #include "xla/service/gpu/transforms/cudnn_simplify_padding.h" #include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h" -#include "xla/service/gpu/transforms/cudnn_workspace_rewriter.h" #include "xla/service/gpu/transforms/dot_sparsity_rewriter.h" #include "xla/service/gpu/transforms/gpusolver_rewriter.h" #include "xla/service/gpu/transforms/sort_rewriter.h" @@ -342,9 +342,6 @@ absl::Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( // Transform TriangularSolve ops into custom-calls, so we can add temp // memory. post_pipeline.AddPass(); - if (stream_exec) { - post_pipeline.AddPass(*stream_exec); - } TF_RETURN_IF_ERROR(post_pipeline.Run(hlo_module).status()); return absl::OkStatus(); @@ -395,15 +392,17 @@ absl::Status NVPTXCompiler::AddCustomKernelReplacementPasses( return absl::OkStatus(); } -absl::Status NVPTXCompiler::RunCudnnFusionCompilerPass( +absl::Status NVPTXCompiler::RunCudnnCompilerPasses( HloModule* module, se::StreamExecutor* stream_exec, BinaryMap* dnn_compiled_graphs) { tsl::profiler::ScopedAnnotation annotation([&] { return absl::StrFormat("XlaCompileCudnnFusion:#module=%s,program_id=%d#", module->name(), module->unique_id()); }); - CuDnnFusionCompiler cudnn_compiler(*stream_exec, *dnn_compiled_graphs); - return cudnn_compiler.Run(module).status(); + CuDnnFusionCompiler fusion_compiler(*stream_exec, *dnn_compiled_graphs); + TF_RETURN_IF_ERROR(fusion_compiler.Run(module).status()); + CuDnnCustomCallCompiler call_compiler(*stream_exec, *dnn_compiled_graphs); + return call_compiler.Run(module).status(); } namespace { diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.h b/third_party/xla/xla/service/gpu/nvptx_compiler.h index fb74f553d67967..6d84deb4398176 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.h +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.h @@ -84,9 +84,9 @@ class NVPTXCompiler : public GpuCompiler { absl::Status AddCustomKernelReplacementPasses( HloPassPipeline* pipeline, const DebugOptions& debug_options) override; - absl::Status RunCudnnFusionCompilerPass( - HloModule* module, se::StreamExecutor* stream_exec, - BinaryMap* dnn_compiled_graphs) override; + absl::Status RunCudnnCompilerPasses(HloModule* module, + se::StreamExecutor* stream_exec, + BinaryMap* dnn_compiled_graphs) override; HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const override; diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 861fcca6b45eb4..d6b6f1ffa78319 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -79,7 +79,6 @@ cc_library( "//xla/service:executable", "//xla/service:global_device_id", "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:gpu_fused_mha_runner", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:stream_executor_util", @@ -121,7 +120,6 @@ cc_library( ":copy_thunk", ":cudnn_thunk", ":custom_call_thunk", - ":fused_mha_thunk", ":gemm_thunk", ":gpublas_lt_matmul_thunk", ":kernel_thunk", @@ -660,28 +658,6 @@ cc_library( ], ) -cc_library( - name = "fused_mha_thunk", - srcs = ["fused_mha_thunk.cc"], - hdrs = ["fused_mha_thunk.h"], - deps = [ - ":thunk", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:buffer_assignment", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:gpu_fused_mha_runner", - "//xla/stream_executor", - "//xla/stream_executor:lazy_op_runner", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "gemm_thunk", srcs = ["gemm_thunk.cc"], diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc index d913871332933d..ce99623c16de6e 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -1167,314 +1167,6 @@ CommandBufferCmd::BufferUsageVector GemmCmd::buffers() { {workspace_, MemoryAccess::kWrite}}; } -//===----------------------------------------------------------------------===// -// FusedMHACmd -//===----------------------------------------------------------------------===// - -FusedMHACmd::FusedMHACmd( - ExecutionStreamId execution_stream_id, GpufMHAConfig config, - BufferAllocation::Slice lhs_bmm1, BufferAllocation::Slice rhs_bmm1, - BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output, - BufferAllocation::Slice scratch, BufferAllocation::Slice mask, - BufferAllocation::Slice bias, BufferAllocation::Slice activation, - BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k) - : TracedCommandBufferCmd(CommandBufferCmdType::kFusedMHACmd, - execution_stream_id), - config_(std::move(config)), - lhs_bmm1_buffer_(lhs_bmm1), - rhs_bmm1_buffer_(rhs_bmm1), - rhs_bmm2_buffer_(rhs_bmm2), - output_buffer_(output), - scratch_buffer_(scratch), - bias_buffer_(bias), - activation_buffer_(activation), - seqlen_q_buffer_(seqlen_q), - seqlen_k_buffer_(seqlen_k) {} - -FusedMultiHeadedAttentionRunner& FusedMHACmd::GetOrCreateRunner( - const stream_executor::Stream* stream) { - absl::MutexLock lock(&mutex_); - auto it = runner_cache_.find(stream); - if (it == runner_cache_.end()) { - it = runner_cache_ - .insert({stream, std::make_unique( - config_)}) - .first; - } - return *it->second; -} - -absl::Status FusedMHACmd::Initialize(const Thunk::InitializeParams& params, - StateManager& state) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(params.command_buffer_trace_stream).AsFusedMHARunner(); - TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHAOpConfig()); - return lazy_runner - ->GetOrCreateRunner(config, params.command_buffer_trace_stream) - .status(); -} - -absl::Status FusedMHACmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(execute_params.command_buffer_trace_stream) - .AsFusedMHARunner(); - CHECK(lazy_runner) << "FusedMHA lazy runner cache should have been populated"; - - const auto& buffer_allocations = *execute_params.buffer_allocations; - se::DeviceMemoryBase lhs_bmm1_buffer = - buffer_allocations.GetDeviceAddress(lhs_bmm1_buffer_); - se::DeviceMemoryBase rhs_bmm1_buffer = - buffer_allocations.GetDeviceAddress(rhs_bmm1_buffer_); - se::DeviceMemoryBase rhs_bmm2_buffer = - buffer_allocations.GetDeviceAddress(rhs_bmm2_buffer_); - se::DeviceMemoryBase output_buffer = - buffer_allocations.GetDeviceAddress(output_buffer_); - se::DeviceMemoryBase scratch_buffer = - buffer_allocations.GetDeviceAddress(scratch_buffer_); - - std::optional bias_buffer = - AssignBufferIfNotNull(buffer_allocations, bias_buffer_); - std::optional activation_buffer = - AssignBufferIfNotNull(buffer_allocations, activation_buffer_); - std::optional seqlen_q_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_); - std::optional seqlen_k_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_); - - ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); - VLOG(5) << "FusedMHACmd with execution_scope_id: " - << execution_scope_id.value(); - VLOG(5) << " lhs_bmm1_buffer: " << lhs_bmm1_buffer_.ToString(); - VLOG(5) << " rhs_bmm1_buffer: " << rhs_bmm1_buffer_.ToString(); - VLOG(5) << " rhs_bmm2_buffer: " << rhs_bmm2_buffer_.ToString(); - VLOG(5) << " output_buffer: " << output_buffer_.ToString(); - VLOG(5) << " scratch_buffer: " << scratch_buffer_.ToString(); - VLOG(5) << " bias_buffer: " << bias_buffer_.ToString(); - VLOG(5) << " activation_buffer: " << activation_buffer_.ToString(); - VLOG(5) << " seqlen_q_buffer: " << seqlen_q_buffer_.ToString(); - VLOG(5) << " seqlen_k_buffer: " << seqlen_k_buffer_.ToString(); - - RunFusedMHAOptions opts; - opts.runner_cache = - &GetOrCreateRunner(execute_params.command_buffer_trace_stream); - return AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { - return RunGpuFMHA(config_, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, scratch_buffer, - bias_buffer, activation_buffer, seqlen_q_buffer, - seqlen_k_buffer, stream, opts); - }); -} - -FusedMHACmd::BufferUsageVector FusedMHACmd::buffers() { - BufferUsageVector buffer_usage; - buffer_usage.reserve(9); - buffer_usage.push_back({lhs_bmm1_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({rhs_bmm1_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({rhs_bmm2_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({output_buffer_, MemoryAccess::kWrite}); - buffer_usage.push_back({scratch_buffer_, MemoryAccess::kWrite}); - if (bias_buffer_.allocation() != nullptr) { - buffer_usage.push_back({bias_buffer_, MemoryAccess::kRead}); - } - if (activation_buffer_.allocation() != nullptr) { - buffer_usage.push_back({activation_buffer_, MemoryAccess::kRead}); - } - if (seqlen_q_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_q_buffer_, MemoryAccess::kRead}); - } - if (seqlen_k_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_k_buffer_, MemoryAccess::kRead}); - } - return buffer_usage; -} - -//===----------------------------------------------------------------------===// -// FusedMHABackwardCmd -//===----------------------------------------------------------------------===// - -FusedMHABackwardCmd::FusedMHABackwardCmd( - ExecutionStreamId execution_stream_id, GpufMHABackwardConfig config, - BufferAllocation::Slice bmm1_grad_gemm1_rhs, - BufferAllocation::Slice bmm1_grad_gemm2_rhs, - BufferAllocation::Slice bmm2_grad_gemm1_lhs, - BufferAllocation::Slice bmm2_grad_gemm2_rhs, - BufferAllocation::Slice d_output, BufferAllocation::Slice scratch, - BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs, - BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s, - BufferAllocation::Slice d_bias, BufferAllocation::Slice fwd_output, - BufferAllocation::Slice bias, BufferAllocation::Slice seqlen_q, - BufferAllocation::Slice seqlen_k) - : TracedCommandBufferCmd(CommandBufferCmdType::kFusedMHABackwardCmd, - execution_stream_id), - config_(std::move(config)), - bmm1_grad_gemm1_rhs_buffer_(bmm1_grad_gemm1_rhs), - bmm1_grad_gemm2_rhs_buffer_(bmm1_grad_gemm2_rhs), - bmm2_grad_gemm1_lhs_buffer_(bmm2_grad_gemm1_lhs), - bmm2_grad_gemm2_rhs_buffer_(bmm2_grad_gemm2_rhs), - d_output_buffer_(d_output), - scratch_buffer_(scratch), - d_bmm1_lhs_buffer_(d_bmm1_lhs), - d_bmm1_rhs_buffer_(d_bmm1_rhs), - d_bmm2_rhs_buffer_(d_bmm2_rhs), - d_s_buffer_(d_s), - d_bias_buffer_(d_bias), - fwd_output_buffer_(fwd_output), - bias_buffer_(bias), - seqlen_q_buffer_(seqlen_q), - seqlen_k_buffer_(seqlen_k) {} - -FusedMultiHeadedAttentionBackwardRunner& FusedMHABackwardCmd::GetOrCreateRunner( - const stream_executor::Stream* stream) { - absl::MutexLock lock(&mutex_); - auto it = runner_cache_.find(stream); - if (it == runner_cache_.end()) { - it = runner_cache_ - .insert({stream, - std::make_unique( - config_)}) - .first; - } - return *it->second; -} - -absl::Status FusedMHABackwardCmd::Initialize( - const Thunk::InitializeParams& params, StateManager& state) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(params.command_buffer_trace_stream) - .AsFusedMHABackwardRunner(); - TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHABackwardOpConfig()); - return lazy_runner - ->GetOrCreateRunner(config, params.command_buffer_trace_stream) - .status(); -} - -absl::Status FusedMHABackwardCmd::Record( - const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(execute_params.command_buffer_trace_stream) - .AsFusedMHABackwardRunner(); - CHECK(lazy_runner) - << "FusedMHABackward lazy runner cache should have been populated"; - - const auto& buffer_allocations = *execute_params.buffer_allocations; - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm1_grad_gemm1_rhs_buffer_); - - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm1_grad_gemm2_rhs_buffer_); - - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer = - buffer_allocations.GetDeviceAddress(bmm2_grad_gemm1_lhs_buffer_); - - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm2_grad_gemm2_rhs_buffer_); - - se::DeviceMemoryBase d_output_buffer = - buffer_allocations.GetDeviceAddress(d_output_buffer_); - - se::DeviceMemoryBase scratch_buffer = - buffer_allocations.GetDeviceAddress(scratch_buffer_); - - se::DeviceMemoryBase d_bmm1_lhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm1_lhs_buffer_); - - se::DeviceMemoryBase d_bmm1_rhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm1_rhs_buffer_); - - se::DeviceMemoryBase d_bmm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm2_rhs_buffer_); - - std::optional d_s_buffer = - AssignBufferIfNotNull(buffer_allocations, d_s_buffer_); - std::optional d_bias_buffer = - AssignBufferIfNotNull(buffer_allocations, d_bias_buffer_); - std::optional fwd_output_buffer = - AssignBufferIfNotNull(buffer_allocations, fwd_output_buffer_); - std::optional bias_buffer = - AssignBufferIfNotNull(buffer_allocations, bias_buffer_); - std::optional seqlen_q_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_); - std::optional seqlen_k_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_); - - ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); - VLOG(5) << "FusedMHABackwardCmd with execution_scope_id: " - << execution_scope_id.value(); - VLOG(5) << "bmm1_grad_gemm1_rhs_buffer" - << bmm1_grad_gemm1_rhs_buffer_.ToString(); - VLOG(5) << "bmm1_grad_gemm2_rhs_buffer" - << bmm1_grad_gemm2_rhs_buffer_.ToString(); - VLOG(5) << "bmm2_grad_gemm1_lhs_buffer" - << bmm2_grad_gemm1_lhs_buffer_.ToString(); - VLOG(5) << "bmm2_grad_gemm2_rhs_buffer" - << bmm2_grad_gemm2_rhs_buffer_.ToString(); - VLOG(5) << "d_output_buffer" << d_output_buffer_.ToString(); - VLOG(5) << "scratch_buffer" << scratch_buffer_.ToString(); - VLOG(5) << "d_bmm1_lhs_buffer" << d_bmm1_lhs_buffer_.ToString(); - VLOG(5) << "d_bmm1_rhs_buffer" << d_bmm1_rhs_buffer_.ToString(); - VLOG(5) << "d_bmm2_rhs_buffer" << d_bmm2_rhs_buffer_.ToString(); - VLOG(5) << "d_s_buffer" << d_s_buffer_.ToString(); - VLOG(5) << "d_bias_buffer" << d_bias_buffer_.ToString(); - VLOG(5) << "fwd_output_buffer" << fwd_output_buffer_.ToString(); - VLOG(5) << "bias_buffer" << bias_buffer_.ToString(); - VLOG(5) << "seqlen_q_buffer" << seqlen_q_buffer_.ToString(); - VLOG(5) << "seqlen_k_buffer" << seqlen_k_buffer_.ToString(); - - RunFusedMHABackwardOptions opts; - opts.runner_cache = - &GetOrCreateRunner(execute_params.command_buffer_trace_stream); - return AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { - return RunGpuFMHABackward( - config_, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, - d_output_buffer, scratch_buffer, d_bmm1_lhs_buffer, - d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, - fwd_output_buffer, bias_buffer, seqlen_q_buffer, seqlen_k_buffer, - stream, opts); - }); -} - -FusedMHABackwardCmd::BufferUsageVector FusedMHABackwardCmd::buffers() { - BufferUsageVector buffer_usage; - buffer_usage.reserve(15); - - buffer_usage.push_back({bmm1_grad_gemm1_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({bmm1_grad_gemm2_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({bmm2_grad_gemm1_lhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({bmm2_grad_gemm2_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({d_output_buffer_, MemoryAccess::kWrite}); - buffer_usage.push_back({scratch_buffer_, MemoryAccess::kWrite}); - buffer_usage.push_back({d_bmm1_lhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({d_bmm1_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({d_bmm2_rhs_buffer_, MemoryAccess::kRead}); - - if (d_s_buffer_.allocation() != nullptr) { - buffer_usage.push_back({d_s_buffer_, MemoryAccess::kRead}); - }; - if (d_bias_buffer_.allocation() != nullptr) { - buffer_usage.push_back({d_bias_buffer_, MemoryAccess::kRead}); - }; - if (fwd_output_buffer_.allocation() != nullptr) { - buffer_usage.push_back({fwd_output_buffer_, MemoryAccess::kRead}); - }; - if (bias_buffer_.allocation() != nullptr) { - buffer_usage.push_back({bias_buffer_, MemoryAccess::kRead}); - }; - if (seqlen_q_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_q_buffer_, MemoryAccess::kRead}); - }; - if (seqlen_k_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_k_buffer_, MemoryAccess::kRead}); - }; - - return buffer_usage; -} - //===----------------------------------------------------------------------===// // CublasLtCmd //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h index b7a077e81a9e4f..27e8fea0d86366 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h @@ -40,7 +40,6 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" @@ -81,8 +80,6 @@ namespace xla::gpu { V(kReduceScatter, "ReduceScatterCmd") \ V(kAllGatherCmd, "AllGatherCmd") \ V(kCollectiveBroadcastCmd, "CollectiveBroadcastCmd") \ - V(kFusedMHACmd, "FusedMHACmd") \ - V(kFusedMHABackwardCmd, "FusedMHABackwardCmd") \ V(kUnknownCmd, "UnknownCmd") \ // clang-format on @@ -782,112 +779,6 @@ class GemmCmd : public TracedCommandBufferCmd { const bool deterministic_; }; -//===----------------------------------------------------------------------===// -// FusedMHACmd -//===----------------------------------------------------------------------===// - -class FusedMHACmd : public TracedCommandBufferCmd { - public: - FusedMHACmd(ExecutionStreamId execution_stream_id, GpufMHAConfig config, - BufferAllocation::Slice lhs_bmm1, - BufferAllocation::Slice rhs_bmm1, - BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output, - BufferAllocation::Slice scratch, BufferAllocation::Slice mask, - BufferAllocation::Slice bias, BufferAllocation::Slice activation, - BufferAllocation::Slice seqlen_q, - BufferAllocation::Slice seqlen_k); - - absl::Status Initialize(const Thunk::InitializeParams& params, - StateManager& state) override; - - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; - - BufferUsageVector buffers() override; - - bool IsNestedCommandBuffer() const final { return true; } - - private: - FusedMultiHeadedAttentionRunner& GetOrCreateRunner( - const stream_executor::Stream* stream); - - const GpufMHAConfig config_; - BufferAllocation::Slice lhs_bmm1_buffer_; - BufferAllocation::Slice rhs_bmm1_buffer_; - BufferAllocation::Slice rhs_bmm2_buffer_; - BufferAllocation::Slice output_buffer_; - BufferAllocation::Slice scratch_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice activation_buffer_; - BufferAllocation::Slice seqlen_q_buffer_; - BufferAllocation::Slice seqlen_k_buffer_; - - // FusedMHA config - absl::Mutex mutex_; - absl::flat_hash_map> - runner_cache_ ABSL_GUARDED_BY(mutex_); -}; - -//===----------------------------------------------------------------------===// -// FusedMHABackwardCmd -//===----------------------------------------------------------------------===// - -class FusedMHABackwardCmd : public TracedCommandBufferCmd { - public: - FusedMHABackwardCmd( - ExecutionStreamId execution_stream_id, GpufMHABackwardConfig config, - BufferAllocation::Slice bmm1_grad_gemm1_rhs, - BufferAllocation::Slice bmm1_grad_gemm2_rhs, - BufferAllocation::Slice bmm2_grad_gemm1_lhs, - BufferAllocation::Slice bmm2_grad_gemm2_rhs, - BufferAllocation::Slice d_output, BufferAllocation::Slice scratch, - BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs, - BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s, - BufferAllocation::Slice d_bias, BufferAllocation::Slice fwd_output, - BufferAllocation::Slice bias, BufferAllocation::Slice seqlen_q, - BufferAllocation::Slice seqlen_k); - - absl::Status Initialize(const Thunk::InitializeParams& params, - StateManager& state) override; - - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; - - BufferUsageVector buffers() override; - - bool IsNestedCommandBuffer() const final { return true; } - - private: - FusedMultiHeadedAttentionBackwardRunner& GetOrCreateRunner( - const stream_executor::Stream* stream); - - const GpufMHABackwardConfig config_; - BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer_; - BufferAllocation::Slice bmm1_grad_gemm2_rhs_buffer_; - BufferAllocation::Slice bmm2_grad_gemm1_lhs_buffer_; - BufferAllocation::Slice bmm2_grad_gemm2_rhs_buffer_; - BufferAllocation::Slice d_output_buffer_; - BufferAllocation::Slice scratch_buffer_; - BufferAllocation::Slice d_bmm1_lhs_buffer_; - BufferAllocation::Slice d_bmm1_rhs_buffer_; - BufferAllocation::Slice d_bmm2_rhs_buffer_; - BufferAllocation::Slice d_s_buffer_; - BufferAllocation::Slice d_bias_buffer_; - BufferAllocation::Slice fwd_output_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice seqlen_q_buffer_; - BufferAllocation::Slice seqlen_k_buffer_; - - // FusedMHA config - absl::Mutex mutex_; - absl::flat_hash_map> - runner_cache_ ABSL_GUARDED_BY(mutex_); -}; - //===----------------------------------------------------------------------===// // CublasLtCmd //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc index 54e01fab8e1109..230d050856fcc2 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc @@ -29,7 +29,6 @@ limitations under the License. #include "xla/service/gpu/runtime/copy_thunk.h" #include "xla/service/gpu/runtime/cudnn_thunk.h" #include "xla/service/gpu/runtime/custom_call_thunk.h" -#include "xla/service/gpu/runtime/fused_mha_thunk.h" #include "xla/service/gpu/runtime/gemm_thunk.h" #include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h" #include "xla/service/gpu/runtime/kernel_thunk.h" @@ -143,27 +142,6 @@ static absl::StatusOr Convert(const CublasLtMatmulThunk& thunk) { thunk.workspace().value()); } -static absl::StatusOr Convert(const FusedMHAThunk& thunk) { - return std::make_unique( - thunk.execution_stream_id(), thunk.config(), thunk.lhs_bmm1_buffer(), - thunk.rhs_bmm1_buffer(), thunk.rhs_bmm2_buffer(), thunk.output_buffer(), - thunk.scratch_buffer(), BufferAllocation::Slice(), thunk.bias_buffer(), - thunk.activation_buffer(), thunk.seqlen_q_buffer(), - thunk.seqlen_k_buffer()); -} - -static absl::StatusOr Convert(const FusedMHABackwardThunk& thunk) { - return std::make_unique( - thunk.execution_stream_id(), thunk.config(), - thunk.bmm1_grad_gemm1_rhs_buffer(), thunk.bmm1_grad_gemm2_rhs_buffer(), - thunk.bmm2_grad_gemm1_lhs_buffer(), thunk.bmm2_grad_gemm2_rhs_buffer(), - thunk.d_output_buffer(), thunk.scratch_buffer(), - thunk.d_bmm1_lhs_buffer(), thunk.d_bmm1_rhs_buffer(), - thunk.d_bmm2_rhs_buffer(), thunk.d_s_buffer(), thunk.d_bias_buffer(), - thunk.fwd_output_buffer(), thunk.bias_buffer(), thunk.seqlen_q_buffer(), - thunk.seqlen_k_buffer()); -} - static absl::StatusOr Convert( const ConditionalThunk& thunk, CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { @@ -276,10 +254,6 @@ static absl::Status AppendCommands( return append(Convert(thunk)); case Thunk::Kind::kCustomKernel: return append(Convert(thunk)); - case Thunk::Kind::kFusedMHA: - return append(Convert(thunk)); - case Thunk::Kind::kFusedMHABackward: - return append(Convert(thunk)); case Thunk::Kind::kKernel: return append(Convert(thunk)); case Thunk::Kind::kGemm: diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 85c4df22b8b215..b52668fbee028c 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -1124,9 +1124,9 @@ xla_cc_test( # TODO(b/358278858): Currently lacking test coverage. cc_library( - name = "cudnn_workspace_rewriter", - srcs = if_cuda_is_configured(["cudnn_workspace_rewriter.cc"]), - hdrs = if_cuda_is_configured(["cudnn_workspace_rewriter.h"]), + name = "cudnn_custom_call_compiler", + srcs = if_cuda_is_configured(["cudnn_custom_call_compiler.cc"]), + hdrs = if_cuda_is_configured(["cudnn_custom_call_compiler.h"]), deps = if_cuda_is_configured([ "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -1141,9 +1141,9 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass", + "//xla/service/gpu/runtime:cudnn_thunk", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", - "//xla/service/gpu:gpu_fused_mha_runner", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:stream_executor_util", "//xla/stream_executor:dnn", diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc new file mode 100644 index 00000000000000..00b73c9112b40b --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc @@ -0,0 +1,660 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/cudnn_custom_call_compiler.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/stream_executor_util.h" +#include "xla/stream_executor/cuda/cuda_dnn.h" +#include "xla/stream_executor/cuda/cudnn_frontend_helpers.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +inline absl::StatusOr AsCudnnFmhaMaskKind( + CudnnfMHABackendConfig_MaskType mask_type) { + switch (mask_type) { + case CudnnfMHABackendConfig::NO_MASK: + return CudnnfMHAMaskKind::kNoMask; + case CudnnfMHABackendConfig::PADDING: + return CudnnfMHAMaskKind::kPadding; + case CudnnfMHABackendConfig::CAUSAL: + return CudnnfMHAMaskKind::kCausal; + case CudnnfMHABackendConfig::PADDING_CAUSAL: + return CudnnfMHAMaskKind::kPaddingCausal; + case CudnnfMHABackendConfig::ALIBI: + return CudnnfMHAMaskKind::kAlibi; + default: + return xla::Internal("Unknown fmha mask kind."); + } +} + +// This is an interim structure to hold the parameters to construct a +// GpufMHAConfig. +// Struct to describe properties of a FMHA without being tied to specific +// IR. Will be used to help build FMHA thunks from either XLA HLO or +// LHLO GPU dialect in MLIR. +struct GpufMHADescriptor { + CudnnfMHAKind kind; + CudnnfMHABackendConfig backend_config; + CudnnfMHAMaskKind mask_type; + Shape lhs_bmm1_shape; + Shape rhs_bmm1_shape; + Shape rhs_bmm2_shape; + Shape intermediate_lhs_bmm2_shape; + // This will contain both output shape and activation shape + absl::InlinedVector output_shapes; + DotDimensionNumbers bmm1_dnums; + DotDimensionNumbers bmm2_dnums; + + std::optional mask_shape; + std::optional bias_shape; +}; + +struct GpufMHABackwardDescriptor { + CudnnfMHAKind kind; + CudnnfMHABackendConfig backend_config; + CudnnfMHAMaskKind mask_type; + Shape bmm1_grad_gemm1_rhs_shape; + Shape bmm1_grad_gemm2_rhs_shape; + Shape bmm2_grad_gemm1_lhs_shape; + Shape bmm2_grad_gemm2_rhs_shape; + Shape d_output_shape; + Shape d_bmm1_lhs_shape; + Shape d_bmm1_rhs_shape; + Shape d_bmm2_rhs_shape; + DotDimensionNumbers bmm1_grad_gemm1_dnums; + DotDimensionNumbers bmm1_grad_gemm2_dnums; + DotDimensionNumbers bmm2_grad_gemm1_dnums; + DotDimensionNumbers bmm2_grad_gemm2_dnums; + + std::optional d_s_shape; + std::optional fwd_output_shape; + std::optional mask_shape; + std::optional d_bias_shape; + std::optional bias_shape; + bool force_deterministic; +}; + +// Structure to describe static properties of a GPU fused Multi-Headed +// Attention. +struct GpufMHAConfig { + static absl::StatusOr For(const GpufMHADescriptor &fmha_desc); + PrimitiveType + input_type; // Capture the primitive type of one of the inputs of BMM1 + PrimitiveType output_type; + CudnnfMHAKind kind; + std::optional fmha_scale; + std::optional dropout_rate; + std::optional seed; + + se::dnn::AlgorithmDesc algorithm; + CudnnfMHAMaskKind mask_type; + // bias -> [1, num_attn_heads, q_seq_len, kv_seq_len] + // mask -> [batch_size, 1, q_seq_len, kv_seq_len] + se::dnn::MatmulTensorDescriptor lhs_bmm1; + se::dnn::MatmulTensorDescriptor rhs_bmm1; + se::dnn::MatmulTensorDescriptor rhs_bmm2; + se::dnn::MatmulTensorDescriptor intermediate_lhs_bmm2; + se::dnn::TensorDescriptor output; + + std::optional activation; + std::optional mask; + std::optional bias; +}; + +// Structure to describe static properties of a GPU fused Multi-Headed +// Attention backward. +struct GpufMHABackwardConfig { + static absl::StatusOr For( + const GpufMHABackwardDescriptor &fmha_desc); + PrimitiveType + input_type; // Capture the primitive type of one of the inputs of BMM1 + PrimitiveType output_type; + CudnnfMHAKind kind; + std::optional fmha_scale; + std::optional dropout_rate; + std::optional seed; + + se::dnn::AlgorithmDesc algorithm; + CudnnfMHAMaskKind mask_type; + // mask -> [batch_size, 1, q_seq_len, kv_seq_len] + // d_bias -> [1, num_heads, q_seq_len, kv_seq_len] + se::dnn::MatmulTensorDescriptor bmm1_grad_gemm1_rhs; + se::dnn::MatmulTensorDescriptor bmm1_grad_gemm2_rhs; + se::dnn::MatmulTensorDescriptor bmm2_grad_gemm1_lhs; + se::dnn::MatmulTensorDescriptor bmm2_grad_gemm2_rhs; + se::dnn::MatmulTensorDescriptor d_output; + se::dnn::TensorDescriptor d_bmm1_lhs; + se::dnn::TensorDescriptor d_bmm1_rhs; + se::dnn::TensorDescriptor d_bmm2_rhs; + std::optional d_s; + std::optional mask; + std::optional d_bias; + std::optional fwd_output; + std::optional bias; +}; + +using se::DeviceMemory; +using se::DeviceMemoryBase; +using se::dnn::DataType; +using se::dnn::MatmulTensorDescriptor; +using se::dnn::TensorDescriptor; + +/*static*/ absl::StatusOr GpufMHAConfig::For( + const GpufMHADescriptor &desc) { + // Get shapes from desc. + const Shape &lhs_bmm1_shape = desc.lhs_bmm1_shape; + const Shape &rhs_bmm1_shape = desc.rhs_bmm1_shape; + const Shape &rhs_bmm2_shape = desc.rhs_bmm2_shape; + const Shape &intermediate_lhs_bmm2_shape = desc.intermediate_lhs_bmm2_shape; + const Shape &output_shape = desc.output_shapes[0]; + + // Get DNN dtype from primtive types + TF_ASSIGN_OR_RETURN( + DataType lhs_bmm1_type, + GetDNNDataTypeFromPrimitiveType(lhs_bmm1_shape.element_type())); + TF_ASSIGN_OR_RETURN( + DataType rhs_bmm1_type, + GetDNNDataTypeFromPrimitiveType(rhs_bmm1_shape.element_type())); + + TF_ASSIGN_OR_RETURN( + DataType rhs_bmm2_type, + GetDNNDataTypeFromPrimitiveType(rhs_bmm2_shape.element_type())); + TF_ASSIGN_OR_RETURN(DataType lhs_bmm2_type, + GetDNNDataTypeFromPrimitiveType( + intermediate_lhs_bmm2_shape.element_type())); + TF_ASSIGN_OR_RETURN(DataType output_type, GetDNNDataTypeFromPrimitiveType( + output_shape.element_type())); + GpufMHAConfig config; + config.input_type = lhs_bmm1_shape.element_type(); + config.output_type = output_shape.element_type(); + + // Get MatmulTensorDescriptors for BMM1 + config.lhs_bmm1 = + MatmulTensorDescriptor::For(lhs_bmm1_type, lhs_bmm1_shape.dimensions(), + desc.lhs_bmm1_shape.layout().minor_to_major(), + desc.bmm1_dnums.lhs_batch_dimensions(), + desc.bmm1_dnums.lhs_contracting_dimensions()); + config.rhs_bmm1 = + MatmulTensorDescriptor::For(rhs_bmm1_type, rhs_bmm1_shape.dimensions(), + desc.rhs_bmm1_shape.layout().minor_to_major(), + desc.bmm1_dnums.rhs_batch_dimensions(), + desc.bmm1_dnums.rhs_contracting_dimensions()); + + // Get MatmulTensorDescriptors for BMM2 + config.rhs_bmm2 = + MatmulTensorDescriptor::For(rhs_bmm2_type, rhs_bmm2_shape.dimensions(), + desc.rhs_bmm2_shape.layout().minor_to_major(), + desc.bmm2_dnums.rhs_batch_dimensions(), + desc.bmm2_dnums.rhs_contracting_dimensions()); + + config.intermediate_lhs_bmm2 = MatmulTensorDescriptor::For( + lhs_bmm2_type, intermediate_lhs_bmm2_shape.dimensions(), + desc.intermediate_lhs_bmm2_shape.layout().minor_to_major(), + desc.bmm2_dnums.lhs_batch_dimensions(), + desc.bmm2_dnums.lhs_contracting_dimensions()); + + config.output = TensorDescriptor::For(output_type, output_shape.dimensions(), + output_shape.layout().minor_to_major()); + + if (desc.output_shapes.size() > 1) { + const Shape &activation_shape = desc.output_shapes.back(); + // Generally, activation should have same type as output, but set it + // explicityly just to be safe. + TF_ASSIGN_OR_RETURN( + DataType activation_type, + GetDNNDataTypeFromPrimitiveType(activation_shape.element_type())); + config.activation = + TensorDescriptor::For(activation_type, activation_shape.dimensions(), + activation_shape.layout().minor_to_major()); + } + + if (desc.mask_shape) { + const Shape &mask_shape = *desc.mask_shape; + TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType( + mask_shape.element_type())); + config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(), + mask_shape.layout().minor_to_major()); + } + + if (desc.bias_shape) { + const Shape &bias_shape = *desc.bias_shape; + TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType( + bias_shape.element_type())); + config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(), + bias_shape.layout().minor_to_major()); + } + config.kind = desc.kind; + config.mask_type = desc.mask_type; + const CudnnfMHABackendConfig &backend_config = desc.backend_config; + config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); + config.fmha_scale.emplace(backend_config.fmha_scale()); + config.dropout_rate.emplace(backend_config.dropout_rate()); + config.seed.emplace(backend_config.seed()); + return config; +} + +/*static*/ absl::StatusOr GpufMHABackwardConfig::For( + const GpufMHABackwardDescriptor &desc) { + // Get shapes from desc. + const Shape &bmm1_grad_gemm1_rhs_shape = desc.bmm1_grad_gemm1_rhs_shape; + const Shape &bmm1_grad_gemm2_rhs_shape = desc.bmm1_grad_gemm2_rhs_shape; + const Shape &bmm2_grad_gemm1_lhs_shape = desc.bmm2_grad_gemm1_lhs_shape; + const Shape &bmm2_grad_gemm2_rhs_shape = desc.bmm2_grad_gemm2_rhs_shape; + const Shape &d_output_shape = desc.d_output_shape; + const Shape &d_bmm1_lhs_shape = desc.d_bmm1_lhs_shape; + const Shape &d_bmm1_rhs_shape = desc.d_bmm1_rhs_shape; + const Shape &d_bmm2_rhs_shape = desc.d_bmm2_rhs_shape; + // Get DNN dtype from primtive types + TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm1_rhs_type, + GetDNNDataTypeFromPrimitiveType( + bmm1_grad_gemm1_rhs_shape.element_type())); + + TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm2_rhs_type, + GetDNNDataTypeFromPrimitiveType( + bmm1_grad_gemm2_rhs_shape.element_type())); + + TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm1_lhs_type, + GetDNNDataTypeFromPrimitiveType( + bmm2_grad_gemm1_lhs_shape.element_type())); + + TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm2_rhs_type, + GetDNNDataTypeFromPrimitiveType( + bmm2_grad_gemm2_rhs_shape.element_type())); + + TF_ASSIGN_OR_RETURN( + DataType d_output_type, + GetDNNDataTypeFromPrimitiveType(d_output_shape.element_type())); + + TF_ASSIGN_OR_RETURN( + DataType d_bmm1_lhs_type, + GetDNNDataTypeFromPrimitiveType(d_bmm1_lhs_shape.element_type())); + + TF_ASSIGN_OR_RETURN( + DataType d_bmm1_rhs_type, + GetDNNDataTypeFromPrimitiveType(d_bmm1_rhs_shape.element_type())); + + TF_ASSIGN_OR_RETURN( + DataType d_bmm2_rhs_type, + GetDNNDataTypeFromPrimitiveType(d_bmm2_rhs_shape.element_type())); + + GpufMHABackwardConfig config; + config.input_type = bmm1_grad_gemm1_rhs_shape.element_type(); + config.output_type = d_bmm1_lhs_shape.element_type(); + + // Get MatmulTensorDescriptors for lhs of BMM1 grad GEMM 1 + config.bmm1_grad_gemm1_rhs = MatmulTensorDescriptor::For( + bmm1_grad_gemm1_rhs_type, bmm1_grad_gemm1_rhs_shape.dimensions(), + desc.bmm1_grad_gemm1_rhs_shape.layout().minor_to_major(), + desc.bmm1_grad_gemm1_dnums.rhs_batch_dimensions(), + desc.bmm1_grad_gemm1_dnums.rhs_contracting_dimensions()); + + // Get MatmulTensorDescriptors for rhs of BMM1 grad GEMM 2 + config.bmm1_grad_gemm2_rhs = MatmulTensorDescriptor::For( + bmm1_grad_gemm2_rhs_type, bmm1_grad_gemm2_rhs_shape.dimensions(), + desc.bmm1_grad_gemm2_rhs_shape.layout().minor_to_major(), + desc.bmm1_grad_gemm2_dnums.rhs_batch_dimensions(), + desc.bmm1_grad_gemm2_dnums.rhs_contracting_dimensions()); + + // Get MatmulTensorDescriptors for BMM2 grad GEMM 1 + config.bmm2_grad_gemm1_lhs = MatmulTensorDescriptor::For( + bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(), + desc.bmm2_grad_gemm1_lhs_shape.layout().minor_to_major(), + desc.bmm2_grad_gemm1_dnums.lhs_batch_dimensions(), + desc.bmm2_grad_gemm1_dnums.lhs_contracting_dimensions()); + + config.d_output = MatmulTensorDescriptor::For( + d_output_type, d_output_shape.dimensions(), + desc.d_output_shape.layout().minor_to_major(), + desc.bmm2_grad_gemm1_dnums.rhs_batch_dimensions(), + desc.bmm2_grad_gemm1_dnums.rhs_contracting_dimensions()); + + // Get MatmulTensorDescriptors for BMM2 grad GEMM 2 + config.bmm2_grad_gemm2_rhs = MatmulTensorDescriptor::For( + bmm2_grad_gemm2_rhs_type, bmm2_grad_gemm2_rhs_shape.dimensions(), + desc.bmm2_grad_gemm2_rhs_shape.layout().minor_to_major(), + desc.bmm2_grad_gemm2_dnums.rhs_batch_dimensions(), + desc.bmm2_grad_gemm2_dnums + .rhs_contracting_dimensions()); // FMHA TODO: transpose here? + + config.d_bmm1_lhs = + TensorDescriptor::For(d_bmm1_lhs_type, d_bmm1_lhs_shape.dimensions(), + d_bmm1_lhs_shape.layout().minor_to_major()); + config.d_bmm1_rhs = + TensorDescriptor::For(d_bmm1_rhs_type, d_bmm1_rhs_shape.dimensions(), + d_bmm1_rhs_shape.layout().minor_to_major()); + config.d_bmm2_rhs = + TensorDescriptor::For(d_bmm2_rhs_type, d_bmm2_rhs_shape.dimensions(), + d_bmm2_rhs_shape.layout().minor_to_major()); + config.d_s = TensorDescriptor::For( + bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(), + bmm2_grad_gemm1_lhs_shape.layout().minor_to_major()); + + if (desc.d_bias_shape) { + const Shape &d_bias_shape = *desc.d_bias_shape; + // Get DNN dtype from primtive types + TF_ASSIGN_OR_RETURN(DataType d_bias_type, GetDNNDataTypeFromPrimitiveType( + d_bias_shape.element_type())); + config.d_bias = + TensorDescriptor::For(d_bias_type, d_bias_shape.dimensions(), + d_bias_shape.layout().minor_to_major()); + } + + if (desc.mask_shape) { + const Shape &mask_shape = *desc.mask_shape; + TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType( + mask_shape.element_type())); + config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(), + mask_shape.layout().minor_to_major()); + } + if (desc.fwd_output_shape) { + const Shape &fwd_output_shape = *desc.fwd_output_shape; + TF_ASSIGN_OR_RETURN( + DataType fwd_output_type, + GetDNNDataTypeFromPrimitiveType(fwd_output_shape.element_type())); + config.fwd_output = + TensorDescriptor::For(fwd_output_type, fwd_output_shape.dimensions(), + fwd_output_shape.layout().minor_to_major()); + } + + if (desc.bias_shape) { + const Shape &bias_shape = *desc.bias_shape; + TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType( + bias_shape.element_type())); + config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(), + bias_shape.layout().minor_to_major()); + } + + config.kind = desc.kind; + config.mask_type = desc.mask_type; + const CudnnfMHABackendConfig &backend_config = desc.backend_config; + config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); + config.fmha_scale.emplace(backend_config.fmha_scale()); + config.dropout_rate.emplace(backend_config.dropout_rate()); + config.seed.emplace(backend_config.seed()); + return config; +} + +absl::StatusOr HloCustomCallToCuDnnGraph( + se::dnn::DnnSupport &dnn_support, HloCustomCallInstruction *custom_call) { + if (IsFwdCustomCallTofMHA(*custom_call)) { + TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, + xla::gpu::GetCudnnfMHAKind(custom_call)); + std::optional mask_shape, bias_shape; + { + bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax || + kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout; + + if (has_bias) { + const HloInstruction *bias = custom_call->operand(3); + bias_shape = bias->shape(); + } + } + + TF_ASSIGN_OR_RETURN( + const auto gpu_config, + custom_call->backend_config()); + const xla::gpu::CudnnfMHABackendConfig &config = + gpu_config.cudnn_fmha_backend_config(); + Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); + absl::InlinedVector output_shapes = { + ShapeUtil::GetSubshape(custom_call->shape(), {0})}; + + bool has_activation = + xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 3; + if (has_activation) { + output_shapes.push_back( + ShapeUtil::GetSubshape(custom_call->shape(), {1})); + } + + Shape q_shape = custom_call->operand(0)->shape(); + Shape k_shape = custom_call->operand(1)->shape(); + Shape v_shape = custom_call->operand(2)->shape(); + TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); + GpufMHADescriptor descriptor = {kind, + config, + cudnn_mask_type, + q_shape, + k_shape, + v_shape, + intermediate_tensor_shape, + output_shapes, + config.bmm1_dot_dimension_numbers(), + config.bmm2_dot_dimension_numbers(), + mask_shape, + bias_shape}; + + TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config, + GpufMHAConfig::For(descriptor)); + TF_ASSIGN_OR_RETURN( + se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type)); + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + se::gpu::GetCudnnFlashAttentionOperationGraph( + dnn_support, fmha_config.lhs_bmm1, fmha_config.rhs_bmm1, + fmha_config.rhs_bmm2, fmha_config.output, fmha_config.bias, + fmha_config.activation, static_cast(*fmha_config.fmha_scale), + fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0, + fmha_config.dropout_rate, dnn_mask_type)); + return std::move(graph); + } else { + TF_ASSIGN_OR_RETURN( + auto gpu_config, + custom_call->backend_config()); + xla::gpu::CudnnfMHABackendConfig &config = + *gpu_config.mutable_cudnn_fmha_backend_config(); + + int input_index = 0; + Shape bmm1_grad_gemm1_rhs_shape = + custom_call->operand(input_index++)->shape(); + Shape bmm1_grad_gemm2_rhs_shape = + custom_call->operand(input_index++)->shape(); + Shape bmm2_grad_gemm2_rhs_shape = + custom_call->operand(input_index++)->shape(); + Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape()); + input_index++; + Shape d_output_shape = custom_call->operand(input_index++)->shape(); + + TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, + GetCudnnfMHAKind(custom_call)); + std::optional mask_shape; + + bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax || + kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout); + std::optional bias_shape; + if (has_bias) { + bias_shape = custom_call->operand(input_index++)->shape(); + } + + std::optional fwd_output_shape = + custom_call->operand(input_index++)->shape(); + if (config.mask_type() == xla::gpu::CudnnfMHABackendConfig::PADDING || + config.mask_type() == + xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL) { + // skip q_seqlen and kv_seqlen + input_index += 2; + } + TF_RET_CHECK(input_index == custom_call->operand_count()); + + int output_index = 0; + Shape d_bmm1_lhs_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + Shape d_bmm1_rhs_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + Shape d_bmm2_rhs_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + std::optional d_s_shape; + std::optional d_bias_shape; + bool has_dbias = custom_call->shape().tuple_shapes().size() == 5; + if (has_dbias) { + d_bias_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + } + // The last one is the workspace. + TF_RET_CHECK(output_index == + custom_call->shape().tuple_shapes().size() - 1); + TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); + + const DebugOptions &debug_options = + custom_call->GetModule()->config().debug_options(); + bool force_deterministic = + debug_options.xla_gpu_deterministic_ops() || + debug_options.xla_gpu_exclude_nondeterministic_ops(); + config.set_force_deterministic(force_deterministic); + TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); + + GpufMHABackwardDescriptor descriptor = { + kind, + config, + cudnn_mask_type, + bmm1_grad_gemm1_rhs_shape, + bmm1_grad_gemm2_rhs_shape, + bmm2_grad_gemm1_lhs_shape, + bmm2_grad_gemm2_rhs_shape, + d_output_shape, + d_bmm1_lhs_shape, + d_bmm1_rhs_shape, + d_bmm2_rhs_shape, + config.bmm1_grad_gemm1_dot_dimension_numbers(), + config.bmm1_grad_gemm2_dot_dimension_numbers(), + config.bmm2_grad_gemm1_dot_dimension_numbers(), + config.bmm2_grad_gemm2_dot_dimension_numbers(), + d_s_shape, + fwd_output_shape, + mask_shape, + d_bias_shape, + bias_shape, + force_deterministic}; + + TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_config, + GpufMHABackwardConfig::For(descriptor)); + TF_ASSIGN_OR_RETURN( + se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type)); + + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + se::gpu::GetCudnnFlashAttentionBackwardOperationGraph( + dnn_support, fmha_config.bmm1_grad_gemm1_rhs, + fmha_config.bmm1_grad_gemm2_rhs, fmha_config.bmm2_grad_gemm1_lhs, + fmha_config.bmm2_grad_gemm2_rhs, fmha_config.d_output, + fmha_config.d_bmm1_lhs, fmha_config.d_bmm1_rhs, + fmha_config.d_bmm2_rhs, fmha_config.bias, fmha_config.dropout_rate, + fmha_config.seed, *fmha_config.fmha_scale, + fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0, + fmha_config.bias != std::nullopt, dnn_mask_type, + force_deterministic)); + return std::move(graph); + } +} + +class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor { + public: + explicit CuDnnCustomCallVisitor(se::dnn::DnnSupport &dnn_support, + BinaryMap &compilation_results) + : dnn_support_(dnn_support), compilation_results_(compilation_results) {} + + void AddWorkspace(HloInstruction &hlo, int64_t workspace_size) { + if (workspace_size == 0) { + return; + } + VLOG(4) << "Applying workspace size " << workspace_size << " to " + << hlo.ToString(); + Shape *shape = hlo.mutable_shape(); + shape->mutable_tuple_shapes()->back().set_dimensions(0, workspace_size); + } + + absl::Status HandleCustomCall(HloInstruction *hlo) override { + if (!IsCustomCallTofMHA(*hlo)) { + return absl::OkStatus(); + } + + TF_ASSIGN_OR_RETURN(const std::string fingerprint_without_workspace, + FingerprintWithBackendConfig(*hlo)); + auto workspace_size_it = + workspace_sizes_.find(fingerprint_without_workspace); + if (workspace_size_it == workspace_sizes_.cend()) { + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + HloCustomCallToCuDnnGraph(dnn_support_, + DynCast(hlo))); + + const int64_t workspace_size = graph.Graph().get_workspace_size(); + workspace_sizes_.insert(workspace_size_it, + {fingerprint_without_workspace, workspace_size}); + AddWorkspace(*hlo, workspace_size); + + std::vector serialized_graph; + RETURN_IF_CUDNN_FRONTEND_ERROR(graph.Graph().serialize(serialized_graph)); + // Compute a new fingerprint with a potential workspace for the + // compilation results to match a fingerprint computed by the emitter. + TF_ASSIGN_OR_RETURN(const std::string fingerprint_with_workspace, + FingerprintWithBackendConfig(*hlo)); + compilation_results_[fingerprint_with_workspace] = + std::string(reinterpret_cast(serialized_graph.data()), + serialized_graph.size()); + } else { + VLOG(4) << "Cache hit."; + AddWorkspace(*hlo, workspace_size_it->second); + } + + MarkAsChanged(); + return absl::OkStatus(); + } + + private: + se::dnn::DnnSupport &dnn_support_; + BinaryMap &compilation_results_; + absl::flat_hash_map workspace_sizes_; +}; + +} // namespace + +absl::StatusOr CuDnnCustomCallCompiler::Run( + HloModule *module, + const absl::flat_hash_set &execution_threads) { + XLA_SCOPED_LOGGING_TIMER_LEVEL("cuDNN custom call compiler", 8); + return CuDnnCustomCallVisitor(dnn_support_, compilation_results_) + .RunOnModule(module, execution_threads); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h similarity index 61% rename from third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h index 962841289b58dc..810286f91b8472 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_WORKSPACE_REWRITER_H_ -#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_WORKSPACE_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/hlo_pass_interface.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream_executor.h" @@ -27,14 +28,18 @@ limitations under the License. namespace xla { namespace gpu { -// Rewrite cuDNN custom call to have correct workspace size by build graph -// and serialize so we can use it later -class CuDnnWorkspaceRewriter : public HloModulePass { +// Compile cuDNN custom calls to binaries and serialize them. +// Also adjust them in HLO to have correct workspace size. +class CuDnnCustomCallCompiler : public HloModulePass { public: - explicit CuDnnWorkspaceRewriter(se::StreamExecutor& stream_exec) - : dnn_support_(*stream_exec.AsDnn()) {} + explicit CuDnnCustomCallCompiler(se::StreamExecutor& stream_exec, + BinaryMap& compilation_results) + : dnn_support_(*stream_exec.AsDnn()), + compilation_results_(compilation_results) {} - absl::string_view name() const override { return "cudnn-workspace-rewriter"; } + absl::string_view name() const override { + return "cudnn-custom-call-compiler"; + } using HloPassInterface::Run; absl::StatusOr Run( @@ -43,9 +48,10 @@ class CuDnnWorkspaceRewriter : public HloModulePass { private: se::dnn::DnnSupport& dnn_support_; + BinaryMap& compilation_results_; }; } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_WORKSPACE_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.cc deleted file mode 100644 index b5440a8a2af53f..00000000000000 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_workspace_rewriter.cc +++ /dev/null @@ -1,272 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/transforms/cudnn_workspace_rewriter.h" - -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_clone_context.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/stream_executor/cuda/cuda_dnn.h" -#include "xla/stream_executor/dnn.h" -#include "xla/util.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { - -namespace { - -// create cuDNN graphs from HloCustomCall -absl::StatusOr HloCustomCallToCuDnnGraph( - se::dnn::DnnSupport& dnn_support, HloCustomCallInstruction* custom_call) { - if (IsFwdCustomCallTofMHA(*custom_call)) { - TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, - xla::gpu::GetCudnnfMHAKind(custom_call)); - std::optional mask_shape, bias_shape; - { - bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax || - kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout; - - if (has_bias) { - const HloInstruction* bias = custom_call->operand(3); - bias_shape = bias->shape(); - } - } - - TF_ASSIGN_OR_RETURN( - const auto gpu_config, - custom_call->backend_config()); - const xla::gpu::CudnnfMHABackendConfig& config = - gpu_config.cudnn_fmha_backend_config(); - Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); - absl::InlinedVector output_shapes = { - ShapeUtil::GetSubshape(custom_call->shape(), {0})}; - - bool has_activation = - xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 3; - if (has_activation) { - output_shapes.push_back( - ShapeUtil::GetSubshape(custom_call->shape(), {1})); - } - - Shape q_shape = custom_call->operand(0)->shape(); - Shape k_shape = custom_call->operand(1)->shape(); - Shape v_shape = custom_call->operand(2)->shape(); - TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - GpufMHADescriptor descriptor = {kind, - config, - cudnn_mask_type, - q_shape, - k_shape, - v_shape, - intermediate_tensor_shape, - output_shapes, - config.bmm1_dot_dimension_numbers(), - config.bmm2_dot_dimension_numbers(), - mask_shape, - bias_shape}; - - TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config, - GpufMHAConfig::For(descriptor)); - TF_ASSIGN_OR_RETURN( - se::dnn::FMHAMaskKind dnn_mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type)); - TF_ASSIGN_OR_RETURN( - se::gpu::CudnnGraph graph, - se::gpu::GetCudnnFlashAttentionOperationGraph( - dnn_support, fmha_config.lhs_bmm1, fmha_config.rhs_bmm1, - fmha_config.rhs_bmm2, fmha_config.output, fmha_config.bias, - fmha_config.activation, static_cast(*fmha_config.fmha_scale), - fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0, - fmha_config.dropout_rate, dnn_mask_type)); - return std::move(graph); - } else { - TF_ASSIGN_OR_RETURN( - auto gpu_config, - custom_call->backend_config()); - xla::gpu::CudnnfMHABackendConfig& config = - *gpu_config.mutable_cudnn_fmha_backend_config(); - - int input_index = 0; - Shape bmm1_grad_gemm1_rhs_shape = - custom_call->operand(input_index++)->shape(); - Shape bmm1_grad_gemm2_rhs_shape = - custom_call->operand(input_index++)->shape(); - Shape bmm2_grad_gemm2_rhs_shape = - custom_call->operand(input_index++)->shape(); - Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape()); - input_index++; - Shape d_output_shape = custom_call->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, - GetCudnnfMHAKind(custom_call)); - std::optional mask_shape; - - bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax || - kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout); - std::optional bias_shape; - if (has_bias) { - bias_shape = custom_call->operand(input_index++)->shape(); - } - - std::optional fwd_output_shape = - custom_call->operand(input_index++)->shape(); - if (config.mask_type() == xla::gpu::CudnnfMHABackendConfig::PADDING || - config.mask_type() == - xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL) { - // skip q_seqlen and kv_seqlen - input_index += 2; - } - TF_RET_CHECK(input_index == custom_call->operand_count()); - - int output_index = 0; - Shape d_bmm1_lhs_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - Shape d_bmm1_rhs_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - Shape d_bmm2_rhs_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - std::optional d_s_shape; - std::optional d_bias_shape; - bool has_dbias = custom_call->shape().tuple_shapes().size() == 5; - if (has_dbias) { - d_bias_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - } - // The last one is the workspace. - TF_RET_CHECK(output_index == - custom_call->shape().tuple_shapes().size() - 1); - TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - - const bool force_deterministic = - RequireDeterminism(custom_call->GetModule()->config()); - // set the correct force_deterministic attribute here - config.set_force_deterministic(force_deterministic); - TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); - - GpufMHABackwardDescriptor descriptor = { - kind, - config, - cudnn_mask_type, - bmm1_grad_gemm1_rhs_shape, - bmm1_grad_gemm2_rhs_shape, - bmm2_grad_gemm1_lhs_shape, - bmm2_grad_gemm2_rhs_shape, - d_output_shape, - d_bmm1_lhs_shape, - d_bmm1_rhs_shape, - d_bmm2_rhs_shape, - config.bmm1_grad_gemm1_dot_dimension_numbers(), - config.bmm1_grad_gemm2_dot_dimension_numbers(), - config.bmm2_grad_gemm1_dot_dimension_numbers(), - config.bmm2_grad_gemm2_dot_dimension_numbers(), - d_s_shape, - fwd_output_shape, - mask_shape, - d_bias_shape, - bias_shape, - force_deterministic}; - - TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_config, - GpufMHABackwardConfig::For(descriptor)); - TF_ASSIGN_OR_RETURN( - se::dnn::FMHAMaskKind dnn_mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type)); - - TF_ASSIGN_OR_RETURN( - se::gpu::CudnnGraph graph, - se::gpu::GetCudnnFlashAttentionBackwardOperationGraph( - dnn_support, fmha_config.bmm1_grad_gemm1_rhs, - fmha_config.bmm1_grad_gemm2_rhs, fmha_config.bmm2_grad_gemm1_lhs, - fmha_config.bmm2_grad_gemm2_rhs, fmha_config.d_output, - fmha_config.d_bmm1_lhs, fmha_config.d_bmm1_rhs, - fmha_config.d_bmm2_rhs, fmha_config.bias, fmha_config.dropout_rate, - fmha_config.seed, *fmha_config.fmha_scale, - fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0, - fmha_config.bias != std::nullopt, dnn_mask_type, - force_deterministic)); - return std::move(graph); - } -} - -class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor { - public: - explicit CuDnnCustomCallVisitor(se::dnn::DnnSupport& dnn_support) - : dnn_support_(dnn_support) {} - - absl::Status HandleCustomCall(HloInstruction* hlo) override { - if (!IsCustomCallTofMHA(*hlo)) { - // don't do anything about other cuDNN custom calls - return absl::OkStatus(); - } - TF_ASSIGN_OR_RETURN(auto gpu_config, - hlo->backend_config()); - - TF_ASSIGN_OR_RETURN( - se::gpu::CudnnGraph graph, - HloCustomCallToCuDnnGraph(dnn_support_, - DynCast(hlo))); - auto workspace = graph.Graph().get_workspace_size(); - if (workspace != 0) { - // rewrite custom call to have correct workspace size - VLOG(4) << "Rewriting: " << hlo->ToString(); - Shape* shape = hlo->mutable_shape(); - shape->mutable_tuple_shapes(shape->tuple_shapes_size() - 1) - ->set_dimensions(0, workspace); - MarkAsChanged(); - } - return absl::OkStatus(); - } - - private: - se::dnn::DnnSupport& dnn_support_; -}; - -} // namespace - -absl::StatusOr CuDnnWorkspaceRewriter::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - XLA_SCOPED_LOGGING_TIMER("cuDNN workspace rewriter"); - return CuDnnCustomCallVisitor(dnn_support_) - .RunOnModule(module, execution_threads); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index 90640a839dcf23..440f647b84f1ce 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -40,7 +40,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -3762,32 +3761,6 @@ absl::StatusOr CreateCudnnTensor( } #if CUDNN_VERSION >= 8800 -enum CudnnfMHAUid { - Q_ID = 400, - K_ID, - V_ID, - P_ID, - O_ID, - dQ_ID, - dK_ID, - dV_ID, - dP_ID, - dO_ID, - dS_ID, - dBIAS_ID, - BIAS_ID, - MASK_ID, - ZERO_VAL_ID, - ONE_VAL_ID, - NEG_INFINITY_ID, - ALPHA_SCALE_ID, - DROPOUT_SCALE_ID, - Q_SEQLEN_ID, - K_SEQLEN_ID, - D_OFFSET_ID, - D_SEED_ID, - VIRTUAL_ID = 34857 -}; absl::StatusOr CreatePwDesc( dnn::DataType dtype, cudnnPointwiseMode_t mode) { @@ -5032,12 +5005,14 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_io_data_type(ioDataType) .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; + std::shared_ptr q_tensor = graph.tensor(Tensor_attributes() .set_name("Q") .set_dim(q_descriptor.GetCudnnCompatibleDimensions(true)) .set_stride(q_descriptor.GetCudnnCompatibleStrides(true)) - .set_uid(CudnnfMHAUid::Q_ID)); + .set_uid(next_uid())); auto dim = k_descriptor.GetCudnnCompatibleDimensions(true); std::shared_ptr k_tensor = @@ -5045,13 +5020,13 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_name("K") .set_dim(k_descriptor.GetCudnnCompatibleDimensions(true)) .set_stride(k_descriptor.GetCudnnCompatibleStrides(true)) - .set_uid(CudnnfMHAUid::K_ID)); + .set_uid(next_uid())); std::shared_ptr v_tensor = graph.tensor( Tensor_attributes() .set_name("V") .set_dim(v_descriptor.GetCudnnCompatibleDimensions(false)) .set_stride(v_descriptor.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::V_ID)); + .set_uid(next_uid())); // Setting sdpa, and is_inference bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || @@ -5069,7 +5044,7 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_name("bias") .set_dim(bias_descriptor->dimensions()) .set_stride(bias_descriptor->GetLogicalStrides()) - .set_uid(CudnnfMHAUid::BIAS_ID)); + .set_uid(next_uid())); sdpa_options.set_bias(bias_tensor); } // Setting actual seqlen @@ -5083,37 +5058,38 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_name("seq_q") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::Q_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); auto seq_kv_tensor = graph.tensor(Tensor_attributes() .set_name("seq_kv") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::K_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); sdpa_options.set_padding_mask(true); sdpa_options.set_seq_len_q(seq_q_tensor); sdpa_options.set_seq_len_kv(seq_kv_tensor); } // Setting seed and offset + std::shared_ptr seed_tensor; + std::shared_ptr offset_tensor; if (use_dropout) { - auto seed_tensor = + // Skip setting UIDs: pass by value tensors go at the end. + seed_tensor = graph.tensor(Tensor_attributes() .set_name("seed") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_SEED_ID)); - auto offset_tensor = + .set_is_pass_by_value(true)); + offset_tensor = graph.tensor(Tensor_attributes() .set_name("offset") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_OFFSET_ID)); + .set_is_pass_by_value(true)); sdpa_options.set_dropout((float)dropout_rate.value(), seed_tensor, offset_tensor); } @@ -5127,7 +5103,7 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_output(true) .set_dim(o_descriptor.dimensions()) .set_stride(o_descriptor.GetLogicalStrides()) - .set_uid(CudnnfMHAUid::O_ID); + .set_uid(next_uid()); if (stats_descriptor.has_value()) { cudnn_frontend::DataType_t statsType = ToCudnnFrontendDataType(stats_descriptor->type()); @@ -5140,7 +5116,13 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_data_type(statsType) .set_dim(stat_dims) .set_stride(stat_strides) - .set_uid(CudnnfMHAUid::P_ID); + .set_uid(next_uid()); + } + if (seed_tensor != nullptr) { + seed_tensor->set_uid(next_uid()); + } + if (offset_tensor != nullptr) { + offset_tensor->set_uid(next_uid()); } CudnnGraph cudnnGraph(std::move(graph)); TF_RETURN_IF_ERROR(cudnnGraph.Prepare( @@ -5195,71 +5177,66 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT) .set_io_data_type(ioDataType); + auto p_dims = p_desc.GetCudnnCompatibleDimensions(false); + auto p_strides = p_desc.GetCudnnCompatibleStrides(false); + std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); + p_reduction_dims.push_back(1); + + // Divide every stride by the last dim value. + std::vector p_reduction_strides; + p_reduction_strides.reserve(p_strides.size()); + int64_t p_reduced_dim_len = p_dims.back(); + for (auto stride : p_strides) { + p_reduction_strides.push_back(stride / p_reduced_dim_len); + } + p_reduction_strides[3] = 1; + bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || + mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; + auto sdpa_backward_options = + cudnn_frontend::graph::SDPA_backward_attributes() + .set_name("flash_attention_backward") + .set_causal_mask(is_causal) + .set_attn_scale(scale) + .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + + auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; + std::shared_ptr q = graph.tensor(Tensor_attributes() .set_name("Q") .set_dim(q_desc.GetCudnnCompatibleDimensions(false)) .set_stride(q_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::Q_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr k = graph.tensor(Tensor_attributes() .set_name("K") .set_dim(k_desc.GetCudnnCompatibleDimensions(false)) .set_stride(k_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::K_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr v = graph.tensor(Tensor_attributes() .set_name("V") .set_dim(v_desc.GetCudnnCompatibleDimensions(true)) .set_stride(v_desc.GetCudnnCompatibleStrides(true)) - .set_uid(CudnnfMHAUid::V_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); - std::shared_ptr o = + std::shared_ptr stats = graph.tensor(Tensor_attributes() - .set_name("O") - .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) - .set_stride(do_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::O_ID) - .set_data_type(ioDataType)); + .set_name("stats") + .set_dim(p_reduction_dims) + .set_stride(p_reduction_strides) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::FLOAT)); std::shared_ptr dO = graph.tensor(Tensor_attributes() .set_name("dO") .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) .set_stride(do_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::dO_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); - auto p_dims = p_desc.GetCudnnCompatibleDimensions(false); - auto p_strides = p_desc.GetCudnnCompatibleStrides(false); - std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); - p_reduction_dims.push_back(1); - - // Divide every stride by the last dim value. - std::vector p_reduction_strides; - p_reduction_strides.reserve(p_strides.size()); - int64_t p_reduced_dim_len = p_dims.back(); - for (auto stride : p_strides) { - p_reduction_strides.push_back(stride / p_reduced_dim_len); - } - p_reduction_strides[3] = 1; - std::shared_ptr stats = - graph.tensor(Tensor_attributes() - .set_name("stats") - .set_dim(p_reduction_dims) - .set_stride(p_reduction_strides) - .set_uid(CudnnfMHAUid::P_ID) - .set_data_type(cudnn_frontend::DataType_t::FLOAT)); - bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || - mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - auto sdpa_backward_options = - cudnn_frontend::graph::SDPA_backward_attributes() - .set_name("flash_attention_backward") - .set_causal_mask(is_causal) - .set_attn_scale(scale) - .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); - - // Setting bias + std::shared_ptr d_bias_tensor; if (use_bias) { DCHECK(bias_descriptor != std::nullopt); auto bias_dim = bias_descriptor->dimensions(); @@ -5272,21 +5249,29 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_name("bias") .set_dim(bias_descriptor->dimensions()) .set_stride(bias_descriptor->GetLogicalStrides()) - .set_uid(CudnnfMHAUid::BIAS_ID)); + .set_uid(next_uid())); sdpa_backward_options.set_bias(bias_tensor); // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] are not supported for // dbias calculation but they are supported for forward bias calculation + // Set UID later: this is the last output tuple element. if (b == 1 && n == q_n) { - auto d_bias_tensor = + d_bias_tensor = graph.tensor(Tensor_attributes() .set_name("dBias") .set_dim(bias_descriptor->dimensions()) - .set_stride(bias_descriptor->GetLogicalStrides()) - .set_uid(CudnnfMHAUid::dBIAS_ID)); + .set_stride(bias_descriptor->GetLogicalStrides())); sdpa_backward_options.set_dbias(d_bias_tensor); } } + std::shared_ptr o = + graph.tensor(Tensor_attributes() + .set_name("O") + .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(do_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + // Setting actual seqlen bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; @@ -5298,38 +5283,39 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_name("seq_q") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::Q_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); auto seq_kv_tensor = graph.tensor(Tensor_attributes() .set_name("seq_kv") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::K_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); sdpa_backward_options.set_padding_mask(true); sdpa_backward_options.set_seq_len_q(seq_q_tensor); sdpa_backward_options.set_seq_len_kv(seq_kv_tensor); } // Setting seed and offset + std::shared_ptr seed_tensor; + std::shared_ptr offset_tensor; if (use_dropout) { DCHECK(dropout_rate != std::nullopt); - auto seed_tensor = + // Skip setting UIDs: pass by value tensors go at the end. + seed_tensor = graph.tensor(Tensor_attributes() .set_name("seed") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_SEED_ID)); - auto offset_tensor = + .set_is_pass_by_value(true)); + offset_tensor = graph.tensor(Tensor_attributes() .set_name("offset") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_OFFSET_ID)); + .set_is_pass_by_value(true)); sdpa_backward_options.set_dropout((float)dropout_rate.value(), seed_tensor, offset_tensor); } @@ -5344,21 +5330,30 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( dQ->set_output(true) .set_dim(dq_desc.dimensions()) .set_stride(dq_desc.GetLogicalStrides()) + .set_uid(next_uid()) .set_name("dQ") - .set_uid(CudnnfMHAUid::dQ_ID) .set_data_type(ioDataType); dK->set_output(true) .set_dim(dk_desc.dimensions()) .set_stride(dk_desc.GetLogicalStrides()) + .set_uid(next_uid()) .set_name("dK") - .set_uid(CudnnfMHAUid::dK_ID) .set_data_type(ioDataType); dV->set_output(true) .set_dim(dv_desc.dimensions()) .set_stride(dv_desc.GetLogicalStrides()) + .set_uid(next_uid()) .set_name("dV") - .set_uid(CudnnfMHAUid::dV_ID) .set_data_type(ioDataType); + if (d_bias_tensor != nullptr) { + d_bias_tensor->set_uid(next_uid()); + } + if (seed_tensor != nullptr) { + seed_tensor->set_uid(next_uid()); + } + if (offset_tensor != nullptr) { + offset_tensor->set_uid(next_uid()); + } CudnnGraph cudnnGraph(std::move(graph)); TF_RETURN_IF_ERROR( @@ -5696,8 +5691,8 @@ absl::Status CudnnSupport::DoConvolve( } // Utility for dealing with CUDA's type-erased scaling parameters, where some -// sets of parameters expect a void* pointing at a float while others expect it -// to point at a double. +// sets of parameters expect a void* pointing at a float while others expect +// it to point at a double. // // This is rather ugly, but its purpose is to quarantine the corresponding // ugliness that already exists in the CUDA API. @@ -5721,9 +5716,9 @@ class ScalingParam { // // See // https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters - // for more info; the behavior for int8 result tensors is not described there, - // but is maintained from the existing behavior (namely, using a float scaling - // parameter). + // for more info; the behavior for int8 result tensors is not described + // there, but is maintained from the existing behavior (namely, using a + // float scaling parameter). void* ToVoidPointer(dnn::DataType element_type) { if (element_type == dnn::DataType::kDouble) { return &as_double_; @@ -5795,10 +5790,11 @@ absl::StatusOr> GetDescriptorAttribute( absl::c_transform(result, std::back_inserter(raw_ptrs), [](const BackendDescriptor& ptr) { return ptr.get(); }); - // This API evidently does a deep copy of the descriptors into the pointers in - // the output array, rather than writing pointers to the descriptors into the - // output array. So, this writes the memory behind each BackendDescriptor in - // result, rather than writing the contents of raw_ptrs. + // This API evidently does a deep copy of the descriptors into the pointers + // in the output array, rather than writing pointers to the descriptors into + // the output array. So, this writes the memory behind each + // BackendDescriptor in result, rather than writing the contents of + // raw_ptrs. RETURN_IF_CUDNN_ERROR(cudnnBackendGetAttribute( desc, name, CUDNN_TYPE_BACKEND_DESCRIPTOR, n, &n, raw_ptrs.data())); @@ -5834,9 +5830,9 @@ absl::StatusOr ExecutionPlanToAlgorithmDesc( cudnnBackendGetAttribute(engines[0].get(), CUDNN_ATTR_ENGINE_GLOBAL_INDEX, CUDNN_TYPE_INT64, 1, &n, &engine_id)); - // Apparently for CUDNN_ATTR_ENGINECFG_KNOB_CHOICES only, trying to query the - // number of elements in the attribute by using an output limit value of 0 - // just returns 0; the only way to find out how many there are is to + // Apparently for CUDNN_ATTR_ENGINECFG_KNOB_CHOICES only, trying to query + // the number of elements in the attribute by using an output limit value of + // 0 just returns 0; the only way to find out how many there are is to // pre-allocate space for every existing knob type (as an upper bound on the // number of knob choices a config can have), and then look back at how many // were filled. @@ -6047,103 +6043,7 @@ class CudnnExecutionPlanRunner std::vector scalar_input_uids_; std::vector scalar_input_values_; }; -#endif // CUDNN_VERSION >= 8100 - -template -class CudnnGraphRunner; -// An OpRunner implemented by a cuDNN frontend graph. -// -// This is the class holding the implementation of ToString, GetWorkspaceSize, -// and operator() for use by the cudnn frontend op runners. -template -class CudnnGraphRunner : public dnn::OpRunner { - private: - using Graph = cudnn_frontend::graph::Graph; - using Tensor_attributes = cudnn_frontend::graph::Tensor_attributes; - - public: - std::string ToString() const override { return graph_.Graph().print(); } - - size_t GetWorkspaceSize() const override { - return graph_.Graph().get_workspace_size(); - } - - absl::StatusOr ToAlgorithmDesc() const override { - return absl::InternalError( - "Unexpected call to CudnnGraphRunner::ToAlgorithmDesc"); - } - - absl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, - DeviceMemoryBase scratch_memory, - Args... inputs) const override { - if (parent_ != stream->parent()) { - return tsl::errors::Internal( - "CudnnExecutionPlanRunner cached across multiple StreamExecutors."); - } - CudnnHandle handle = cudnn_->GetHandle(parent_, stream); - std::unordered_map variant_pack; - std::vector vec = {inputs.opaque()...}; - - // add device buffers to the variant pack - for (int i = 0; i < uids_.size(); ++i) { - if (uids_[i].has_value()) { - variant_pack[*uids_[i]] = vec[i]; - } - } - if (dropout_rng_offset_increment_ > 0) { -#if CUDNN_VERSION >= 8800 - variant_pack[CudnnfMHAUid::D_SEED_ID] = (void*)&dropout_rng_seed_; - current_dropout_rng_offset_ += dropout_rng_offset_increment_; - variant_pack[CudnnfMHAUid::D_OFFSET_ID] = - (void*)¤t_dropout_rng_offset_; -#else - return absl::UnimplementedError( - "Cudnn dropout offset and seed are only supported with Cudnn >= " - "8.8.0"); -#endif // CUDNN_VERSION >= 8800 - } - int workspace = graph_.Graph().get_workspace_size(); - if (workspace > scratch_memory.size()) { - return tsl::errors::Internal( - absl::StrFormat("CuDNN FMHA requires %d workspace, got %d workspace.", - workspace, scratch_memory.size())); - } - RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.Graph().execute( - handle.handle(), variant_pack, scratch_memory.opaque())); - - return absl::OkStatus(); - } - - static absl::StatusOr Create( - GpuExecutor* parent, CudnnAccess* cudnn, CudnnGraph graph, - int64_t dropout_rng_seed, int64_t dropout_rng_offset, - std::vector> uids) { - return CudnnGraphRunner(parent, cudnn, std::move(graph), dropout_rng_seed, - dropout_rng_offset, uids); - } - - private: - CudnnGraphRunner(GpuExecutor* parent, CudnnAccess* cudnn, CudnnGraph graph, - int64_t dropout_rng_seed, int64_t dropout_rng_offset, - std::vector> uids) - : parent_(parent), - cudnn_(cudnn), - graph_(std::move(graph)), - dropout_rng_seed_(dropout_rng_seed), - current_dropout_rng_offset_(0), - dropout_rng_offset_increment_(dropout_rng_offset), - uids_(uids) {} - GpuExecutor* parent_; - CudnnAccess* cudnn_; - Stream* stream_; - CudnnGraph graph_; - int64_t dropout_rng_seed_; - mutable int64_t current_dropout_rng_offset_; - int64_t dropout_rng_offset_increment_; - std::vector> uids_; -}; -#if CUDNN_VERSION >= 8100 namespace { template @@ -6929,7 +6829,8 @@ absl::Status CudnnSupport::GetFusedMatmulRunners( use_fallback, out_exec_plans, /*need_side_input=*/true, numeric_options); #else return tsl::errors::Unimplemented( - "Cudnn execution plans for matmul are only supported with Cudnn >= 8.4."); + "Cudnn execution plans for matmul are only supported with Cudnn >= " + "8.4."); #endif // CUDNN_VERSION >= 8400 } @@ -7131,139 +7032,6 @@ int64_t GetDropoutRngOffset(std::vector& intermediate_shape) { return max_seq_len * max_seq_len / cudnn_mha_num_threads; } -absl::StatusOr> -CudnnSupport::FusedMHARunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type) { -#if CUDNN_VERSION >= 90000 - auto cudnn = cudnn_->GetHandle(parent_, stream); - bool use_dropout = dropout_rate && *dropout_rate > 0.0; - std::vector intermediate_shape; - - TF_ASSIGN_OR_RETURN(auto graph, - GetCudnnFlashAttentionOperationGraph( - *this, /*q_descriptor=*/bmm1_lhs_descriptor, - /*k_descriptor=*/bmm1_rhs_descriptor, - /*v_descriptor=*/bmm2_rhs_descriptor, - /*o_descriptor=*/output_descriptor, bias_descriptor, - /*stats_descriptor=*/activation_descriptor, - /*scale=*/static_cast(scale), use_dropout, - dropout_rate, mask_type)); - - std::vector intermediate_bmm2_lhs_dims = - intermediate_bmm2_lhs_descriptor.GetCudnnCompatibleDimensions(true); - intermediate_shape = intermediate_bmm2_lhs_dims; - int64_t dropout_rng_offset = GetDropoutRngOffset(intermediate_shape); - int64_t dropout_rng_seed = seed.has_value() ? *seed : 0; - std::vector> uids = { - CudnnfMHAUid::Q_ID, CudnnfMHAUid::K_ID, CudnnfMHAUid::V_ID, - CudnnfMHAUid::O_ID}; - uids.emplace_back(bias_descriptor.has_value() - ? std::optional(CudnnfMHAUid::BIAS_ID) - : std::nullopt); - uids.emplace_back(activation_descriptor.has_value() - ? std::optional(CudnnfMHAUid::P_ID) - : std::nullopt); - bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || - mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::Q_SEQLEN_ID) - : std::nullopt); - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::K_SEQLEN_ID) - : std::nullopt); - TF_ASSIGN_OR_RETURN(auto runner, - CudnnGraphRunner::Create( - parent_, cudnn_.get(), std::move(graph), - dropout_rng_seed, dropout_rng_offset, uids)); - - return {std::make_unique>( - std::move(runner))}; -#else - return absl::UnimplementedError( - "Cudnn flash attention are only supported with Cudnn >= 9.0.0"); -#endif // CUDNN_VERSION >= 90000 -} - -absl::StatusOr> -CudnnSupport::FusedMHABackwardRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& d_output_descriptor, - const dnn::TensorDescriptor& d_bmm1_lhs_descriptor, - const dnn::TensorDescriptor& d_bmm1_rhs_descriptor, - const dnn::TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic) { -#if CUDNN_VERSION >= 90000 - auto cudnn = cudnn_->GetHandle(parent_, stream); - - bool use_dropout = dropout_rate && *dropout_rate > 0.0; - std::vector intermediate_shape; - - TF_ASSIGN_OR_RETURN( - auto graph, - GetCudnnFlashAttentionBackwardOperationGraph( - *this, bmm1_grad_gemm1_rhs_descriptor, bmm1_grad_gemm2_rhs_descriptor, - bmm2_grad_gemm1_lhs_descriptor, bmm2_grad_gemm2_rhs_descriptor, - d_output_descriptor, d_bmm1_lhs_descriptor, d_bmm1_rhs_descriptor, - d_bmm2_rhs_descriptor, bias_descriptor, dropout_rate, seed, scale, - use_dropout, bias_descriptor != std::nullopt, mask_type, - force_deterministic)); - - std::vector p_dims = - bmm2_grad_gemm1_lhs_descriptor.GetCudnnCompatibleDimensions(false); - intermediate_shape = p_dims; - int64_t dropout_rng_offset = GetDropoutRngOffset(intermediate_shape); - int64_t dropout_rng_seed = seed.has_value() ? *seed : 0; - - std::vector> uids; - uids = {CudnnfMHAUid::Q_ID, CudnnfMHAUid::K_ID, CudnnfMHAUid::P_ID, - CudnnfMHAUid::V_ID, CudnnfMHAUid::dO_ID, CudnnfMHAUid::dQ_ID, - CudnnfMHAUid::dK_ID, CudnnfMHAUid::dV_ID, std::nullopt}; - uids.emplace_back(d_bias_descriptor.has_value() - ? std::optional(CudnnfMHAUid::dBIAS_ID) - : std::nullopt); - uids.push_back(CudnnfMHAUid::O_ID); - uids.emplace_back(bias_descriptor.has_value() - ? std::optional(CudnnfMHAUid::BIAS_ID) - : std::nullopt); - bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || - mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::Q_SEQLEN_ID) - : std::nullopt); - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::K_SEQLEN_ID) - : std::nullopt); - TF_ASSIGN_OR_RETURN(auto runner, - CudnnGraphRunner::Create( - parent_, cudnn_.get(), graph, dropout_rng_seed, - dropout_rng_offset, uids)); - return {std::make_unique>( - std::move(runner))}; -#else - return absl::UnimplementedError( - "Cudnn flash attention bwd are only " - "supported with Cudnn >= 9.0.0"); -#endif // CUDNN_VERSION >= 90000 -} - bool CudnnSupport::GetRnnAlgorithms( std::vector* out_algorithms) { PreloadCudnnSubLibs(PreloadCudnnType::Rnn); @@ -8348,15 +8116,30 @@ absl::Status CudnnGraph::Execute(Stream& stream, std::unordered_map tensor_to_ptr_map; absl::Span operands_without_workspace = operands; DeviceMemoryBase workspace; - if (graph_.get_workspace_size() != 0) { + if (graph_.get_workspace_size() > 0) { workspace = operands.back(); CHECK_EQ(graph_.get_workspace_size(), workspace.size()); + } + if (graph_.get_workspace_size() > 0 || operands.back().size() == 0) { operands_without_workspace = operands.first(operands.size() - 1); } - int operand_number = 0; + auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; for (DeviceMemoryBase operand : operands_without_workspace) { - tensor_to_ptr_map[CuDnnTensorUID(operand_number++)] = operand.opaque(); + tensor_to_ptr_map[next_uid()] = operand.opaque(); } + + if (dropout_rng_offset_increment_ > 0) { +#if CUDNN_VERSION >= 8800 + tensor_to_ptr_map[next_uid()] = (void*)&dropout_rng_seed_; + current_dropout_rng_offset_ += dropout_rng_offset_increment_; + tensor_to_ptr_map[next_uid()] = (void*)¤t_dropout_rng_offset_; +#else + return absl::UnimplementedError( + "Cudnn dropout offset and seed are only supported with Cudnn >= " + "8.8.0"); +#endif // CUDNN_VERSION >= 8800 + } + const CudnnSupport& dnn_support = static_cast(*stream.parent()->AsDnn()); RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.execute( diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h index 083a2b431e62c1..24d84e369cb138 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h @@ -70,6 +70,9 @@ class CudnnGraph : public dnn::DnnGraph { private: cudnn_frontend::graph::Graph graph_; + int64_t dropout_rng_seed_; + mutable int64_t current_dropout_rng_offset_; + int64_t dropout_rng_offset_increment_ = 0; }; #endif // CUDNN_VERSION >= 8100 @@ -335,37 +338,6 @@ class CudnnSupport : public dnn::DnnSupport { std::optional dscale_descriptor, std::optional dbias_descriptor) override; - absl::StatusOr> - FusedMHARunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type) override; - - absl::StatusOr> - FusedMHABackwardRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& d_output_descriptor, - const dnn::TensorDescriptor& d_bmm1_lhs_descriptor, - const dnn::TensorDescriptor& d_bmm1_rhs_descriptor, - const dnn::TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic); - bool GetRnnAlgorithms( std::vector* out_algorithms) override; diff --git a/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h b/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h index aa59af500ba7a3..0a30c1af59c0c4 100644 --- a/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h +++ b/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h @@ -29,6 +29,11 @@ namespace gpu { } \ } while (false) +// UIDs for cuDNN are unique identifiers of tensors within a graph. They are +// assigned during graph construction; then graph execution takes a {uid: +// buffer pointer} map defining the correspondance of buffers to tensors. +// UID assignment scheme can be arbitrary; at the moment for simplicity XLA uses +// a scheme UID = (HLO operand number + 1). int CuDnnTensorUID(int offset); } // namespace gpu diff --git a/third_party/xla/xla/stream_executor/dnn.cc b/third_party/xla/xla/stream_executor/dnn.cc index b7da7e50eb5eb9..951b2f6e147cd8 100644 --- a/third_party/xla/xla/stream_executor/dnn.cc +++ b/third_party/xla/xla/stream_executor/dnn.cc @@ -249,42 +249,6 @@ DnnSupport::NormRunnerFromDesc( return absl::UnimplementedError("NormRunnerFromDesc not implemented."); } -absl::StatusOr> -DnnSupport::FusedMHARunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type) { - return absl::UnimplementedError("FusedMHARunnerFromDesc not implemented."); -} - -absl::StatusOr> -DnnSupport::FusedMHABackwardRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& d_output_descriptor, - const TensorDescriptor& d_bmm1_lhs_descriptor, - const TensorDescriptor& d_bmm1_rhs_descriptor, - const TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic) { - return absl::UnimplementedError( - "FusedMHABackwardRunnerFromDesc not implemented."); -} - bool DnnSupport::GetMIOpenConvolveAlgorithms( dnn::ConvolutionKind /*kind*/, dnn::DataType /*element_type*/, dnn::DataType /*output_type*/, Stream* /*stream*/, diff --git a/third_party/xla/xla/stream_executor/dnn.h b/third_party/xla/xla/stream_executor/dnn.h index af709946eeb241..a2e1cd629dc2b4 100644 --- a/third_party/xla/xla/stream_executor/dnn.h +++ b/third_party/xla/xla/stream_executor/dnn.h @@ -993,30 +993,6 @@ using FusedMatmulRunner = OpRunner; using NormSignature = void(std::vector); using NormRunner = OpRunner; -using FusedMHASignature = void(DeviceMemoryBase /*BMM1_inputA_data*/, - DeviceMemoryBase /* BMM1_inputB_data */, - DeviceMemoryBase /* BMM2_inputA_data */, - DeviceMemoryBase /* output_data */, - DeviceMemoryBase /* bias_data */, - DeviceMemoryBase /* activation_data */, - DeviceMemoryBase /* seqlen_q_data */, - DeviceMemoryBase /* seqlen_k_data */); -using FusedMHARunner = OpRunner; - -using FusedMHABackwardSignature = void( - DeviceMemoryBase /* BMM1_GRAD_GEMM1_inputA_data */, - DeviceMemoryBase /* BMM1_GRAD_GEMM2_inputB_data */, - DeviceMemoryBase /* BMM2_GRAD_GEMM1_inputA_data */, - DeviceMemoryBase /* BMM2_GRAD_GEMM2_inputB_data */, - DeviceMemoryBase /* d_output_data */, - DeviceMemoryBase /* d_BMM1_inputA_data */, - DeviceMemoryBase /* d_BMM1_inputB_data */, - DeviceMemoryBase /* d_BMM2_inputB_data */, DeviceMemoryBase /* d_S_data */, - DeviceMemoryBase /* d_bias_data */, DeviceMemoryBase /* fwd_output_data */, - DeviceMemoryBase /* bias_data */, DeviceMemoryBase /* seqlen_q_data */, - DeviceMemoryBase /* seqlen_k_data */); -using FusedMHABackwardRunner = OpRunner; - // Describes the configuration for the algorithms that will used. // // Arguments: @@ -1731,37 +1707,6 @@ class DnnSupport { return absl::UnimplementedError("Graph support requires cuDNN >= 8.1."); }; - virtual absl::StatusOr> - FusedMHARunnerFromDesc( - Stream* stream, const AlgorithmDesc& algorithm_desc, - const MatmulTensorDescriptor& bmm1_lhs_descriptor, - const MatmulTensorDescriptor& bmm1_rhs_descriptor, - const MatmulTensorDescriptor& bmm2_rhs_descriptor, - const MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type); - - virtual absl::StatusOr> - FusedMHABackwardRunnerFromDesc( - Stream* stream, const AlgorithmDesc& algorithm_desc, - const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& d_output_descriptor, - const TensorDescriptor& d_bmm1_lhs_descriptor, - const TensorDescriptor& d_bmm1_rhs_descriptor, - const TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic); - virtual bool GetMIOpenConvolveAlgorithms( ConvolutionKind kind, DataType element_type, DataType output_type, Stream* stream, const BatchDescriptor& input_descriptor, diff --git a/third_party/xla/xla/stream_executor/lazy_op_runner.h b/third_party/xla/xla/stream_executor/lazy_op_runner.h index c74a03e1ad5226..bf964e05bbaae6 100644 --- a/third_party/xla/xla/stream_executor/lazy_op_runner.h +++ b/third_party/xla/xla/stream_executor/lazy_op_runner.h @@ -280,76 +280,6 @@ struct FusedMatmulOp { } }; -struct FusedMHAOp { - using Signature = FusedMHASignature; - struct Config { - double scale; - const MatmulTensorDescriptor& bmm1_lhs_descriptor; - const MatmulTensorDescriptor& bmm1_rhs_descriptor; - const MatmulTensorDescriptor& bmm2_rhs_descriptor; - const MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor; - const TensorDescriptor& output_descriptor; - std::optional bias_descriptor; - std::optional activation_descriptor; - std::optional dropout_rate; - std::optional seed; - FMHAMaskKind mask_type; - }; - - static absl::StatusOr>> - RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, - Stream* stream) { - TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); - return dnn->FusedMHARunnerFromDesc( - stream, desc, config.bmm1_lhs_descriptor, config.bmm1_rhs_descriptor, - config.bmm2_rhs_descriptor, config.intermediate_bmm2_lhs_descriptor, - config.output_descriptor, config.activation_descriptor, - config.bias_descriptor, config.scale, config.dropout_rate, config.seed, - config.mask_type); - } -}; - -struct FusedMHABackwardOp { - using Signature = FusedMHABackwardSignature; - - struct Config { - double scale; - const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor; - const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor; - const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor; - const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor; - const MatmulTensorDescriptor& d_output_descriptor; - const TensorDescriptor& d_bmm1_lhs_descriptor; - const TensorDescriptor& d_bmm1_rhs_descriptor; - const TensorDescriptor& d_bmm2_rhs_descriptor; - std::optional d_s_descriptor; - std::optional d_bias_descriptor; - std::optional fwd_output_descriptor; - std::optional bias_descriptor; - std::optional dropout_rate; - std::optional seed; - FMHAMaskKind mask_type; - bool force_deterministic; - }; - - static absl::StatusOr< - std::unique_ptr>> - RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, - Stream* stream) { - TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); - return dnn->FusedMHABackwardRunnerFromDesc( - stream, desc, config.bmm1_grad_gemm1_rhs_descriptor, - config.bmm1_grad_gemm2_rhs_descriptor, - config.bmm2_grad_gemm1_lhs_descriptor, - config.bmm2_grad_gemm2_rhs_descriptor, config.d_output_descriptor, - config.d_bmm1_lhs_descriptor, config.d_bmm1_rhs_descriptor, - config.d_bmm2_rhs_descriptor, config.d_s_descriptor, - config.d_bias_descriptor, config.fwd_output_descriptor, - config.bias_descriptor, config.scale, config.dropout_rate, config.seed, - config.mask_type, config.force_deterministic); - } -}; - } // namespace dnn } // namespace stream_executor