Skip to content

Commit

Permalink
[GEN] Use OCL builtin for some variants of 2D block read (#1041)
Browse files Browse the repository at this point in the history
```
// 2x ATile (MxK) block read:
ushort2 intel_subgroup_block_read_u8_m1k32v2(
    __global void* base_address,
    int width, int height, int pitch, int2 byte_coord);
ushort4 intel_subgroup_block_read_u8_m2k32v2(
    __global void* base_address,
    int width, int height, int pitch, int2 byte_coord);
ushort8 intel_subgroup_block_read_u8_m4k32v2(
    __global void* base_address,
    int width, int height, int pitch, int2 byte_coord);
ushort16 intel_subgroup_block_read_u8_m8k32v2(
    __global void* base_address,
    int width, int height, int pitch, int2 byte_coord);
ushort2 intel_subgroup_block_read_u16_m1k16v2(
    __global void* base_address,
    int width, int height, int pitch, int2 byte_coord);
ushort4 intel_subgroup_block_read_u16_m2k16v2(
    __global void* base_address,
    int width, int height, int pitch, int2 byte_coord);
ushort8 intel_subgroup_block_read_u16_m4k16v2(
    __global void* base_address,
    int width, int height, int pitch, int2 byte_coord);
ushort16 intel_subgroup_block_read_u16_m8k16v2(
    __global void* base_address,
    int width, int height, int pitch, int2 byte_coord);
```

---------

Signed-off-by: Whitney Tsang <whitney.tsang@intel.com>
  • Loading branch information
whitneywhtsang authored May 6, 2024
1 parent a6fcc09 commit 26f7cf0
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
16 changes: 16 additions & 0 deletions test/TritonGEN/tritongen-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,22 @@ llvm.func @triton_gen.dpas.f32(%c : vector<8xf32>, %a : vector<4xf32>, %b : vect

// -----

// CHECK: llvm.func spir_funccc @intel_subgroup_block_read_u8_m8k32v2(!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> vector<16xi16> attributes {passthrough = ["convergent"]}

llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
// CHECK: llvm.func @triton_gen.2Dblockload(%arg0: !llvm.ptr<1>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) {
// CHECK-DAG: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-DAG: [[ONE:%.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: [[UNDEF:%.*]] = llvm.mlir.undef : vector<2xi32>
// CHECK-NEXT: [[COORD0:%.*]] = llvm.insertelement %arg4, [[UNDEF]][[[ZERO]] : i32] : vector<2xi32>
// CHECK-NEXT: [[COORD1:%.*]] = llvm.insertelement %arg5, [[COORD0]][[[ONE]] : i32] : vector<2xi32>
// CHECK-NEXT: llvm.call @intel_subgroup_block_read_u8_m8k32v2(%arg0, %arg1, %arg2, %arg3, [[COORD1]]) {passthrough = ["convergent"]} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> vector<16xi16>
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16>
llvm.return
}

// -----

// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8f32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xf32>

llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
Expand Down
37 changes: 36 additions & 1 deletion third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,49 @@ static LLVM::CallOp createGenISADPAS(TritonGEN::MatrixDPASOp op,
return rewriter.create<LLVM::CallOp>(loc, funcOp, args);
}

static bool isOCLBuiltinAvailable(TritonGEN::Matrix2DBlockLoadOp op) {
if (op.getVnniTransform() || op.getTranspose())
return false;

if (op.getElemSizeInBits() == 32)
return false;

if (op.getTileHeight() > 8)
return false;

if (op.getVBlocks() != 2)
return false;

return true;
}

static LLVM::CallOp
createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op,
ConversionPatternRewriter &rewriter) {
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
MLIRContext *context = rewriter.getContext();
Type resType = op->getResultTypes()[0];
Location loc = op->getLoc();

// FIXME: Use the OpenCL API also for all other variants.
if (isOCLBuiltinAvailable(op)) {
std::string fnName = "intel_subgroup_block_read_u" +
std::to_string(op.getElemSizeInBits()) + "_m" +
std::to_string(op.getTileHeight()) + "k" +
std::to_string(op.getTileWidth()) + "v" +
std::to_string(op.getVBlocks());
VectorType vecType = vec_ty(i32_ty, 2);
Value byteCoord = insert_element(
vecType, insert_element(vecType, undef(vecType), op.getX(), i32_val(0)),
op.getY(), i32_val(1));
SmallVector<Type> argTypes{ptr_ty(context, 1), i32_ty, i32_ty, i32_ty,
vecType};
SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
op.getBasePitch(), byteCoord};
return createDeviceFunctionCall(rewriter, fnName, resType, argTypes, args,
true /*convergent*/);
}

auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
Value ptr = op.getPtr();
Value baseWidth = op.getBaseWidth();
Value baseHeight = op.getBaseHeight();
Expand Down

0 comments on commit 26f7cf0

Please sign in to comment.