Skip to content

Commit

Permalink
Change the 2D load to dense stride and using large size to load packe…
Browse files Browse the repository at this point in the history
…d A and B operands.

Need to change the convert layout and emit index as well for the dot operands layout with DPAS as parent.
  • Loading branch information
chengjunlu committed Apr 29, 2024
1 parent 90e1152 commit ec24ea5
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 50 deletions.
10 changes: 5 additions & 5 deletions test/TritonIntelGPU/intel-2d-load-to-llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm -canonicalize | FileCheck %s


// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi32>
// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8i16(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi16>
// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<32xi32>
// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32i16(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<32xi16>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [1, 0]}>
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], A = [8, 16], B = [16, 16], C = [8, 16]}>
module attributes {"triton_gpu.compute-capability" = 2 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
Expand All @@ -23,9 +23,9 @@ module attributes {"triton_gpu.compute-capability" = 2 : i32, "triton_gpu.num-ct
%6 = tt.make_tensor_ptr %arg1, [%1, %4], [%5, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>>
%7 = tt.advance %3, [%c64_i32, %c-32_i32] : <tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>>
%8 = tt.advance %7, [%c-64_i32, %c32_i32] : <tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>>
// CHECK-COUNT-2: llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v8i16({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi16>
// CHECK-COUNT: llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v32i16({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<32xi16>
%9 = triton_intel_gpu.load_2d %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
// CHECK-COUNT-2: llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v8i32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi32>
// CHECK-COUNT: llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v32i32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<32xi32>
%10 = triton_intel_gpu.load_2d %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>
%11 = tt.dot %9, %10, %cst, inputPrecision = tf32 : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<64x64xf32, #mma>
%12 = triton_gpu.convert_layout %11 : tensor<64x64xf32, #mma> -> tensor<64x64xf32, #blocked>
Expand Down
168 changes: 123 additions & 45 deletions third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ struct Load2DOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::intel::Load2DOp>::ConvertTritonGPUOpToLLVMPattern;

using ValueTable = std::map<std::pair<int, int>, Value>;

Load2DOpConversion(TritonGPUToLLVMTypeConverter &converter,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::intel::Load2DOp>(converter,
Expand Down Expand Up @@ -523,39 +525,83 @@ struct Load2DOpConversion
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);

Type load2DGenXType;
Type unpackType;
int64_t elemsPerLane;
int64_t numRepOuter = numReps[opIdx];
int64_t numRepK = numReps[(opIdx == 0) ? 1 : 0];
int64_t opaqueElemPerLane;
unsigned tileHeight;
unsigned elemsPerLanePerDotOp;
unsigned vBlocks = 1;
unsigned packedOuterDimPerLoad = 1;
unsigned packedKDimPerLoad = 1;
SmallVector<int64_t> elemsPerInstr;
Type packedElemType =
opIdx == 0 ? type::i16Ty(ctx) : type::i32Ty(ctx);
if (opIdx == 0) {
auto shapeA = dpasLayout.getShapeA();
elemsPerInstr = {shapeA[0], shapeA[1]};
elemsPerLane = product<int64_t>(elemsPerInstr) /
product<unsigned>(getThreadsPerWarp(dpasLayout));
unpackType = LLVM::getFixedVectorType(
typeConverter->convertType(eltTy), elemsPerLane);
elemsPerLanePerDotOp =
product<int64_t>(elemsPerInstr) /
product<unsigned>(getThreadsPerWarp(dpasLayout));

unsigned maxPackedOuterDimPerLoad = 32 / elemsPerInstr[0];
packedOuterDimPerLoad =
std::min<unsigned>(maxPackedOuterDimPerLoad, numRepOuter);
// use the tileHeight to load multiple operand A in one time.
tileHeight = elemsPerInstr[0] * packedOuterDimPerLoad;

if (numRepK >= 2) {
// Double the block array length 2 to load operand A.
vBlocks = 2;
packedKDimPerLoad *= 2;
} else {
vBlocks = 1;
}

// pack scalar to i16.
auto opsPerChannel = dpasLayout.getOpsPerChannel();
elemsPerLane = opsPerChannel == 4 ? elemsPerLane / 2 : elemsPerLane;
load2DGenXType =
LLVM::getFixedVectorType(type::i16Ty(ctx), elemsPerLane);

opaqueElemPerLane = opsPerChannel == 4 ? elemsPerLanePerDotOp / 2
: elemsPerLanePerDotOp;
opaqueElemPerLane =
opaqueElemPerLane * packedOuterDimPerLoad * packedKDimPerLoad;
} else {
auto shapeB = dpasLayout.getShapeB();
elemsPerInstr = {shapeB[0], shapeB[1]};
elemsPerLane = product<int64_t>(elemsPerInstr) /
product<unsigned>(getThreadsPerWarp(dpasLayout));
unpackType = LLVM::getFixedVectorType(
typeConverter->convertType(eltTy), elemsPerLane);
elemsPerLanePerDotOp =
product<int64_t>(elemsPerInstr) /
product<unsigned>(getThreadsPerWarp(dpasLayout));

if (numRepOuter >= 2) {
// Double the block array length to 2 to load operand B.
vBlocks = 2;
packedOuterDimPerLoad *= 2;
} else {
vBlocks = 1;
}

// pack scalar to i32 for load.
if (numRepK >= 2) {
// Double tileHeight to load operand B.
tileHeight = elemsPerInstr[0] * 2;
packedKDimPerLoad *= 2;
} else {
tileHeight = elemsPerInstr[0];
}

// pack scalar to i32.
auto opsPerChannel = dpasLayout.getOpsPerChannel();
elemsPerLane = elemsPerLane / opsPerChannel;
load2DGenXType =
LLVM::getFixedVectorType(type::i32Ty(ctx), elemsPerLane);
opaqueElemPerLane = (elemsPerLanePerDotOp / opsPerChannel);
opaqueElemPerLane =
opaqueElemPerLane * packedOuterDimPerLoad * packedKDimPerLoad;
}

Type load2DGenXType =
LLVM::getFixedVectorType(packedElemType, opaqueElemPerLane);
Type decomposedType = LLVM::getFixedVectorType(
packedElemType,
opaqueElemPerLane / packedOuterDimPerLoad / packedKDimPerLoad);
Type unpackType = LLVM::getFixedVectorType(
typeConverter->convertType(eltTy), elemsPerLanePerDotOp);

// Load the operand.
// Outer dim, A is the M, B is the N. Inner dim, the K
int outerDimWarpNum =
std::min<int>(warpsPerCTA[opIdx],
Expand All @@ -570,26 +616,26 @@ struct Load2DOpConversion
colStride, base) =
getValuesFromBlockPointerStruct(blockPtr, rewriter);

// Load the operand.
int64_t numRepOuter = numReps[opIdx];
int64_t numRepK = numReps[(opIdx == 0) ? 1 : 0];
// A dense stride for the replicates.
unsigned repOuterStride = elemsPerInstr[opIdx];
unsigned warpOuterStride = elemsPerInstr[opIdx] * numRepOuter;
unsigned repKStride = elemsPerInstr[opIdx == 0 ? 1 : 0];

SmallVector<Value> rets;
for (int outer = 0; outer < numRepOuter; ++outer) {
for (int k = 0; k < numRepK; ++k) {
ValueTable loadVals;
for (int outer = 0; outer < numRepOuter;
outer += packedOuterDimPerLoad) {
for (int k = 0; k < numRepK; k += packedKDimPerLoad) {
Value offsetX, offsetY;
if (opIdx == 0) {
// A
offsetY = add(
mul(outerDimWarpId, i32_val(elemsPerInstr[opIdx])),
i32_val(outer * outerDimWarpNum * elemsPerInstr[opIdx]));
offsetX = i32_val(k * elemsPerInstr[1]);
offsetY = add(mul(outerDimWarpId, i32_val(warpOuterStride)),
i32_val(outer * repOuterStride));
offsetX = i32_val(k * repKStride);
} else {
// B
offsetX = add(
mul(outerDimWarpId, i32_val(elemsPerInstr[opIdx])),
i32_val(outer * outerDimWarpNum * elemsPerInstr[opIdx]));
offsetY = i32_val(k * elemsPerInstr[0]);
offsetX = add(mul(outerDimWarpId, i32_val(warpOuterStride)),
i32_val(outer * repOuterStride));
offsetY = i32_val(k * repKStride);
}
offsetX = add(offsetX, offsetBaseX);
offsetY = add(offsetY, offsetBaseY);
Expand All @@ -616,32 +662,64 @@ struct Load2DOpConversion
elemsPerInstr[1]),
/*tile_height*/
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
elemsPerInstr[0]),
tileHeight),
/*v_blocks*/
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), 1),
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
vBlocks),
/*transpose*/
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 1), 0),
/*vnni_transform*/
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 1),
opIdx == 0 ? /*A vnni=false*/ 0
: /*B vnni=true*/ 1));
Value loadVal = bitcast(load2dOp, unpackType);
rets.push_back(loadVal);

unsigned packedRowNum =
opIdx == 0 ? packedOuterDimPerLoad : packedKDimPerLoad;
unsigned packedColNum =
opIdx == 0 ? packedKDimPerLoad : packedOuterDimPerLoad;
unsigned offset = 0;
// The packed load is contiguous on the row.
for (int col = 0; col < packedColNum; col++) {
for (int row = 0; row < packedRowNum; row++) {

Value loadVal = undef(decomposedType);
for (int elemIdx = 0;
elemIdx < opaqueElemPerLane / packedOuterDimPerLoad /
packedKDimPerLoad;
elemIdx++) {
Value loaded = extract_element(load2dOp, i32_val(offset++));
loadVal = insert_element(loadVal, loaded, i32_val(elemIdx));
}

// Save the unpacked vals to the map;
if (opIdx == 0) {
loadVals[{outer + row, k + col}] =
bitcast(loadVal, unpackType);
} else {
loadVals[{outer + col, k + row}] =
bitcast(loadVal, unpackType);
}
}
}
}
}

SmallVector<Value> loadedVals;
for (auto &ret : rets) {
VectorType loadTy = unpackType.cast<VectorType>();
for (size_t i = 0; i < loadTy.getNumElements(); ++i) {
Value loaded = extract_element(ret, i32_val(i));
loadedVals.push_back(loaded);
SmallVector<Value> unpackedLoadedVals;
for (int outer = 0; outer < numRepOuter; ++outer) {
for (int k = 0; k < numRepK; ++k) {
Value loadVal = loadVals.at({outer, k});
VectorType loadTy = loadVal.getType().cast<VectorType>();
for (int i = 0; i < loadTy.getNumElements(); ++i) {
auto val = extract_element(loadVal, i32_val(i));
unpackedLoadedVals.push_back(val);
}
}
}

Type llvmResultStructTy = typeConverter->convertType(op.getType());
Value resultStruct = packLLElements(loc, typeConverter, loadedVals,
rewriter, llvmResultStructTy);
Value resultStruct =
packLLElements(loc, typeConverter, unpackedLoadedVals, rewriter,
llvmResultStructTy);
rewriter.replaceOp(op, {resultStruct});

return success();
Expand Down

0 comments on commit ec24ea5

Please sign in to comment.