From ec24ea53661cd8bd7da17df8200b1befd17f9aef Mon Sep 17 00:00:00 2001 From: "Lu, Chengjun" Date: Mon, 11 Mar 2024 10:00:49 +0000 Subject: [PATCH] Change the 2D load to dense stride and using large size to load packed A and B operands. Need to change the convert layout and emit index as well for the dot operands layout with DPAS as parent. --- .../TritonIntelGPU/intel-2d-load-to-llvm.mlir | 10 +- .../LoadStoreOpToLLVM.cpp | 168 +++++++++++++----- 2 files changed, 128 insertions(+), 50 deletions(-) diff --git a/test/TritonIntelGPU/intel-2d-load-to-llvm.mlir b/test/TritonIntelGPU/intel-2d-load-to-llvm.mlir index f28aea7766..d3ae937ebc 100644 --- a/test/TritonIntelGPU/intel-2d-load-to-llvm.mlir +++ b/test/TritonIntelGPU/intel-2d-load-to-llvm.mlir @@ -1,8 +1,8 @@ -// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s +// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm -canonicalize | FileCheck %s -// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi32> -// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8i16(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi16> +// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<32xi32> +// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32i16(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<32xi16> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [1, 0]}> #mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], A = [8, 16], B = [16, 16], C = [8, 16]}> module attributes {"triton_gpu.compute-capability" = 2 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { @@ -23,9 +23,9 @@ module attributes {"triton_gpu.compute-capability" = 2 : i32, "triton_gpu.num-ct %6 = tt.make_tensor_ptr %arg1, [%1, %4], [%5, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> %7 = tt.advance %3, [%c64_i32, %c-32_i32] : >> %8 = tt.advance %7, [%c-64_i32, %c32_i32] : >> - // CHECK-COUNT-2: llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v8i16({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi16> + // CHECK-COUNT: llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v32i16({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<32xi16> %9 = triton_intel_gpu.load_2d %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr>> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> - // CHECK-COUNT-2: llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v8i32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi32> + // CHECK-COUNT: llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v32i32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<32xi32> %10 = triton_intel_gpu.load_2d %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr>> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> %11 = tt.dot %9, %10, %cst, inputPrecision = tf32 : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<64x64xf32, #mma> %12 = triton_gpu.convert_layout %11 : tensor<64x64xf32, #mma> -> tensor<64x64xf32, #blocked> diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 4e7dc8a67f..5a073a1698 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -470,6 +470,8 @@ struct Load2DOpConversion using ConvertTritonGPUOpToLLVMPattern< triton::gpu::intel::Load2DOp>::ConvertTritonGPUOpToLLVMPattern; + using ValueTable = std::map, Value>; + Load2DOpConversion(TritonGPUToLLVMTypeConverter &converter, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, @@ -523,39 +525,83 @@ struct Load2DOpConversion SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order); - Type load2DGenXType; - Type unpackType; - int64_t elemsPerLane; + int64_t numRepOuter = numReps[opIdx]; + int64_t numRepK = numReps[(opIdx == 0) ? 1 : 0]; + int64_t opaqueElemPerLane; + unsigned tileHeight; + unsigned elemsPerLanePerDotOp; + unsigned vBlocks = 1; + unsigned packedOuterDimPerLoad = 1; + unsigned packedKDimPerLoad = 1; SmallVector elemsPerInstr; + Type packedElemType = + opIdx == 0 ? type::i16Ty(ctx) : type::i32Ty(ctx); if (opIdx == 0) { auto shapeA = dpasLayout.getShapeA(); elemsPerInstr = {shapeA[0], shapeA[1]}; - elemsPerLane = product(elemsPerInstr) / - product(getThreadsPerWarp(dpasLayout)); - unpackType = LLVM::getFixedVectorType( - typeConverter->convertType(eltTy), elemsPerLane); + elemsPerLanePerDotOp = + product(elemsPerInstr) / + product(getThreadsPerWarp(dpasLayout)); + + unsigned maxPackedOuterDimPerLoad = 32 / elemsPerInstr[0]; + packedOuterDimPerLoad = + std::min(maxPackedOuterDimPerLoad, numRepOuter); + // use the tileHeight to load multiple operand A in one time. + tileHeight = elemsPerInstr[0] * packedOuterDimPerLoad; + + if (numRepK >= 2) { + // Double the block array length 2 to load operand A. + vBlocks = 2; + packedKDimPerLoad *= 2; + } else { + vBlocks = 1; + } // pack scalar to i16. auto opsPerChannel = dpasLayout.getOpsPerChannel(); - elemsPerLane = opsPerChannel == 4 ? elemsPerLane / 2 : elemsPerLane; - load2DGenXType = - LLVM::getFixedVectorType(type::i16Ty(ctx), elemsPerLane); - + opaqueElemPerLane = opsPerChannel == 4 ? elemsPerLanePerDotOp / 2 + : elemsPerLanePerDotOp; + opaqueElemPerLane = + opaqueElemPerLane * packedOuterDimPerLoad * packedKDimPerLoad; } else { auto shapeB = dpasLayout.getShapeB(); elemsPerInstr = {shapeB[0], shapeB[1]}; - elemsPerLane = product(elemsPerInstr) / - product(getThreadsPerWarp(dpasLayout)); - unpackType = LLVM::getFixedVectorType( - typeConverter->convertType(eltTy), elemsPerLane); + elemsPerLanePerDotOp = + product(elemsPerInstr) / + product(getThreadsPerWarp(dpasLayout)); + + if (numRepOuter >= 2) { + // Double the block array length to 2 to load operand B. + vBlocks = 2; + packedOuterDimPerLoad *= 2; + } else { + vBlocks = 1; + } - // pack scalar to i32 for load. + if (numRepK >= 2) { + // Double tileHeight to load operand B. + tileHeight = elemsPerInstr[0] * 2; + packedKDimPerLoad *= 2; + } else { + tileHeight = elemsPerInstr[0]; + } + + // pack scalar to i32. auto opsPerChannel = dpasLayout.getOpsPerChannel(); - elemsPerLane = elemsPerLane / opsPerChannel; - load2DGenXType = - LLVM::getFixedVectorType(type::i32Ty(ctx), elemsPerLane); + opaqueElemPerLane = (elemsPerLanePerDotOp / opsPerChannel); + opaqueElemPerLane = + opaqueElemPerLane * packedOuterDimPerLoad * packedKDimPerLoad; } + Type load2DGenXType = + LLVM::getFixedVectorType(packedElemType, opaqueElemPerLane); + Type decomposedType = LLVM::getFixedVectorType( + packedElemType, + opaqueElemPerLane / packedOuterDimPerLoad / packedKDimPerLoad); + Type unpackType = LLVM::getFixedVectorType( + typeConverter->convertType(eltTy), elemsPerLanePerDotOp); + + // Load the operand. // Outer dim, A is the M, B is the N. Inner dim, the K int outerDimWarpNum = std::min(warpsPerCTA[opIdx], @@ -570,26 +616,26 @@ struct Load2DOpConversion colStride, base) = getValuesFromBlockPointerStruct(blockPtr, rewriter); - // Load the operand. - int64_t numRepOuter = numReps[opIdx]; - int64_t numRepK = numReps[(opIdx == 0) ? 1 : 0]; + // A dense stride for the replicates. + unsigned repOuterStride = elemsPerInstr[opIdx]; + unsigned warpOuterStride = elemsPerInstr[opIdx] * numRepOuter; + unsigned repKStride = elemsPerInstr[opIdx == 0 ? 1 : 0]; - SmallVector rets; - for (int outer = 0; outer < numRepOuter; ++outer) { - for (int k = 0; k < numRepK; ++k) { + ValueTable loadVals; + for (int outer = 0; outer < numRepOuter; + outer += packedOuterDimPerLoad) { + for (int k = 0; k < numRepK; k += packedKDimPerLoad) { Value offsetX, offsetY; if (opIdx == 0) { // A - offsetY = add( - mul(outerDimWarpId, i32_val(elemsPerInstr[opIdx])), - i32_val(outer * outerDimWarpNum * elemsPerInstr[opIdx])); - offsetX = i32_val(k * elemsPerInstr[1]); + offsetY = add(mul(outerDimWarpId, i32_val(warpOuterStride)), + i32_val(outer * repOuterStride)); + offsetX = i32_val(k * repKStride); } else { // B - offsetX = add( - mul(outerDimWarpId, i32_val(elemsPerInstr[opIdx])), - i32_val(outer * outerDimWarpNum * elemsPerInstr[opIdx])); - offsetY = i32_val(k * elemsPerInstr[0]); + offsetX = add(mul(outerDimWarpId, i32_val(warpOuterStride)), + i32_val(outer * repOuterStride)); + offsetY = i32_val(k * repKStride); } offsetX = add(offsetX, offsetBaseX); offsetY = add(offsetY, offsetBaseY); @@ -616,32 +662,64 @@ struct Load2DOpConversion elemsPerInstr[1]), /*tile_height*/ mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), - elemsPerInstr[0]), + tileHeight), /*v_blocks*/ - mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), 1), + mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), + vBlocks), /*transpose*/ mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 1), 0), /*vnni_transform*/ mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 1), opIdx == 0 ? /*A vnni=false*/ 0 : /*B vnni=true*/ 1)); - Value loadVal = bitcast(load2dOp, unpackType); - rets.push_back(loadVal); + + unsigned packedRowNum = + opIdx == 0 ? packedOuterDimPerLoad : packedKDimPerLoad; + unsigned packedColNum = + opIdx == 0 ? packedKDimPerLoad : packedOuterDimPerLoad; + unsigned offset = 0; + // The packed load is contiguous on the row. + for (int col = 0; col < packedColNum; col++) { + for (int row = 0; row < packedRowNum; row++) { + + Value loadVal = undef(decomposedType); + for (int elemIdx = 0; + elemIdx < opaqueElemPerLane / packedOuterDimPerLoad / + packedKDimPerLoad; + elemIdx++) { + Value loaded = extract_element(load2dOp, i32_val(offset++)); + loadVal = insert_element(loadVal, loaded, i32_val(elemIdx)); + } + + // Save the unpacked vals to the map; + if (opIdx == 0) { + loadVals[{outer + row, k + col}] = + bitcast(loadVal, unpackType); + } else { + loadVals[{outer + col, k + row}] = + bitcast(loadVal, unpackType); + } + } + } } } - SmallVector loadedVals; - for (auto &ret : rets) { - VectorType loadTy = unpackType.cast(); - for (size_t i = 0; i < loadTy.getNumElements(); ++i) { - Value loaded = extract_element(ret, i32_val(i)); - loadedVals.push_back(loaded); + SmallVector unpackedLoadedVals; + for (int outer = 0; outer < numRepOuter; ++outer) { + for (int k = 0; k < numRepK; ++k) { + Value loadVal = loadVals.at({outer, k}); + VectorType loadTy = loadVal.getType().cast(); + for (int i = 0; i < loadTy.getNumElements(); ++i) { + auto val = extract_element(loadVal, i32_val(i)); + unpackedLoadedVals.push_back(val); + } } } Type llvmResultStructTy = typeConverter->convertType(op.getType()); - Value resultStruct = packLLElements(loc, typeConverter, loadedVals, - rewriter, llvmResultStructTy); + Value resultStruct = + packLLElements(loc, typeConverter, unpackedLoadedVals, rewriter, + llvmResultStructTy); rewriter.replaceOp(op, {resultStruct}); return success();