From 61514ad415a16a290f35f29a6717f4d52497c552 Mon Sep 17 00:00:00 2001 From: "Lu, Chengjun" Date: Fri, 23 Feb 2024 13:24:31 +0000 Subject: [PATCH] Add Intel loop pipelining to optimize the loop body with heavy tt.dot computation. Only work for the non-nested loop body at first. --- .../TritonIntelGPU/Transforms/Passes.h | 6 + .../TritonIntelGPU/Transforms/Passes.td | 30 ++ test/TritonIntelGPU/loop-pipeline.mlir | 134 +++++++ .../TritonIntelGPUTransforms/CMakeLists.txt | 3 + .../Pipeliner/MatmulLoopPipeline.cpp | 330 ++++++++++++++++++ .../Pipeliner/Schedule.h | 19 + .../Pipeliner/SoftwarePipeliner.cpp | 77 ++++ third_party/intel/triton_xpu.cc | 10 + 8 files changed, 609 insertions(+) create mode 100644 test/TritonIntelGPU/loop-pipeline.mlir create mode 100644 third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp create mode 100644 third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/Schedule.h create mode 100644 third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/SoftwarePipeliner.cpp diff --git a/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.h b/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.h index 45ea0a4564..6a893d6932 100644 --- a/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.h @@ -35,6 +35,12 @@ std::unique_ptr createTritonIntelGPURewriteTensorPointerPass(); std::unique_ptr createPrefetchBlockPass(); std::unique_ptr createMatchTargetSizePass(); + +std::unique_ptr +createTritonIntelGPUPipelinePass(int numStages = 2, + triton::gpu::intel::DeviceArch arch = + triton::gpu::intel::DeviceArch::UNKNOWN); + } // namespace intel } // namespace gpu } // namespace triton diff --git a/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.td b/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.td index c9e27bb985..688957d396 100644 --- a/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -234,4 +234,34 @@ def TritonIntelGPUMatchTargetSize : Pass<"tritonintelgpu-match-target-size", "ml "mlir::triton::gpu::intel::TritonIntelGPUDialect"]; } +def TritonIntelGPUPipeline : Pass<"tritonintelgpu-pipeline", "mlir::ModuleOp"> { + let summary = "Intel GPU pipeline"; + + let description = [{ + This pass pipelines the loop body with heavy `tt.dot` operations. + It supports to prefetch the `tt.dot` operands to cache on PVC. To prefetch operands to cache is in the first stage. + User can use the num-stages to control the prefetching and using distance. + }]; + + let constructor = "mlir::triton::gpu::intel::createTritonIntelGPUPipelinePass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::gpu::intel::TritonIntelGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; + + let options = [ + Option<"numStages", "num-stages", + "int32_t", /*default*/"2", + "number of pipeline stages">, + Option<"deviceArch", "device-architecture", + "mlir::triton::gpu::intel::DeviceArch", /*default*/" mlir::triton::gpu::intel::DeviceArch::PVC", + "device architecture", + "llvm::cl::values(" + "clEnumValN(mlir::triton::gpu::intel::DeviceArch::UNKNOWN, \"UNKNOWN\", \"Unknown arch\"), " + "clEnumValN(mlir::triton::gpu::intel::DeviceArch::ATS, \"ATS\", \"ATS arch\"), " + "clEnumValN(mlir::triton::gpu::intel::DeviceArch::PVC, \"PVC\", \"PVC arch\"))"> + ]; +} + #endif // TRITON_INTEL_GPU_PASSES diff --git a/test/TritonIntelGPU/loop-pipeline.mlir b/test/TritonIntelGPU/loop-pipeline.mlir new file mode 100644 index 0000000000..10d5988884 --- /dev/null +++ b/test/TritonIntelGPU/loop-pipeline.mlir @@ -0,0 +1,134 @@ +// RUN: triton-opt %s -split-input-file -tritonintelgpu-pipeline="num-stages=3 device-architecture=PVC" | FileCheck %s + +// CHECK: #[[$BLOCK_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK: #[[$BLOCK_1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +// CHECK: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], A = [8, 16], B = [16, 16], C = [8, 16]}> +// CHECK-LABEL: tt.func public @matmul_kernel( +// COM: There is 3 stages in loop body pipelining. The distance of prefetching stage to load stage is 2. There are total 4 prefetching ops for A and B operand in pre-epilogue. +// CHECK: triton_intel_gpu.prefetch %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> +// CHECK: triton_intel_gpu.prefetch %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> +// CHECK: triton_intel_gpu.prefetch %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> +// CHECK: triton_intel_gpu.prefetch %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> +// CHECK: scf.for %[[VAL_92:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[VAL_93:.*]] = %{{.*}}, %[[VAL_94:.*]] = %{{.*}}, %[[VAL_95:.*]] = %{{.*}}, %[[VAL_96:.*]] = %{{.*}}, %[[VAL_97:.*]] = %{{.*}}) -> (tensor<64x256xf32, #[[$DPAS]]>, tensor<64x32x!tt.ptr, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr, #[[$BLOCK_1]]>, tensor<64x32x!tt.ptr, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr, #[[$BLOCK_1]]>) : i32 { +// CHECK: %[[VAL_106:.*]] = tt.addptr %[[VAL_94]], %{{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]>, tensor<64x32xi32, #[[$BLOCK_0]]> +// CHECK: %[[VAL_107:.*]] = tt.addptr %[[VAL_95]], %{{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]>, tensor<32x256xi32, #[[$BLOCK_1]]> +// CHECK: triton_intel_gpu.prefetch %[[VAL_106]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> +// CHECK: triton_intel_gpu.prefetch %[[VAL_107]] {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> +// CHECK: %[[VAL_116:.*]] = tt.load %[[VAL_96]], %{{.*}}, %{{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> +// CHECK: %[[VAL_120:.*]] = tt.load %[[VAL_97]], %{{.*}}, %{{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> +// CHECK: %[[VAL_121:.*]] = triton_gpu.convert_layout %[[VAL_116]] : tensor<64x32xf16, #[[$BLOCK_0]]> -> tensor<64x32xf16, #{{.*}}<{opIdx = 0, parent = #[[$DPAS]]}>> +// CHECK: %[[VAL_122:.*]] = triton_gpu.convert_layout %[[VAL_120]] : tensor<32x256xf16, #[[$BLOCK_1]]> -> tensor<32x256xf16, #{{.*}}<{opIdx = 1, parent = #[[$DPAS]]}>> +// CHECK: %[[VAL_123:.*]] = tt.dot %[[VAL_121]], %[[VAL_122]], %[[VAL_93]], inputPrecision = tf32 : tensor<64x32xf16, #{{.*}}<{opIdx = 0, parent = #[[$DPAS]]}>> * tensor<32x256xf16, #{{.*}}<{opIdx = 1, parent = #[[$DPAS]]}>> -> tensor<64x256xf32, #[[$DPAS]]> +// COM: The carry dependency distance is enlarged to 3 iteration. The SCF ForOp iter arg should match with the SCF Yeild args. +// CHECK: scf.yield %[[VAL_123]], %[[VAL_106]], %[[VAL_107]], %[[VAL_94]], %[[VAL_95]] : tensor<64x256xf32, #[[$DPAS]]>, tensor<64x32x!tt.ptr, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr, #[[$BLOCK_1]]>, tensor<64x32x!tt.ptr, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], A = [8, 16], B = [16, 16], C = [8, 16]}> +module attributes {"triton_gpu.compute-capability" = 0 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c8_i32 = arith.constant 8 : i32 + %c64_i32 = arith.constant 64 : i32 + %c256_i32 = arith.constant 256 : i32 + %c32_i32 = arith.constant 32 : i32 + %cst = arith.constant dense<32> : tensor<64x32xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x32xf16, #blocked> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x256xf16, #blocked1> + %c63_i32 = arith.constant 63 : i32 + %c255_i32 = arith.constant 255 : i32 + %c31_i32 = arith.constant 31 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x256xf32, #mma> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c63_i32 : i32 + %2 = arith.divsi %1, %c64_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.minsi %8, %c8_i32 : i32 + %10 = arith.remsi %0, %9 : i32 + %11 = arith.addi %7, %10 : i32 + %12 = arith.remsi %0, %5 : i32 + %13 = arith.divsi %12, %9 : i32 + %14 = arith.muli %11, %c64_i32 : i32 + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %17 = tt.splat %14 : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %18 = tt.splat %14 : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %19 = arith.addi %17, %15 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %20 = arith.addi %18, %16 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %21 = tt.splat %arg3 : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %22 = arith.remsi %19, %21 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %23 = arith.muli %13, %c256_i32 : i32 + %24 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %25 = tt.splat %23 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %26 = arith.addi %25, %24 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %27 = tt.splat %arg4 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %28 = arith.remsi %26, %27 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %29 = tt.expand_dims %22 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %30 = tt.splat %arg6 : i32 -> tensor<64x1xi32, #blocked> + %31 = arith.muli %29, %30 : tensor<64x1xi32, #blocked> + %32 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %33 = tt.expand_dims %32 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %34 = tt.broadcast %31 : tensor<64x1xi32, #blocked> -> tensor<64x32xi32, #blocked> + %35 = tt.broadcast %33 : tensor<1x32xi32, #blocked> -> tensor<64x32xi32, #blocked> + %36 = arith.addi %34, %35 : tensor<64x32xi32, #blocked> + %37 = tt.splat %arg0 : !tt.ptr -> tensor<64x32x!tt.ptr, #blocked> + %38 = tt.addptr %37, %36 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi32, #blocked> + %39 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %40 = tt.expand_dims %39 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %41 = tt.splat %arg7 : i32 -> tensor<32x1xi32, #blocked1> + %42 = arith.muli %40, %41 : tensor<32x1xi32, #blocked1> + %43 = tt.expand_dims %28 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %44 = tt.broadcast %42 : tensor<32x1xi32, #blocked1> -> tensor<32x256xi32, #blocked1> + %45 = tt.broadcast %43 : tensor<1x256xi32, #blocked1> -> tensor<32x256xi32, #blocked1> + %46 = arith.addi %44, %45 : tensor<32x256xi32, #blocked1> + %47 = tt.splat %arg1 : !tt.ptr -> tensor<32x256x!tt.ptr, #blocked1> + %48 = tt.addptr %47, %46 : tensor<32x256x!tt.ptr, #blocked1>, tensor<32x256xi32, #blocked1> + %49 = arith.addi %arg5, %c31_i32 : i32 + %50 = arith.divsi %49, %c32_i32 : i32 + %51 = arith.muli %arg7, %c32_i32 : i32 + %52 = tt.splat %51 : i32 -> tensor<32x256xi32, #blocked1> + %53:3 = scf.for %arg9 = %c0_i32 to %50 step %c1_i32 iter_args(%arg10 = %cst_2, %arg11 = %38, %arg12 = %48) -> (tensor<64x256xf32, #mma>, tensor<64x32x!tt.ptr, #blocked>, tensor<32x256x!tt.ptr, #blocked1>) : i32 { + %72 = arith.muli %arg9, %c32_i32 : i32 + %73 = arith.subi %arg5, %72 : i32 + %74 = tt.splat %73 : i32 -> tensor<1x32xi32, #blocked> + %75 = arith.cmpi slt, %33, %74 : tensor<1x32xi32, #blocked> + %76 = tt.broadcast %75 : tensor<1x32xi1, #blocked> -> tensor<64x32xi1, #blocked> + %77 = tt.load %arg11, %76, %cst_0 : tensor<64x32x!tt.ptr, #blocked> + %78 = tt.splat %73 : i32 -> tensor<32x1xi32, #blocked1> + %79 = arith.cmpi slt, %40, %78 : tensor<32x1xi32, #blocked1> + %80 = tt.broadcast %79 : tensor<32x1xi1, #blocked1> -> tensor<32x256xi1, #blocked1> + %81 = tt.load %arg12, %80, %cst_1 : tensor<32x256x!tt.ptr, #blocked1> + %82 = triton_gpu.convert_layout %77 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %83 = triton_gpu.convert_layout %81 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> + %84 = tt.dot %82, %83, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<64x256xf32, #mma> + %85 = tt.addptr %arg11, %cst : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi32, #blocked> + %86 = tt.addptr %arg12, %52 : tensor<32x256x!tt.ptr, #blocked1>, tensor<32x256xi32, #blocked1> + scf.yield %84, %85, %86 : tensor<64x256xf32, #mma>, tensor<64x32x!tt.ptr, #blocked>, tensor<32x256x!tt.ptr, #blocked1> + } + %54 = arith.truncf %53#0 : tensor<64x256xf32, #mma> to tensor<64x256xf16, #mma> + %55 = triton_gpu.convert_layout %54 : tensor<64x256xf16, #mma> -> tensor<64x256xf16, #blocked1> + %56 = tt.expand_dims %20 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %57 = tt.splat %arg8 : i32 -> tensor<64x1xi32, #blocked1> + %58 = arith.muli %57, %56 : tensor<64x1xi32, #blocked1> + %59 = tt.splat %arg2 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> + %60 = tt.addptr %59, %58 : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %61 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %62 = tt.broadcast %60 : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x256x!tt.ptr, #blocked1> + %63 = tt.broadcast %61 : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + %64 = tt.addptr %62, %63 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> + %65 = tt.splat %arg3 : i32 -> tensor<64x1xi32, #blocked1> + %66 = arith.cmpi slt, %56, %65 : tensor<64x1xi32, #blocked1> + %67 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked1> + %68 = arith.cmpi slt, %61, %67 : tensor<1x256xi32, #blocked1> + %69 = tt.broadcast %66 : tensor<64x1xi1, #blocked1> -> tensor<64x256xi1, #blocked1> + %70 = tt.broadcast %68 : tensor<1x256xi1, #blocked1> -> tensor<64x256xi1, #blocked1> + %71 = arith.andi %69, %70 : tensor<64x256xi1, #blocked1> + tt.store %64, %55, %71 : tensor<64x256x!tt.ptr, #blocked1> + tt.return + } +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index 25f13a6dc1..c5b0379fe3 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -5,6 +5,8 @@ add_triton_library(TritonIntelGPUTransforms PrefetchBlock.cpp RemoveLayoutConversions.cpp RewriteTensorPointer.cpp + Pipeliner/MatmulLoopPipeline.cpp + Pipeliner/SoftwarePipeliner.cpp Utility.cpp DEPENDS @@ -13,6 +15,7 @@ add_triton_library(TritonIntelGPUTransforms LINK_LIBS PUBLIC MLIRTransforms MLIRTransformUtils + MLIRSCFTransforms TritonAnalysis TritonIR TritonGPUIR diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp new file mode 100644 index 0000000000..e0b7432456 --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp @@ -0,0 +1,330 @@ +#include "Schedule.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" + +#define int_attr(num) builder.getI64IntegerAttr(num) + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttgi = mlir::triton::gpu::intel; + +// TODO: We can extra some helpers into common utilities once we add more +// schedules. + +/// Replace the yield with a new one with the given operands appended. +static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { + // Fix up the yield op. + Operation *yieldOp = forOp.getBody()->getTerminator(); + SmallVector operands(yieldOp->getOperands().begin(), + yieldOp->getOperands().end()); + operands.append(newOperands.begin(), newOperands.end()); + OpBuilder builder(yieldOp); + builder.create(yieldOp->getLoc(), operands); + yieldOp->erase(); +} + +namespace { +struct LoadDotOperand { + LoadDotOperand(tt::LoadOp load, + ttg::DotOperandEncodingAttr dotOperandEncoding, + bool needTrans = false) + : load(load), dotOperandEncoding(dotOperandEncoding) {} + tt::LoadOp load; + ttg::DotOperandEncodingAttr dotOperandEncoding; +}; +} // namespace + +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return the encoding. Otherwise return nullptr. +static ttg::DotOperandEncodingAttr allTransitiveUsesHaveDotEncoding(Value val) { + ttg::DotOperandEncodingAttr attr{nullptr}; + for (Operation *user : val.getUsers()) { + if (user->getNumResults() != 1) + return nullptr; + auto tensorType = user->getResult(0).getType().dyn_cast(); + if (!tensorType) + return nullptr; + ttg::DotOperandEncodingAttr tempAttr; + if (tensorType.getEncoding().isa()) { + tempAttr = allTransitiveUsesHaveDotEncoding(user->getResult(0)); + } else if (auto convertLayout = + llvm::dyn_cast(user)) { + auto tensorType = + convertLayout.getResult().getType().dyn_cast(); + if (!tensorType) + return nullptr; + tempAttr = + tensorType.getEncoding().dyn_cast(); + } else if (auto dotOp = llvm::dyn_cast(user)) { + auto tensorType = val.getType().dyn_cast(); + if (!tensorType) + return nullptr; + tempAttr = + tensorType.getEncoding().dyn_cast(); + } else { + return nullptr; + } + if (!tempAttr || (attr != nullptr && attr != tempAttr)) + return nullptr; + attr = tempAttr; + } + return attr; +} + +static void createPrefetchOp(scf::ForOp &forOp, tt::LoadOp loadOp, Value ptr) { + OpBuilder builder(forOp); + // Replace the load with load/prefetch in different stage. + builder.setInsertionPoint(loadOp); + Location loc = loadOp->getLoc(); + auto prefetchOp = builder.create( + loc, ptr, loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); +} + +// Return the transitive use of the load which is a dot operand. +static std::optional loadDotOperand(tt::LoadOp loadOp) { + ttg::DotOperandEncodingAttr attr = + allTransitiveUsesHaveDotEncoding(loadOp.getResult()); + if (!attr) + return std::nullopt; + return LoadDotOperand(loadOp, attr); +} + +/// Collect loads to pipeline. Return success if we can pipeline this loop +static void collectOpsToPipeline(scf::ForOp forOp, + SmallVectorImpl &ops) { + ModuleOp moduleOp = forOp->getParentOfType(); + mlir::triton::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // We cannot use forOp.walk(...) here because we only want to visit the + // operations in the loop body block. + for (Operation &op : forOp) { + if (auto loadOp = dyn_cast(&op)) { + bool candidate = false; + if (isLoadFromTensorPtr(loadOp)) { + // 2D load/store. + candidate = true; + } else { + // gather/scatter + candidate = true; + } + if (!candidate) + continue; + std::optional loadWithDotOperand = loadDotOperand(loadOp); + if (!loadWithDotOperand.has_value()) + continue; + ops.push_back(loadWithDotOperand.value()); + } + } +} + +static void createPrefetchOps(scf::ForOp &forOp, ArrayRef loads, + int numStages) { + struct prefetchLoad { + prefetchLoad(tt::LoadOp load, Value ptr) : load(load), ptr(ptr) {} + tt::LoadOp load; + Value ptr; + }; + int numBuffers = numStages - 1; + SmallVector prefetchLoads; + + for (const LoadDotOperand &loadOperand : loads) { + tt::LoadOp loadOp = loadOperand.load; + prefetchLoads.emplace_back(loadOp, loadOp.getPtr()); + } + + for (prefetchLoad &prefetchLoad : prefetchLoads) { + createPrefetchOp(forOp, prefetchLoad.load, prefetchLoad.ptr); + } +} + +// Combine the current mask with the given predicate. +static Value getPredMask(RewriterBase &rewriter, Type typeLike, + Value currentMask, Value pred) { + Type maskType = tt::getI1SameShape(typeLike); + Location loc = pred.getLoc(); + Value mask = pred; + if (maskType.isa()) { + mask = rewriter.create(loc, maskType, pred); + } + if (currentMask) { + mask = rewriter.create(loc, mask, currentMask); + } + return mask; +} + +// Function to mask operations during scheduling. +static Operation *predicateOp(RewriterBase &rewriter, Operation *op, + Value pred) { + OpBuilder::InsertionGuard guard(rewriter); + if (mlir::isMemoryEffectFree(op)) + return op; + if (isa(op)) + return op; + if (auto loadOp = dyn_cast(op)) { + rewriter.setInsertionPoint(loadOp); + Value mask = getPredMask(rewriter, loadOp.getPtr().getType(), + loadOp.getMask(), pred); + loadOp.getMaskMutable().assign(mask); + return op; + } + llvm_unreachable("don't know how to predicate this op for intel"); +} + +/// Helper to recursively add dependencies to the same stage. +static void addDep(Operation *op, DenseSet &deps, + bool includeArg = true, + DenseSet *filter = nullptr) { + if (filter && filter->count(op)) + return; + if (!deps.insert(op).second) + return; + for (Value operand : op->getOperands()) { + Value v = operand; + llvm::SmallDenseSet seen; + while (auto arg = v.dyn_cast()) { + if (!includeArg) + break; + if (!seen.insert(v).second) + break; + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + v = yieldOp->getOperand(arg.getArgNumber() - 1); + continue; + } + break; + } + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + addDep(defOp, deps, includeArg, filter); + } + } +} + +// Add operations to the shedule with the given stage based on the filter +// function. +static void addOps(scf::ForOp forOp, int stage, + std::vector> &schedule, + std::function filter) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!filter(&op)) + continue; + schedule.emplace_back(&op, stage); + } +} + +// create the schedule for a matmul loop. This is ad hoc based on how we know +// matmul loops should be pipelined and is not a generic scheduler. +static std::vector> +createSchedule(scf::ForOp forOp, int numStages) { + SmallVector prefetchOps; + SmallVector loadOps; + // Find the prefetch/load ops that will go respectively in stage 0 and stage + // `numStages - 1`. All the other operations will go in stage `numStages - 1`. + for (Operation &op : forOp.getBody()->without_terminator()) { + if (isa(op)) + prefetchOps.emplace_back(&op); + if (isa(op)) + loadOps.emplace_back(&op); + } + DenseSet prefetchAndDeps; + for (Operation *op : prefetchOps) { + addDep(op, prefetchAndDeps, false); + } + + // Find depenencies with distance of 1. + SmallVector distanceOneUsers; + for (Operation *op : prefetchAndDeps) { + for (Value operand : op->getOperands()) { + if (auto arg = operand.dyn_cast()) { + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (defOp) { + distanceOneUsers.push_back(defOp); + } + } + } + } + } + + // For the rest of the ops we can move then into stage 1 so that they can be + // closer to their uses. + DenseSet stage1deps; + for (Operation *op : distanceOneUsers) { + addDep(op, stage1deps, true, &prefetchAndDeps); + } + + DenseSet loadAndDeps; + for (Operation *op : loadOps) { + addDep(op, loadAndDeps, false, &prefetchAndDeps); + } + std::vector> schedule; + + // Schedule some dependencies with distance of 1 into stage 1 to reduce + // pressure. + addOps(forOp, 1, schedule, + [&](Operation *op) { return stage1deps.count(op); }); + + // Then Schedule stage 0. + addOps(forOp, 0, schedule, + [&](Operation *op) { return prefetchAndDeps.count(op); }); + + // Schedule stage `numStage - 1` first. + // Finally schedule the dot ops in stage `numStage - 1` so that they get + // pre-fetched and play well with pretech pass. + addOps(forOp, numStages - 1, schedule, + [&](Operation *op) { return loadAndDeps.count(op); }); + + addOps(forOp, numStages - 1, schedule, [&](Operation *op) { + return prefetchAndDeps.count(op) == 0 && stage1deps.count(op) == 0 && + loadAndDeps.count(op) == 0; + }); + + return schedule; +} + +bool mlir::triton::gpu::intel::preProcessLoopAndGetScheduleIntel( + scf::ForOp &forOp, int numStages, mlir::scf::PipeliningOption &options) { + // 1. First collect "interesting" operations with a stage where to schedule + // them. This gives a coarse scheduling for the loop. + SmallVector loads; + collectOpsToPipeline(forOp, loads); + if (loads.empty()) + return false; + + // 2. Convert the loads into async loads and create the prefetching. + createPrefetchOps(forOp, loads, numStages); + + // 3. Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> schedule = + createSchedule(forOp, numStages); + + // 4. Fill out the pipeline options. + options.getScheduleFn = + [schedule](scf::ForOp forOp, + std::vector> &s) { + s = std::move(schedule); + }; + options.peelEpilogue = false; + options.predicateFn = predicateOp; + options.supportDynamicLoops = true; + options.annotateFn = [](Operation *op, + mlir::scf::PipeliningOption::PipelinerPart part, + unsigned iteration) {}; + + return true; +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/Schedule.h b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/Schedule.h new file mode 100644 index 0000000000..e69b71f964 --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/Schedule.h @@ -0,0 +1,19 @@ +#ifndef TRITON_TRITONINTELGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ +#define TRITON_TRITONINTELGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ + +#include "mlir/Dialect/SCF/Transforms/Transforms.h" + +namespace mlir { +namespace triton { +namespace gpu { +namespace intel { + +bool preProcessLoopAndGetScheduleIntel(scf::ForOp &forOp, int numStages, + mlir::scf::PipeliningOption &options); + +} // namespace intel +} // namespace gpu +} // namespace triton +} // namespace mlir + +#endif // TRITON_TRITONINTELGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/SoftwarePipeliner.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/SoftwarePipeliner.cpp new file mode 100644 index 0000000000..d4c3db359c --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/SoftwarePipeliner.cpp @@ -0,0 +1,77 @@ +#include "Schedule.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "triton/Dialect/TritonIntelGPU/Transforms/Passes.h" + +//===----------------------------------------------------------------------===// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create async operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop. +//===----------------------------------------------------------------------===// + +using namespace mlir; +namespace ttgi = triton::gpu::intel; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" + +static void pipelineLoop(scf::ForOp forOp, int numStages) { + mlir::scf::PipeliningOption options; + // Skip loop with distance > 1 for now. + // TODO: relax the constraint in the expander. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def; + })) + return; + + bool foundSchedule = false; + foundSchedule = mlir::triton::gpu::intel::preProcessLoopAndGetScheduleIntel( + forOp, numStages, options); + + if (!foundSchedule) + return; + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + FailureOr newForOp = + mlir::scf::pipelineForLoop(rewriter, forOp, options); +} + +namespace { +struct IntelGPUPipelinePass + : public TritonIntelGPUPipelineBase { + IntelGPUPipelinePass() = default; + IntelGPUPipelinePass(int numStages, ttgi::DeviceArch arch) { + this->numStages = numStages; + this->deviceArch = arch; + } + + void runOnOperation() override { + if (this->numStages <= 1) + return; + // Only the PVC support the prefetching ops. + if (deviceArch != ttgi::DeviceArch::PVC) + return; + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + for (scf::ForOp forOp : loops) { + pipelineLoop(forOp, numStages); + } + } +}; +} // anonymous namespace + +std::unique_ptr +ttgi::createTritonIntelGPUPipelinePass(int numStages, ttgi::DeviceArch arch) { + return std::make_unique(numStages, arch); +} diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index fe8bb7bc26..eff46c4144 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -48,6 +48,16 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { pm.addPass(mlir::triton::gpu::intel:: createTritonIntelGPURemoveLayoutConversionsPass()); }); + m.def( + "add_tritonintelgpu_pipe_line_pass", + [](mlir::PassManager &self, int numStages, + mlir::triton::gpu::intel::DeviceArch arch) { + self.addPass(mlir::triton::gpu::intel::createTritonIntelGPUPipelinePass( + numStages, arch)); + }, + py::arg("pm"), py::arg("numStages"), + py::arg("arch") = mlir::triton::gpu::intel::DeviceArch::UNKNOWN); + } void init_triton_intel_passes_ttnvgpuir(py::module &&m) {