Skip to content

Commit

Permalink
[XLA:CPU] Make an experimental dependency on the MLIR GPU dialects op…
Browse files Browse the repository at this point in the history
…tional.

This code path isn't used by default, and linking in the MLIR GPU dialects increases binary size.

PiperOrigin-RevId: 587828526
  • Loading branch information
hawkinsp authored and tensorflower-gardener committed Dec 4, 2023
1 parent fffb030 commit d699fc4
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 10 deletions.
17 changes: 14 additions & 3 deletions third_party/xla/xla/mlir/runtime/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ cc_library(
srcs = ["compilation_pipeline_cpu.cc"],
hdrs = ["compilation_pipeline_cpu.h"],
compatible_with = get_compatible_with_portable(),
local_defines = select({
"//xla/service/cpu:experimental_mlir_gpu_enabled": [
"EXPERIMENTAL_MLIR_GPU=1",
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
":compilation_pipeline_options",
Expand All @@ -126,8 +132,6 @@ cc_library(
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncExtensions",
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
"@llvm-project//mlir:GPUTransforms",
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LinalgTransforms",
Expand All @@ -142,7 +146,14 @@ cc_library(
"@llvm-project//mlir:SparseTensorDialect",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:X86VectorToLLVMIRTranslation",
],
"@local_tsl//tsl/platform:logging",
] + select({
"//xla/service/cpu:experimental_mlir_gpu_enabled": [
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
"@llvm-project//mlir:GPUTransforms",
],
"//conditions:default": [],
}),
alwayslink = 1, # has pipeline registration
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ limitations under the License.
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" // from @llvm-project
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project
Expand All @@ -32,7 +31,6 @@ limitations under the License.
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project
#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/GPU/Transforms/Passes.h" // from @llvm-project
#include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project
Expand All @@ -57,6 +55,12 @@ limitations under the License.
#include "xla/mlir/runtime/transforms/compiler.h"
#include "xla/mlir/runtime/transforms/passes.h"
#include "xla/mlir_hlo/transforms/passes.h"
#include "tsl/platform/logging.h"

#ifdef EXPERIMENTAL_MLIR_GPU
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project
#include "mlir/Dialect/GPU/Transforms/Passes.h" // from @llvm-project
#endif // EXPERIMENTAL_MLIR_GPU

namespace xla {
namespace runtime {
Expand Down Expand Up @@ -146,6 +150,7 @@ static void CreateXlaCpuCompilationPipeline(mlir::OpPassManager& pm,
llvm_options.enableAvx2 = opts.math_avx2;
pm.addPass(mlir::hlo::createGenericHostToLLVMPass(llvm_options));
const bool gpuCodegen = opts.xla_cpu_sparse_cuda_threads > 0;
#ifdef EXPERIMENTAL_MLIR_GPU
if (gpuCodegen) {
#ifdef MLIR_GPU_TO_CUBIN_PASS_ENABLE
pm.addNestedPass<mlir::gpu::GPUModuleOp>(
Expand All @@ -154,6 +159,10 @@ static void CreateXlaCpuCompilationPipeline(mlir::OpPassManager& pm,
#endif
pm.addPass(mlir::createGpuToLLVMConversionPass());
}
#else // EXPERIMENTAL_MLIR_GPU
CHECK(!gpuCodegen)
<< "Experimental MLIR GPU code generation was not enabled at build time";
#endif // EXPERIMENTAL_MLIR_GPU
pm.addPass(mlir::createReconcileUnrealizedCastsPass());

// Prepare module for translation to LLVM.
Expand Down
28 changes: 25 additions & 3 deletions third_party/xla/xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# LLVM-based CPU backend for XLA.

load("@bazel_skylib//rules:build_test.bzl", "build_test")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load(
"//xla:xla.bzl",
"ORC_JIT_MEMORY_MAPPER_TARGETS",
Expand Down Expand Up @@ -45,6 +46,19 @@ filegroup(
visibility = ["//visibility:public"],
)

bool_flag(
name = "experimental_mlir_gpu",
build_setting_default = False,
)

config_setting(
name = "experimental_mlir_gpu_enabled",
flag_values = {
":experimental_mlir_gpu": "True",
},
visibility = ["//visibility:public"],
)

cc_library(
name = "test_header_helper",
testonly = True,
Expand Down Expand Up @@ -468,6 +482,10 @@ cc_library(
name = "hlo_xla_runtime_pipeline",
srcs = ["hlo_xla_runtime_pipeline.cc"],
hdrs = ["hlo_xla_runtime_pipeline.h"],
local_defines = select({
":experimental_mlir_gpu_enabled": ["EXPERIMENTAL_MLIR_GPU=1"],
"//conditions:default": [],
}),
visibility = ["//visibility:public"],
deps = [
"//xla:status",
Expand All @@ -483,8 +501,6 @@ cc_library(
"@llvm-project//mlir:ComplexToStandard",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncTransforms",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUToNVVMTransforms",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:Pass",
Expand All @@ -502,7 +518,13 @@ cc_library(
"@llvm-project//mlir:VectorTransforms",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
],
] + select({
":experimental_mlir_gpu_enabled": [
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUToNVVMTransforms",
],
"//conditions:default": [],
}),
alwayslink = 1, # has pipeline registration
)

Expand Down
12 changes: 10 additions & 2 deletions third_party/xla/xla/service/cpu/hlo_xla_runtime_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ limitations under the License.
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" // from @llvm-project
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" // from @llvm-project
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project
Expand All @@ -33,7 +32,6 @@ limitations under the License.
#include "mlir/Dialect/Bufferization/Transforms/Passes.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Func/Transforms/Passes.h" // from @llvm-project
#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" // from @llvm-project
Expand All @@ -56,6 +54,11 @@ limitations under the License.
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"

#ifdef EXPERIMENTAL_MLIR_GPU
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project
#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project
#endif // EXPERIMENTAL_MLIR_GPU

namespace xla {
namespace cpu {
namespace {
Expand Down Expand Up @@ -111,6 +114,7 @@ void AddSparsificationPasses(mlir::OpPassManager& pm, bool new_deallocator,
pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
pm.addNestedPass<mlir::func::FuncOp>(
mlir::bufferization::createFinalizingBufferizePass());
#ifdef EXPERIMENTAL_MLIR_GPU
// Sparse GPU acceleration lowers to GPU dialect.
if (gpu_codegen) {
pm.addPass(
Expand All @@ -120,6 +124,10 @@ void AddSparsificationPasses(mlir::OpPassManager& pm, bool new_deallocator,
pm.addNestedPass<mlir::gpu::GPUModuleOp>(
mlir::createConvertGpuOpsToNVVMOps());
}
#else // EXPERIMENTAL_MLIR_GPU
CHECK(!gpu_codegen)
<< "Experimental MLIR GPU code generation was not enabled at build time";
#endif // EXPERIMENTAL_MLIR_GPU
}

void AddSparsificationPassPipeline(mlir::OpPassManager& pm) {
Expand Down

0 comments on commit d699fc4

Please sign in to comment.