From aca33d2dbf0c6ff83f84ca34e1d15eab8e62ddd7 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Thu, 30 May 2024 10:44:16 -0400 Subject: [PATCH] Lower a block pointer load to address payload create/set/load operations rather than 2DBlockRead (#1209) Rather than lowering `TritonGEN::Matrix2DBlockLoadOp` to `GenISA.LSC2DBlockRead` calls generate calls to (a) create an address payload, (b) update the payload block coordinates, (c) load the memory pointed to by the address payload pointer: ``` addr_payload = llvm.call @__builtin_IB_subgroup_createBlock2DAddressPayload(to_long(ptr), ...) ... : (i64, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.ptr llvm.call @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockX(addr_payload, ...) ... : (!llvm.ptr, i32) -> () llvm.call @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockY(addr_payload, ...) ... : (!llvm.ptr, i32) -> () load_A = llvm.call @__builtin_IB_subgroup_block_read_ap_u16_m8k16v1(%379, 0, 0, 0) : (!llvm.ptr, i32, i32, i32) -> vector<8xi16> ``` Note: this PR generates the code sequence above in place (where the TritonGen::Matrix2DBlockLoadOp is). A subsequent PR will hoist invariant calls out of loops if applicable. --------- Signed-off-by: Tiotto, Ettore --- include/triton/Tools/Sys/GetEnv.hpp | 1 + .../TritonGEN/tritongen-addr-payload-opt.mlir | 61 ++++++++++++++ .../TritonGENToLLVM/TritonGENToLLVMPass.cpp | 82 ++++++++++++++++++- 3 files changed, 141 insertions(+), 3 deletions(-) create mode 100644 test/TritonGEN/tritongen-addr-payload-opt.mlir diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index a350dfb63a..c62a0ab0c8 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -29,6 +29,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "USE_TTGIR_LOC", "NVPTX_ENABLE_DUMP", "TRITON_INTEL_ENABLE_BLOCK_PTR", + "TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT", // clang-format on }; diff --git a/test/TritonGEN/tritongen-addr-payload-opt.mlir b/test/TritonGEN/tritongen-addr-payload-opt.mlir new file mode 100644 index 0000000000..84eabe17db --- /dev/null +++ b/test/TritonGEN/tritongen-addr-payload-opt.mlir @@ -0,0 +1,61 @@ +// RUN: TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT=1 triton-opt %s --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], A = [8, 16], B = [16, 16], C = [8, 16]}> +#dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}> +#dot1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}> + +// COM: Test that, instead of 2D block reads, the compiler generates address payload create/set/load builtins. +// CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {passthrough = ["convergent"]} +// CHECK-DAG: llvm.func spir_funccc @__builtin_IB_subgroup_block_read_ap_transform_u16_m16k16v1(!llvm.ptr, i32, i32, i32) -> vector<8xi32> attributes {passthrough = ["convergent"]} +// CHECK-DAG: llvm.func spir_funccc @__builtin_IB_subgroup_block_read_ap_u16_m8k16v1(!llvm.ptr, i32, i32, i32) -> vector<8xi16> attributes {passthrough = ["convergent"]} +// CHECK-DAG: llvm.func spir_funccc @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockY(!llvm.ptr, i32) attributes {passthrough = ["convergent"]} +// CHECK-DAG: llvm.func spir_funccc @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockX(!llvm.ptr, i32) attributes {passthrough = ["convergent"]} +// CHECK-DAG: llvm.func spir_funccc @__builtin_IB_subgroup_createBlock2DAddressPayload(i64, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.ptr attributes {passthrough = ["convergent"]} + +module attributes {"triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 33792 : i32, triton_gpu.target = "xpu:DEVICE_ARCH.PVC", "triton_gpu.threads-per-warp" = 16 : i32} { + tt.func public @matmul_kernel_with_addr_payload_opt(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: i64) { + // CHECK-LABEL: @matmul_kernel_with_addr_payload_opt + // CHECK: [[CMP:%.*]] = llvm.icmp "slt" {{.*}}, %arg4 : i64 + // CHECK: llvm.cond_br [[CMP]], ^bb2, ^bb3 + // CHECK: ^bb2: + // CHECK: [[PTRTOINT_1:%.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64 + // CHECK: [[ADDR_PAYLOAD_1:%.*]] = llvm.call @__builtin_IB_subgroup_createBlock2DAddressPayload([[PTRTOINT_1]], {{.*}}) {passthrough = ["convergent"]} : (i64, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.ptr + // CHECK-DAG: llvm.call @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockX([[ADDR_PAYLOAD_1]], {{.*}}) {passthrough = ["convergent"]} : (!llvm.ptr, i32) -> () + // CHECK-DAG: llvm.call @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockY([[ADDR_PAYLOAD_1]], {{.*}}) {passthrough = ["convergent"]} : (!llvm.ptr, i32) -> () + // CHECK: [[ZERO_1:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @__builtin_IB_subgroup_block_read_ap_u16_m8k16v1([[ADDR_PAYLOAD_1]], [[ZERO_1]], [[ZERO_1]], [[ZERO_1]]) {passthrough = ["convergent"]} : (!llvm.ptr, i32, i32, i32) -> vector<8xi16> + // + // CHECK: [[PTRTOINT_2:%.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64 + // CHECK: [[ADDR_PAYLOAD_2:%.*]] = llvm.call @__builtin_IB_subgroup_createBlock2DAddressPayload([[PTRTOINT_2]], {{.*}}) {passthrough = ["convergent"]} : (i64, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.ptr + // CHECK-DAG: llvm.call @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockX([[ADDR_PAYLOAD_2]], {{.*}}) {passthrough = ["convergent"]} : (!llvm.ptr, i32) -> () + // CHECK-DAG: llvm.call @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockY([[ADDR_PAYLOAD_2]], {{.*}}) {passthrough = ["convergent"]} : (!llvm.ptr, i32) -> () + // CHECK: [[ZERO_2:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @__builtin_IB_subgroup_block_read_ap_transform_u16_m16k16v1([[ADDR_PAYLOAD_2]], [[ZERO_2]], [[ZERO_2]], [[ZERO_2]]) {passthrough = ["convergent"]} : (!llvm.ptr, i32, i32, i32) -> vector<8xi32> + // CHECK: ^bb3: + // CHECK: llvm.return + + %cst = arith.constant dense<0.000000e+00> : tensor<8x8xf32, #mma> + %c32_i32 = arith.constant 32 : i32 + %c32_i64 = arith.constant 32 : i64 + %c0_i32 = arith.constant 0 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %18 = tt.make_tensor_ptr %arg0, [%arg3, %arg5], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %22 = tt.make_tensor_ptr %arg1, [%arg5, %arg4], [%arg7, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + cf.br ^bb1(%c0_i64, %cst, %18, %22 : i64, tensor<8x8xf32, #mma>, !tt.ptr>, !tt.ptr>) + ^bb1(%23: i64, %24: tensor<8x8xf32, #mma>, %25: !tt.ptr>, %26: !tt.ptr>): // 2 preds: ^bb0, ^bb2 + %27 = arith.cmpi slt, %23, %arg5 : i64 + cf.cond_br %27, ^bb2, ^bb3 + ^bb2: // pred: ^bb1 + %28 = tt.load %25 {boundaryCheck = array} : !tt.ptr> + %29 = tt.load %26 {boundaryCheck = array} : !tt.ptr> + %30 = tt.dot %28, %29, %24, inputPrecision = tf32 : tensor<8x16xf16, #dot0> * tensor<16x8xf16, #dot1> -> tensor<8x8xf32, #mma> + %31 = tt.advance %25, [%c0_i32, %c32_i32] : > + %32 = tt.advance %26, [%c32_i32, %c0_i32] : > + %33 = arith.addi %23, %c32_i64 : i64 + cf.br ^bb1(%33, %30, %31, %32 : i64, tensor<8x8xf32, #mma>, !tt.ptr>, !tt.ptr>) + ^bb3: // pred: ^bb1 + tt.return + } +} diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index 0531a34cf7..8213300fa1 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -7,9 +7,11 @@ //===----------------------------------------------------------------------===// #include "intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h" -#include "TritonGENToLLVM/GenIntrinsicEnum.h" +#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "intel/include/TritonGENToLLVM/GenIntrinsics.h" +#include "TritonGENToLLVM/GenIntrinsicEnum.h" + #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" @@ -28,11 +30,11 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Intrinsics.h" #include "llvm/Support/ErrorHandling.h" +#include #include -#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" - #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" namespace mlir::triton { #define GEN_PASS_DEF_CONVERTTRITONGENTOLLVM @@ -336,6 +338,73 @@ createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op, return rewriter.create(loc, funcOp, args); } +// FIXME: This is a temporary solution. Remove once IGC can update the address +// payload. +static LLVM::CallOp +createBlock2DReadWithAddressPayloadUpdate(TritonGEN::Matrix2DBlockLoadOp op, + ConversionPatternRewriter &rewriter) { + MLIRContext *context = rewriter.getContext(); + Type resType = op->getResultTypes()[0]; + Location loc = op->getLoc(); + + auto createBlock2DAddressPayload = [&](TritonGEN::Matrix2DBlockLoadOp op) { + SmallVector argTypes{i64_ty, i32_ty, i32_ty, i32_ty, i32_ty, + i32_ty, i32_ty, i32_ty, i32_ty}; + Value zero = i32_val(0); + Value one = i32_val(1); + SmallVector args{ptrtoint(i64_ty, op.getPtr()), + sub(op.getBaseWidth(), one), + sub(op.getBaseHeight(), one), + sub(op.getBasePitch(), one), + zero, + zero, + i32_val(op.getTileWidth()), + i32_val(op.getTileHeight()), + i32_val(op.getVBlocks())}; + LLVM::CallOp callOp = createDeviceFunctionCall( + rewriter, "__builtin_IB_subgroup_createBlock2DAddressPayload", + ptr_ty(context), argTypes, args, true /*convergent*/); + return callOp.getResult(); + }; + + auto setBlock2DAddressPayload = [&](Value ptr, + TritonGEN::Matrix2DBlockLoadOp op) { + assert(isa(ptr.getType()) && + "Expecting a pointer type"); + SmallVector argTypes{ptr.getType(), i32_ty}; + createDeviceFunctionCall( + rewriter, "__builtin_IB_subgroup_setBlock2DAddressPayloadBlockX", + LLVM::LLVMVoidType::get(context), argTypes, {ptr, op.getX()}, + true /*convergent*/); + createDeviceFunctionCall( + rewriter, "__builtin_IB_subgroup_setBlock2DAddressPayloadBlockY", + LLVM::LLVMVoidType::get(context), argTypes, {ptr, op.getY()}, + true /*convergent*/); + }; + + auto createBlock2DRead = [&](Value ptr, TritonGEN::Matrix2DBlockLoadOp op) { + assert(isa(ptr.getType()) && + "Expecting a pointer type"); + + std::string fnName = "__builtin_IB_subgroup_block_read_ap_"; + if (op.getVnniTransform()) + fnName += "transform_"; + fnName += "u" + std::to_string(op.getElemSizeInBits()) + "_m" + + std::to_string(op.getTileHeight()) + "k" + + std::to_string(op.getTileWidth()) + "v" + + std::to_string(op.getVBlocks()); + Value zero = i32_val(0); + SmallVector argTypes{ptr.getType(), i32_ty, i32_ty, i32_ty}; + SmallVector args{ptr, zero, zero, zero}; + return createDeviceFunctionCall(rewriter, fnName, resType, argTypes, args, + true /*convergent*/); + }; + + Value ptr = createBlock2DAddressPayload(op); + setBlock2DAddressPayload(ptr, op); + return createBlock2DRead(ptr, op); +} + static LLVM::CallOp createGenISA2DBlockWrite(TritonGEN::Matrix2DBlockStoreOp op, ConversionPatternRewriter &rewriter) { @@ -889,6 +958,13 @@ struct TritonMatrix2DBlockLoadLowering LogicalResult matchAndRewrite(TritonGEN::Matrix2DBlockLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (tools::getBoolEnv("TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT")) { + LLVM::CallOp callOp = + createBlock2DReadWithAddressPayloadUpdate(op, rewriter); + rewriter.replaceOp(op, callOp); + return success(); + } + LLVM::CallOp callOp = createGenISA2DBlockRead(op, rewriter); rewriter.replaceOp(op, callOp); return success();