diff --git a/third_party/xla/xla/mlir/runtime/transforms/BUILD b/third_party/xla/xla/mlir/runtime/transforms/BUILD index d9151d90a5c200..b56d4930bfbd59 100644 --- a/third_party/xla/xla/mlir/runtime/transforms/BUILD +++ b/third_party/xla/xla/mlir/runtime/transforms/BUILD @@ -100,6 +100,12 @@ cc_library( srcs = ["compilation_pipeline_cpu.cc"], hdrs = ["compilation_pipeline_cpu.h"], compatible_with = get_compatible_with_portable(), + local_defines = select({ + "//xla/service/cpu:experimental_mlir_gpu_enabled": [ + "EXPERIMENTAL_MLIR_GPU=1", + ], + "//conditions:default": [], + }), visibility = ["//visibility:public"], deps = [ ":compilation_pipeline_options", @@ -126,8 +132,6 @@ cc_library( "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", - "@llvm-project//mlir:GPUToGPURuntimeTransforms", - "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:LinalgTransforms", @@ -142,7 +146,14 @@ cc_library( "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Transforms", "@llvm-project//mlir:X86VectorToLLVMIRTranslation", - ], + "@local_tsl//tsl/platform:logging", + ] + select({ + "//xla/service/cpu:experimental_mlir_gpu_enabled": [ + "@llvm-project//mlir:GPUToGPURuntimeTransforms", + "@llvm-project//mlir:GPUTransforms", + ], + "//conditions:default": [], + }), alwayslink = 1, # has pipeline registration ) diff --git a/third_party/xla/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc b/third_party/xla/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc index 50488cad7f13ff..22920be516e5d4 100644 --- a/third_party/xla/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc +++ b/third_party/xla/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc @@ -20,7 +20,6 @@ limitations under the License. #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" // from @llvm-project #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project -#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project @@ -32,7 +31,6 @@ limitations under the License. #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project #include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project @@ -57,6 +55,12 @@ limitations under the License. #include "xla/mlir/runtime/transforms/compiler.h" #include "xla/mlir/runtime/transforms/passes.h" #include "xla/mlir_hlo/transforms/passes.h" +#include "tsl/platform/logging.h" + +#ifdef EXPERIMENTAL_MLIR_GPU +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project +#include "mlir/Dialect/GPU/Transforms/Passes.h" // from @llvm-project +#endif // EXPERIMENTAL_MLIR_GPU namespace xla { namespace runtime { @@ -146,6 +150,7 @@ static void CreateXlaCpuCompilationPipeline(mlir::OpPassManager& pm, llvm_options.enableAvx2 = opts.math_avx2; pm.addPass(mlir::hlo::createGenericHostToLLVMPass(llvm_options)); const bool gpuCodegen = opts.xla_cpu_sparse_cuda_threads > 0; +#ifdef EXPERIMENTAL_MLIR_GPU if (gpuCodegen) { #ifdef MLIR_GPU_TO_CUBIN_PASS_ENABLE pm.addNestedPass( @@ -154,6 +159,10 @@ static void CreateXlaCpuCompilationPipeline(mlir::OpPassManager& pm, #endif pm.addPass(mlir::createGpuToLLVMConversionPass()); } +#else // EXPERIMENTAL_MLIR_GPU + CHECK(!gpuCodegen) + << "Experimental MLIR GPU code generation was not enabled at build time"; +#endif // EXPERIMENTAL_MLIR_GPU pm.addPass(mlir::createReconcileUnrealizedCastsPass()); // Prepare module for translation to LLVM. diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 6f782102a843b6..448ecbc8e1bf6c 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -2,6 +2,7 @@ # LLVM-based CPU backend for XLA. load("@bazel_skylib//rules:build_test.bzl", "build_test") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load( "//xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS", @@ -45,6 +46,19 @@ filegroup( visibility = ["//visibility:public"], ) +bool_flag( + name = "experimental_mlir_gpu", + build_setting_default = False, +) + +config_setting( + name = "experimental_mlir_gpu_enabled", + flag_values = { + ":experimental_mlir_gpu": "True", + }, + visibility = ["//visibility:public"], +) + cc_library( name = "test_header_helper", testonly = True, @@ -468,6 +482,10 @@ cc_library( name = "hlo_xla_runtime_pipeline", srcs = ["hlo_xla_runtime_pipeline.cc"], hdrs = ["hlo_xla_runtime_pipeline.h"], + local_defines = select({ + ":experimental_mlir_gpu_enabled": ["EXPERIMENTAL_MLIR_GPU=1"], + "//conditions:default": [], + }), visibility = ["//visibility:public"], deps = [ "//xla:status", @@ -483,8 +501,6 @@ cc_library( "@llvm-project//mlir:ComplexToStandard", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncTransforms", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:GPUToNVVMTransforms", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:Pass", @@ -502,7 +518,13 @@ cc_library( "@llvm-project//mlir:VectorTransforms", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", - ], + ] + select({ + ":experimental_mlir_gpu_enabled": [ + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + ], + "//conditions:default": [], + }), alwayslink = 1, # has pipeline registration ) diff --git a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc index caff8ee13e90f7..0dafc4c9235d3e 100644 --- a/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc +++ b/third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc @@ -20,7 +20,6 @@ limitations under the License. #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" // from @llvm-project #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" // from @llvm-project -#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project @@ -33,7 +32,6 @@ limitations under the License. #include "mlir/Dialect/Bufferization/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Func/Transforms/Passes.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" // from @llvm-project @@ -56,6 +54,11 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#ifdef EXPERIMENTAL_MLIR_GPU +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#endif // EXPERIMENTAL_MLIR_GPU + namespace xla { namespace cpu { namespace { @@ -111,6 +114,7 @@ void AddSparsificationPasses(mlir::OpPassManager& pm, bool new_deallocator, pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass( mlir::bufferization::createFinalizingBufferizePass()); +#ifdef EXPERIMENTAL_MLIR_GPU // Sparse GPU acceleration lowers to GPU dialect. if (gpu_codegen) { pm.addPass( @@ -120,6 +124,10 @@ void AddSparsificationPasses(mlir::OpPassManager& pm, bool new_deallocator, pm.addNestedPass( mlir::createConvertGpuOpsToNVVMOps()); } +#else // EXPERIMENTAL_MLIR_GPU + CHECK(!gpu_codegen) + << "Experimental MLIR GPU code generation was not enabled at build time"; +#endif // EXPERIMENTAL_MLIR_GPU } void AddSparsificationPassPipeline(mlir::OpPassManager& pm) {