Skip to content

Commit

Permalink
Lower a block pointer load to address payload create/set/load operati…
Browse files Browse the repository at this point in the history
…ons rather than 2DBlockRead (#1209)

Rather than lowering `TritonGEN::Matrix2DBlockLoadOp` to
`GenISA.LSC2DBlockRead` calls generate calls to (a) create an address
payload, (b) update the payload block coordinates, (c) load the memory
pointed to by the address payload pointer:
```
    addr_payload = llvm.call @__builtin_IB_subgroup_createBlock2DAddressPayload(to_long(ptr), ...) ... : (i64, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.ptr
    llvm.call @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockX(addr_payload, ...) ... : (!llvm.ptr, i32) -> ()
    llvm.call @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockY(addr_payload, ...) ... : (!llvm.ptr, i32) -> ()
    load_A = llvm.call @__builtin_IB_subgroup_block_read_ap_u16_m8k16v1(%379, 0, 0, 0)  : (!llvm.ptr, i32, i32, i32) -> vector<8xi16>
```
Note: this PR generates the code sequence above in place (where the
TritonGen::Matrix2DBlockLoadOp is). A subsequent PR will hoist invariant
calls out of loops if applicable.

---------

Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
  • Loading branch information
etiotto authored May 30, 2024
1 parent a84b0c9 commit aca33d2
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 3 deletions.
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"USE_TTGIR_LOC",
"NVPTX_ENABLE_DUMP",
"TRITON_INTEL_ENABLE_BLOCK_PTR",
"TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT",
// clang-format on
};

Expand Down
61 changes: 61 additions & 0 deletions test/TritonGEN/tritongen-addr-payload-opt.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// RUN: TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT=1 triton-opt %s --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}>
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], A = [8, 16], B = [16, 16], C = [8, 16]}>
#dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>
#dot1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>

// COM: Test that, instead of 2D block reads, the compiler generates address payload create/set/load builtins.
// CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @__builtin_IB_subgroup_block_read_ap_transform_u16_m16k16v1(!llvm.ptr, i32, i32, i32) -> vector<8xi32> attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @__builtin_IB_subgroup_block_read_ap_u16_m8k16v1(!llvm.ptr, i32, i32, i32) -> vector<8xi16> attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockY(!llvm.ptr, i32) attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockX(!llvm.ptr, i32) attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @__builtin_IB_subgroup_createBlock2DAddressPayload(i64, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.ptr attributes {passthrough = ["convergent"]}

module attributes {"triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 33792 : i32, triton_gpu.target = "xpu:DEVICE_ARCH.PVC", "triton_gpu.threads-per-warp" = 16 : i32} {
tt.func public @matmul_kernel_with_addr_payload_opt(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: i64) {
// CHECK-LABEL: @matmul_kernel_with_addr_payload_opt
// CHECK: [[CMP:%.*]] = llvm.icmp "slt" {{.*}}, %arg4 : i64
// CHECK: llvm.cond_br [[CMP]], ^bb2, ^bb3
// CHECK: ^bb2:
// CHECK: [[PTRTOINT_1:%.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64
// CHECK: [[ADDR_PAYLOAD_1:%.*]] = llvm.call @__builtin_IB_subgroup_createBlock2DAddressPayload([[PTRTOINT_1]], {{.*}}) {passthrough = ["convergent"]} : (i64, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.ptr
// CHECK-DAG: llvm.call @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockX([[ADDR_PAYLOAD_1]], {{.*}}) {passthrough = ["convergent"]} : (!llvm.ptr, i32) -> ()
// CHECK-DAG: llvm.call @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockY([[ADDR_PAYLOAD_1]], {{.*}}) {passthrough = ["convergent"]} : (!llvm.ptr, i32) -> ()
// CHECK: [[ZERO_1:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @__builtin_IB_subgroup_block_read_ap_u16_m8k16v1([[ADDR_PAYLOAD_1]], [[ZERO_1]], [[ZERO_1]], [[ZERO_1]]) {passthrough = ["convergent"]} : (!llvm.ptr, i32, i32, i32) -> vector<8xi16>
//
// CHECK: [[PTRTOINT_2:%.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64
// CHECK: [[ADDR_PAYLOAD_2:%.*]] = llvm.call @__builtin_IB_subgroup_createBlock2DAddressPayload([[PTRTOINT_2]], {{.*}}) {passthrough = ["convergent"]} : (i64, i32, i32, i32, i32, i32, i32, i32, i32) -> !llvm.ptr
// CHECK-DAG: llvm.call @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockX([[ADDR_PAYLOAD_2]], {{.*}}) {passthrough = ["convergent"]} : (!llvm.ptr, i32) -> ()
// CHECK-DAG: llvm.call @__builtin_IB_subgroup_setBlock2DAddressPayloadBlockY([[ADDR_PAYLOAD_2]], {{.*}}) {passthrough = ["convergent"]} : (!llvm.ptr, i32) -> ()
// CHECK: [[ZERO_2:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @__builtin_IB_subgroup_block_read_ap_transform_u16_m16k16v1([[ADDR_PAYLOAD_2]], [[ZERO_2]], [[ZERO_2]], [[ZERO_2]]) {passthrough = ["convergent"]} : (!llvm.ptr, i32, i32, i32) -> vector<8xi32>
// CHECK: ^bb3:
// CHECK: llvm.return

%cst = arith.constant dense<0.000000e+00> : tensor<8x8xf32, #mma>
%c32_i32 = arith.constant 32 : i32
%c32_i64 = arith.constant 32 : i64
%c0_i32 = arith.constant 0 : i32
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%18 = tt.make_tensor_ptr %arg0, [%arg3, %arg5], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x16xf16, #dot0>>
%22 = tt.make_tensor_ptr %arg1, [%arg5, %arg4], [%arg7, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x8xf16, #dot1>>
cf.br ^bb1(%c0_i64, %cst, %18, %22 : i64, tensor<8x8xf32, #mma>, !tt.ptr<tensor<8x16xf16, #dot0>>, !tt.ptr<tensor<16x8xf16, #dot1>>)
^bb1(%23: i64, %24: tensor<8x8xf32, #mma>, %25: !tt.ptr<tensor<8x16xf16, #dot0>>, %26: !tt.ptr<tensor<16x8xf16, #dot1>>): // 2 preds: ^bb0, ^bb2
%27 = arith.cmpi slt, %23, %arg5 : i64
cf.cond_br %27, ^bb2, ^bb3
^bb2: // pred: ^bb1
%28 = tt.load %25 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x16xf16, #dot0>>
%29 = tt.load %26 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x8xf16, #dot1>>
%30 = tt.dot %28, %29, %24, inputPrecision = tf32 : tensor<8x16xf16, #dot0> * tensor<16x8xf16, #dot1> -> tensor<8x8xf32, #mma>
%31 = tt.advance %25, [%c0_i32, %c32_i32] : <tensor<8x16xf16, #dot0>>
%32 = tt.advance %26, [%c32_i32, %c0_i32] : <tensor<16x8xf16, #dot1>>
%33 = arith.addi %23, %c32_i64 : i64
cf.br ^bb1(%33, %30, %31, %32 : i64, tensor<8x8xf32, #mma>, !tt.ptr<tensor<8x16xf16, #dot0>>, !tt.ptr<tensor<16x8xf16, #dot1>>)
^bb3: // pred: ^bb1
tt.return
}
}
82 changes: 79 additions & 3 deletions third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
//===----------------------------------------------------------------------===//

#include "intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h"
#include "TritonGENToLLVM/GenIntrinsicEnum.h"
#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"
#include "intel/include/TritonGENToLLVM/GenIntrinsics.h"

#include "TritonGENToLLVM/GenIntrinsicEnum.h"

#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
Expand All @@ -28,11 +30,11 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/Support/ErrorHandling.h"
#include <string>
#include <type_traits>

#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"

#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Tools/Sys/GetEnv.hpp"

namespace mlir::triton {
#define GEN_PASS_DEF_CONVERTTRITONGENTOLLVM
Expand Down Expand Up @@ -336,6 +338,73 @@ createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op,
return rewriter.create<LLVM::CallOp>(loc, funcOp, args);
}

// FIXME: This is a temporary solution. Remove once IGC can update the address
// payload.
static LLVM::CallOp
createBlock2DReadWithAddressPayloadUpdate(TritonGEN::Matrix2DBlockLoadOp op,
ConversionPatternRewriter &rewriter) {
MLIRContext *context = rewriter.getContext();
Type resType = op->getResultTypes()[0];
Location loc = op->getLoc();

auto createBlock2DAddressPayload = [&](TritonGEN::Matrix2DBlockLoadOp op) {
SmallVector<Type> argTypes{i64_ty, i32_ty, i32_ty, i32_ty, i32_ty,
i32_ty, i32_ty, i32_ty, i32_ty};
Value zero = i32_val(0);
Value one = i32_val(1);
SmallVector<Value> args{ptrtoint(i64_ty, op.getPtr()),
sub(op.getBaseWidth(), one),
sub(op.getBaseHeight(), one),
sub(op.getBasePitch(), one),
zero,
zero,
i32_val(op.getTileWidth()),
i32_val(op.getTileHeight()),
i32_val(op.getVBlocks())};
LLVM::CallOp callOp = createDeviceFunctionCall(
rewriter, "__builtin_IB_subgroup_createBlock2DAddressPayload",
ptr_ty(context), argTypes, args, true /*convergent*/);
return callOp.getResult();
};

auto setBlock2DAddressPayload = [&](Value ptr,
TritonGEN::Matrix2DBlockLoadOp op) {
assert(isa<LLVM::LLVMPointerType>(ptr.getType()) &&
"Expecting a pointer type");
SmallVector<Type> argTypes{ptr.getType(), i32_ty};
createDeviceFunctionCall(
rewriter, "__builtin_IB_subgroup_setBlock2DAddressPayloadBlockX",
LLVM::LLVMVoidType::get(context), argTypes, {ptr, op.getX()},
true /*convergent*/);
createDeviceFunctionCall(
rewriter, "__builtin_IB_subgroup_setBlock2DAddressPayloadBlockY",
LLVM::LLVMVoidType::get(context), argTypes, {ptr, op.getY()},
true /*convergent*/);
};

auto createBlock2DRead = [&](Value ptr, TritonGEN::Matrix2DBlockLoadOp op) {
assert(isa<LLVM::LLVMPointerType>(ptr.getType()) &&
"Expecting a pointer type");

std::string fnName = "__builtin_IB_subgroup_block_read_ap_";
if (op.getVnniTransform())
fnName += "transform_";
fnName += "u" + std::to_string(op.getElemSizeInBits()) + "_m" +
std::to_string(op.getTileHeight()) + "k" +
std::to_string(op.getTileWidth()) + "v" +
std::to_string(op.getVBlocks());
Value zero = i32_val(0);
SmallVector<Type> argTypes{ptr.getType(), i32_ty, i32_ty, i32_ty};
SmallVector<Value> args{ptr, zero, zero, zero};
return createDeviceFunctionCall(rewriter, fnName, resType, argTypes, args,
true /*convergent*/);
};

Value ptr = createBlock2DAddressPayload(op);
setBlock2DAddressPayload(ptr, op);
return createBlock2DRead(ptr, op);
}

static LLVM::CallOp
createGenISA2DBlockWrite(TritonGEN::Matrix2DBlockStoreOp op,
ConversionPatternRewriter &rewriter) {
Expand Down Expand Up @@ -889,6 +958,13 @@ struct TritonMatrix2DBlockLoadLowering
LogicalResult
matchAndRewrite(TritonGEN::Matrix2DBlockLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (tools::getBoolEnv("TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT")) {
LLVM::CallOp callOp =
createBlock2DReadWithAddressPayloadUpdate(op, rewriter);
rewriter.replaceOp(op, callOp);
return success();
}

LLVM::CallOp callOp = createGenISA2DBlockRead(op, rewriter);
rewriter.replaceOp(op, callOp);
return success();
Expand Down

0 comments on commit aca33d2

Please sign in to comment.