From ffc18faffefdb1cde7c6427a3f99ccc623b73dc6 Mon Sep 17 00:00:00 2001 From: "Lu, Chengjun" Date: Mon, 22 Apr 2024 14:24:13 +0000 Subject: [PATCH] Add the nested layout with dot op layout. Such as `#triton_gpu.slice<{dim = 1, parent = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>}>` --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 17 ++ .../IR/TritonIntelGPUAttrDefs.td | 1 + lib/Dialect/TritonGPU/IR/Dialect.cpp | 3 + lib/Dialect/TritonIntelGPU/IR/Dialect.cpp | 9 + .../intel/lib/TritonIntelGPUToLLVM/Utility.h | 169 ++++++++++++++++++ 5 files changed, 199 insertions(+) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 47513d3c0f..863b0db106 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -773,6 +773,11 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { InterfaceMethod<"Return size per thread for dot operands.", "SmallVector", "getSizePerThreadForOperands", (ins "unsigned":$opIdx)>, + + InterfaceMethod<"Return element sizes per thread for dot operands.", "SmallVector", + "getElemsPerThreadForOperands", (ins "ArrayRef":$tensorShape, + "Type":$eltTy, + "unsigned":$opIdx)>, ]; } @@ -904,6 +909,10 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129, return contigPerThread; }; + SmallVector getElemsPerThreadForOperands(ArrayRef shape, Type eltTy, unsigned opIdx) const { + llvm_unreachable("getElemsPerThreadForOperands is not supported."); + }; + }]; let hasCustomAssemblyFormat = 1; @@ -968,6 +977,10 @@ Suppose we have a tensor with shape [32, 48], `warpsPerCTA` set to [2, 3]. SmallVector contigPerThread(rank, 1); return contigPerThread; }; + + SmallVector getElemsPerThreadForOperands(ArrayRef shape, Type eltTy, unsigned opIdx) const { + llvm_unreachable("getElemsPerThreadForOperands is not supported."); + }; }]; } @@ -1169,6 +1182,10 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: return contigPerThread; }; + SmallVector getElemsPerThreadForOperands(ArrayRef shape, Type eltTy, unsigned opIdx) const { + llvm_unreachable("getElemsPerThreadForOperands is not supported."); + }; + }]; let hasCustomAssemblyFormat = 1; diff --git a/include/triton/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/include/triton/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index 81592ffcc8..53356701e0 100644 --- a/include/triton/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/include/triton/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -79,6 +79,7 @@ along the row (resp. col) dimension. SmallVector getShapeC() const; SmallVector getDPASRepetitions(ArrayRef shape, int opIdx) const; SmallVector getSizePerThreadForOperands(unsigned opIdx) const; + SmallVector getElemsPerThreadForOperands(ArrayRef shape, Type eltTy, unsigned opIdx) const; SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index eb27b7bfd5..6cb610c206 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -908,6 +908,9 @@ unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, SmallVector DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, Type eltTy) const { + if (auto mmaParent = getParent().dyn_cast()) { + return mmaParent.getElemsPerThreadForOperands(shape, eltTy, getOpIdx()); + } llvm_unreachable("getElemsPerThread is not supported for dot operand"); return SmallVector(); } diff --git a/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index 6ce5c79ce6..6d62fdb339 100644 --- a/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -266,6 +266,15 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { } } +SmallVector DpasEncodingAttr::getElemsPerThreadForOperands( + ArrayRef shape, Type eltTy, unsigned opIdx) const { + SmallVector sizePerThread = getSizePerThreadForOperands(opIdx); + SmallVector repetitions = getDPASRepetitions(shape, opIdx); + + return SmallVector{(unsigned)(sizePerThread[0] * repetitions[0]), + (unsigned)(sizePerThread[1] * repetitions[1])}; +}; + SmallVector DpasEncodingAttr::getContigPerThread() { unsigned threadsPerWarp = getSubGroupSize(); auto shapeC = getShapeC(); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h index e8f431c502..997f590041 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h @@ -178,6 +178,87 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout, } } +static SmallVector> +emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout, + RankedTensorType type) { + + if (auto dpasLayout = dotLayout.getParent().dyn_cast()) { + ArrayRef shape = type.getShape(); + SmallVector> offsets; + auto shapePerCTA = triton::gpu::getShapePerCTA(type); + + auto opIdx = dotLayout.getOpIdx(); + SmallVector numReps = + dpasLayout.getDPASRepetitions(shapePerCTA, opIdx); + + SmallVector warpShape; + if (opIdx == 0) { + warpShape = dpasLayout.getShapeA(); + } else { + warpShape = dpasLayout.getShapeB(); + } + + unsigned warpSize = triton::gpu::getWarpSize(dpasLayout); + unsigned numElemPerInstPerThread = product(warpShape) / warpSize; + + unsigned systolicDepth = dpasLayout.getSystolicDepth(); + unsigned repeatCount = dpasLayout.getRepeatCount(); + unsigned executionSize = dpasLayout.getExecutionSize(); + unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); + + unsigned rowsPerWarp, numElemPerInstPerRowPerThread; + switch (opIdx) { + case 0: { + assert((opsPerChannel == 4 || opsPerChannel == 2 || opsPerChannel == 1) && + "invalid opsPerChannel number."); + SmallVector shapeA = dpasLayout.getShapeA(); + // Unlike the operand B, to pack the value to i16 for scalar bit width + // <=16. + unsigned packedOpsPerLane = opsPerChannel == 4 ? 2 : 1; + unsigned packedColNum = shapeA[1] / packedOpsPerLane; + if (warpSize < packedColNum) { + llvm::report_fatal_error( + "DpasEncodingAttr sub-group size could not " + "be smaller than the threads required per row for A operand."); + } + rowsPerWarp = warpSize / packedColNum; + numElemPerInstPerRowPerThread = packedOpsPerLane; + } break; + case 1: { + if (warpSize < executionSize) { + llvm::report_fatal_error( + "DpasEncodingAttr sub-group size could not " + "be smaller than the execution size for B operand."); + } + rowsPerWarp = warpSize / executionSize; + rowsPerWarp = rowsPerWarp * opsPerChannel; + numElemPerInstPerRowPerThread = 1; + } break; + } + + auto shapePerCTATile = triton::gpu::getShapePerCTATile(dotLayout); + int64_t numRepOuter = numReps[opIdx]; + int64_t numRepK = numReps[(opIdx == 0) ? 1 : 0]; + for (int dimOuter = 0; dimOuter < numRepOuter; ++dimOuter) + for (int k = 0; k < numRepK; ++k) + for (unsigned elemId = 0; elemId < numElemPerInstPerThread; ++elemId) { + uint32_t repRowIndex = + shapePerCTATile[0] * (opIdx == 0 ? dimOuter : k); + uint32_t repColIndex = + shapePerCTATile[1] * (opIdx == 0 ? k : dimOuter); + uint32_t elemRowIndex = + (elemId / numElemPerInstPerRowPerThread) * rowsPerWarp; + uint32_t elemColIndex = elemId % numElemPerInstPerRowPerThread; + offsets.push_back( + {repRowIndex + elemRowIndex, repColIndex + elemColIndex}); + } + + return offsets; + } else { + llvm_unreachable("unsupported parent layout in emitOffsetForDotOpLayout"); + } +} + static SmallVector> emitOffsetForDpasLayout(const DpasEncodingAttr &dpasLayout, RankedTensorType type) { @@ -197,6 +278,89 @@ emitOffsetForDpasLayout(const DpasEncodingAttr &dpasLayout, // ----------------------------------------------------------------------- // Dpas layout indices // ----------------------------------------------------------------------- +static SmallVector +emitBaseIndexForDotOpLayout(Location loc, RewriterBase &rewriter, + const DotOperandEncodingAttr &dotLayout, + RankedTensorType type) { + + if (auto dpasLayout = dotLayout.getParent().dyn_cast()) { + Value threadId = getThreadId(rewriter, loc); + unsigned warpSize = triton::gpu::getWarpSize(dpasLayout); + Value warpId = udiv(threadId, i32_val(warpSize)); + Value laneId = urem(threadId, i32_val(warpSize)); + + const SmallVector warpsPerCTA = dpasLayout.getWarpsPerCTA(); + SmallVector order = triton::gpu::getOrder(dpasLayout); + auto shapePerCTA = triton::gpu::getShapePerCTA(type); + + SmallVector warpShape; + auto opIdx = dotLayout.getOpIdx(); + if (opIdx == 0) { + warpShape = dpasLayout.getShapeA(); + } else { + warpShape = dpasLayout.getShapeB(); + } + SmallVector numReps = + dpasLayout.getDPASRepetitions(shapePerCTA, opIdx); + + SmallVector multiDimWarpId = + mlir::LLVM::delinearize(rewriter, loc, warpId, warpsPerCTA, order); + + Value rowWarpId = + urem(multiDimWarpId[0], + i32_val(mlir::ceil(shapePerCTA[0], warpShape[0]))); + Value colWarpId = + urem(multiDimWarpId[1], + i32_val(mlir::ceil(shapePerCTA[1], warpShape[1]))); + Value rowWarpOffset = mul(rowWarpId, i32_val(warpShape[0])); + Value colWarpOffset = mul(colWarpId, i32_val(warpShape[1])); + + // Compute the 2-dim coordinates of the first element in the warp operated + // own by this thread. + unsigned systolicDepth = dpasLayout.getSystolicDepth(); + unsigned repeatCount = dpasLayout.getRepeatCount(); + unsigned executionSize = dpasLayout.getExecutionSize(); + unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); + + Value laneRowIndex, laneColIndex; + switch (opIdx) { + case 0: { + assert((opsPerChannel == 4 || opsPerChannel == 2 || opsPerChannel == 1) && + "invalid opsPerChannel number."); + SmallVector shapeA = dpasLayout.getShapeA(); + // Unlike the operand B, to pack the value to i16 for scalar bit width + // <=16. + unsigned packedOpsPerLane = opsPerChannel == 4 ? 2 : 1; + unsigned packedColNum = shapeA[1] / packedOpsPerLane; + if (warpSize < packedColNum) { + llvm::report_fatal_error( + "DpasEncodingAttr sub-group size could not " + "be smaller than the threads required per row for A operand."); + } + laneRowIndex = udiv(laneId, i32_val(packedColNum)); + laneColIndex = urem(laneId, i32_val(packedColNum)); + laneColIndex = mul(laneColIndex, i32_val(packedOpsPerLane)); + } break; + case 1: { + if (warpSize < executionSize) { + llvm::report_fatal_error( + "DpasEncodingAttr sub-group size could not " + "be smaller than the execution size for B operand."); + } + laneRowIndex = udiv(laneId, i32_val(executionSize)); + laneRowIndex = mul(laneRowIndex, i32_val(opsPerChannel)); + laneColIndex = urem(laneId, i32_val(executionSize)); + } break; + } + + SmallVector multiDimBase = {add(laneRowIndex, rowWarpOffset), + add(laneColIndex, colWarpOffset)}; + return multiDimBase; + } else { + llvm_unreachable( + "unsupported parent layout in emitBaseIndexForDotOpLayout"); + } +} static SmallVector emitBaseIndexForDpasLayout(Location loc, RewriterBase &rewriter, @@ -330,6 +494,8 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, Attribute layout, result.erase(result.begin() + sliceLayout.getDim()); // CTAOffset has been added in emitBaseIndexForLayout of parentLayout return result; + } else if (auto dotLayout = layout.dyn_cast()) { + result = emitBaseIndexForDotOpLayout(loc, rewriter, dotLayout, type); } else { llvm_unreachable("unsupported emitBaseIndexForLayout"); } @@ -348,6 +514,9 @@ emitOffsetForLayout(Attribute layout, RankedTensorType type) { if (auto dpasLayout = layout.dyn_cast()) { return emitOffsetForDpasLayout(dpasLayout, type); } + if (auto dotLayout = layout.dyn_cast()) { + return emitOffsetForDotOpLayout(dotLayout, type); + } if (auto sliceLayout = layout.dyn_cast()) return ::intel::emitOffsetForSliceLayout(sliceLayout, type); return ::emitOffsetForLayout(layout, type);