diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 2ee39e28db..3912191f4f 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -525,6 +525,10 @@ For the Threads Per Warp and Values Per Thread level, the linear id distribution InterfaceMethod<"Gets the number of contiguous elements per thread.", "SmallVector", "getContigPerThread">, + InterfaceMethod<"Convert to LinearLayout.", + "std::optional", + "toLinearLayout", + (ins "ArrayRef":$shape)> ]; } @@ -576,6 +580,8 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, SmallVector getSizePerThread() const; SmallVector getShapePerCTATile(ArrayRef tensorShape = ArrayRef()) const; + + std::optional toLinearLayout(ArrayRef shape) const; }]; } diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h b/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h index 0ee2cfeca0..9cf2876d2c 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h @@ -1,6 +1,6 @@ #ifndef TRITON_GPU_DIALECT_INTERFACES_H #define TRITON_GPU_DIALECT_INTERFACES_H - +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc" #endif // TRITON_GPU_DIALECT_INTERFACES_H diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 58eb031d2d..a65b9e64e2 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -268,24 +268,6 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, return ret; } -LinearLayout blockedToLinearLayout(ArrayRef shape, - BlockedEncodingAttr blocked) { - assert(shape.size() == blocked.getOrder().size()); - - int rank = shape.size(); - MLIRContext *ctx = blocked.getContext(); - SmallVector outDimNames = standardOutDimNames(ctx, rank); - - const auto &order = blocked.getOrder(); - LinearLayout ctaLayout = - identityND(S("register"), blocked.getSizePerThread(), order, - outDimNames) * - identityND(S("lane"), blocked.getThreadsPerWarp(), order, outDimNames) * - identityND(S("warp"), blocked.getWarpsPerCTA(), order, outDimNames); - - return combineCtaCgaWithShape(ctaLayout, blocked.getCTALayout(), shape); -} - LinearLayout ampereMmaToLinearLayout(ArrayRef shape, NvidiaMmaEncodingAttr mma) { int rank = shape.size(); @@ -350,25 +332,147 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef shape, return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); } -LinearLayout mfmaToLinearLayout(ArrayRef shape, - AMDMfmaEncodingAttr mfma) { +LinearLayout sharedToLinearLayoutNoLeadingOffset(ArrayRef shape, + SharedEncodingAttr shared) { + assert(!shared.getHasLeadingOffset()); + + MLIRContext *ctx = shared.getContext(); + int rank = shape.size(); + if (rank == 1) { + return combineCtaCgaWithShape( + LinearLayout::identity1D(shape[0], S("offset"), S("dim0")), + shared.getCTALayout(), shape); + } + + auto outDimNames = standardOutDimNames(ctx, rank); + + // Construct bases for the 2 most minor dimensions of the layout. These are + // the dims that get swizzled. + assert(shape.size() >= 2); + int colDim = shared.getOrder()[0]; + int rowDim = shared.getOrder()[1]; + int numCols = shape[colDim]; + int numRows = shape[rowDim]; + StringAttr colDimName = outDimNames[colDim]; + StringAttr rowDimName = outDimNames[rowDim]; + + std::vector> bases2D; + for (int logCol = 0; logCol < llvm::Log2_32(numCols); logCol++) { + bases2D.push_back({0, 1 << logCol}); + } + for (int logRow = 0; logRow < llvm::Log2_32(numRows); logRow++) { + int row = 1 << logRow; + int vec = shared.getVec(); + int perPhase = shared.getPerPhase(); + int maxPhase = shared.getMaxPhase(); + bases2D.push_back({row, (vec * ((row / perPhase) % maxPhase)) % numCols}); + } + LinearLayout ctaLayout = + LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName}); + + // Add the remaining dimensions. + for (int i = 2; i < rank; i++) { + int dim = shared.getOrder()[i]; + ctaLayout *= + LinearLayout::identity1D(shape[dim], S("offset"), outDimNames[dim]); + } + + return combineCtaCgaWithShape(ctaLayout, shared.getCTALayout(), shape); +} + +LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, + SharedEncodingAttr shared, + int32_t elemBitWidth) { + assert(shared.getHasLeadingOffset()); + + MLIRContext *ctx = shared.getContext(); + int rank = shape.size(); + if (rank == 1) { + // TODO: Not sure if this is correct. + return combineCtaCgaWithShape( + LinearLayout::identity1D(shape[0], S("offset"), S("dim0")), + shared.getCTALayout(), shape); + } + + int tileWidthBytes; + if (shared.getPerPhase() == 4 && shared.getMaxPhase() == 2) { + tileWidthBytes = 32; + } else if (shared.getPerPhase() == 2 && shared.getMaxPhase() == 4) { + tileWidthBytes = 64; + } else if (shared.getPerPhase() == 1 && shared.getMaxPhase() == 8) { + tileWidthBytes = 128; + } else { + llvm::errs() + << "Illegal shared encoding. If hasLeadingOffset is true, " + "then (perPhase, maxPhase) must be either (4,2), (2,4), or (1,8): " + << shared << "\n"; + llvm_unreachable("Illegal shared encoding"); + } + + auto outDimNames = standardOutDimNames(ctx, rank); + + // Construct bases for a the layout's 2-dimensional tile. + assert(shape.size() >= 2); + int colDim = shared.getOrder()[0]; + int rowDim = shared.getOrder()[1]; + + int tileRows = 8; + int tileCols = 8 * tileWidthBytes / elemBitWidth; + + int vec = 8 * 16 / elemBitWidth; + if (vec != shared.getVec()) { + llvm::errs() << "Illegal shared layout; expected `vec` to be " << vec + << ": " << shared << "\n"; + llvm::report_fatal_error("Illegal shared layout"); + } + + StringAttr colDimName = outDimNames[colDim]; + StringAttr rowDimName = outDimNames[rowDim]; + + std::vector> bases2D; + for (int logCol = 0; logCol < llvm::Log2_32(tileCols); logCol++) { + bases2D.push_back({0, 1 << logCol}); + } + for (int logRow = 0; logRow < llvm::Log2_32(tileRows); logRow++) { + int row = 1 << logRow; + int perPhase = shared.getPerPhase(); + int maxPhase = shared.getMaxPhase(); + bases2D.push_back({row, vec * ((row / perPhase) % maxPhase)}); + } + LinearLayout tileLayout = + LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName}); + + // Add the remaining dimensions. + for (int i = 2; i < rank; i++) { + int dim = shared.getOrder()[i]; + tileLayout *= + LinearLayout::identity1D(shape[dim], S("offset"), outDimNames[dim]); + } + + return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape); +} + +} // anonymous namespace + +std::optional +AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { int rank = shape.size(); - assert(rank == mfma.getWarpsPerCTA().size()); + assert(rank == getWarpsPerCTA().size()); bool hasBatchDim = rank == 3; int mIndex = 0 + hasBatchDim; int nIndex = 1 + hasBatchDim; (void)mIndex, (void)nIndex; - assert(((shape[mIndex] == 1 || shape[mIndex] >= mfma.getMDim()) && - (shape[nIndex] == 1 || shape[nIndex] >= mfma.getNDim())) && + assert(((shape[mIndex] == 1 || shape[mIndex] >= getMDim()) && + (shape[nIndex] == 1 || shape[nIndex] >= getNDim())) && "Unsupported tensor shape for given mfma layout"); - assert(((mfma.getMDim() == 32 && mfma.getNDim() == 32) || - (mfma.getMDim() == 16 && mfma.getNDim() == 16)) && + assert(((getMDim() == 32 && getNDim() == 32) || + (getMDim() == 16 && getNDim() == 16)) && "Unsupported mfma type"); - MLIRContext *ctx = mfma.getContext(); + MLIRContext *ctx = getContext(); SmallVector outDimNames = standardOutDimNames(ctx, rank); StringAttr kRegister = S("register"); @@ -379,10 +483,10 @@ LinearLayout mfmaToLinearLayout(ArrayRef shape, // We use the order from fastest varying to slowest varying. So each base // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices. - SmallVector order = triton::gpu::getOrder(mfma); + SmallVector order = triton::gpu::getOrder(*this); auto tileLayout = LinearLayout::empty(); - if (mfma.getMDim() == 32) { + if (getMDim() == 32) { // For mfma with 32x32 output, each of the 64 threads holds 16 elements. // // For the register (i.e., element) dimension, these 16 elements are along @@ -397,7 +501,7 @@ LinearLayout mfmaToLinearLayout(ArrayRef shape, {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, /*gap*/ {0, 4}}}}, {outDimNames[order[0]], outDimNames[order[1]]}); } else { - assert(mfma.getMDim() == 16); + assert(getMDim() == 16); // For mfma with 16x16 output, each of the 64 threads holds 4 elements. // // For the register (i.e., element) dimension, these 4 elements are along @@ -422,23 +526,23 @@ LinearLayout mfmaToLinearLayout(ArrayRef shape, // And each warp takes the same register and lane sub-layout. So mulitply with // an identity layout for the warp. LinearLayout warpLayout = - identityND(S("warp"), mfma.getWarpsPerCTA(), order, outDimNames); + identityND(S("warp"), getWarpsPerCTA(), order, outDimNames); LinearLayout ctaLayout = tileLayout * warpLayout; - return combineCtaCgaWithShape(ctaLayout, mfma.getCTALayout(), shape); + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); } -LinearLayout wmmaToLinearLayout(ArrayRef shape, - AMDWmmaEncodingAttr wmma) { +std::optional +AMDWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { int rank = shape.size(); - assert(rank == wmma.getWarpsPerCTA().size()); + assert(rank == getWarpsPerCTA().size()); bool hasBatchDim = rank == 3; int mIndex = 0 + hasBatchDim; int nIndex = 1 + hasBatchDim; (void)mIndex, (void)nIndex; - SmallVector mnkDim = wmma.getMNKDimPerInstr(); + SmallVector mnkDim = getMNKDimPerInstr(); unsigned mDim = mnkDim[0], nDim = mnkDim[1]; (void)mDim, (void)nDim; @@ -446,7 +550,7 @@ LinearLayout wmmaToLinearLayout(ArrayRef shape, (shape[nIndex] == 1 || shape[nIndex] >= nDim)) && "Unsupported tensor shape for given wmma layout"); - MLIRContext *ctx = wmma.getContext(); + MLIRContext *ctx = getContext(); SmallVector outDimNames = standardOutDimNames(ctx, rank); StringAttr kRegister = S("register"); @@ -457,7 +561,7 @@ LinearLayout wmmaToLinearLayout(ArrayRef shape, // We use the order from fastest varying to slowest varying. So each base // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices. - SmallVector order = triton::gpu::getOrder(wmma); + SmallVector order = triton::gpu::getOrder(*this); // For wmma with 16x16 output, each of the 32 threads holds 8 elements. // @@ -484,26 +588,54 @@ LinearLayout wmmaToLinearLayout(ArrayRef shape, // And each warp takes the same register and lane sub-layout. So mulitply with // an identity layout for the warp. LinearLayout warpLayout = - identityND(S("warp"), wmma.getWarpsPerCTA(), order, outDimNames); + identityND(S("warp"), getWarpsPerCTA(), order, outDimNames); LinearLayout ctaLayout = tileLayout * warpLayout; - return combineCtaCgaWithShape(ctaLayout, wmma.getCTALayout(), shape); + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +std::optional +BlockedEncodingAttr::toLinearLayout(ArrayRef shape) const { + assert(shape.size() == getOrder().size()); + + int rank = shape.size(); + MLIRContext *ctx = getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + const auto &order = getOrder(); + LinearLayout ctaLayout = + identityND(S("register"), getSizePerThread(), order, outDimNames) * + identityND(S("lane"), getThreadsPerWarp(), order, outDimNames) * + identityND(S("warp"), getWarpsPerCTA(), order, outDimNames); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +std::optional +NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + if (isAmpere()) { + return ampereMmaToLinearLayout(shape, *this); + } + if (isHopper()) { + return hopperMmaToLinearLayout(shape, *this); + } + return std::nullopt; } -LinearLayout sliceToLinearLayout(ArrayRef shape, - SliceEncodingAttr slice) { - MLIRContext *ctx = slice.getContext(); +std::optional +SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { + MLIRContext *ctx = getContext(); // First compute the linear layout for this layout's parent. SmallVector parentShape(shape); - parentShape.insert(parentShape.begin() + slice.getDim(), 1); + parentShape.insert(parentShape.begin() + getDim(), 1); std::optional parentLL = - triton::gpu::toLinearLayout(parentShape, slice.getParent()); + triton::gpu::toLinearLayout(parentShape, getParent()); if (!parentLL.has_value()) llvm::report_fatal_error( "Failed to compute parent layout for slice layout."); - // Remove dimension slice.getDim() from the parent layout. + // Remove dimension getDim() from the parent layout. // // 1. Construct a layout `transform` from parent-out-dims to slice-out-dims // that removes the relevant out-dim. @@ -513,15 +645,15 @@ LinearLayout sliceToLinearLayout(ArrayRef shape, auto outDimNames = standardOutDimNames(ctx, shape.size() + 1); LinearLayout transform = LinearLayout::empty(); for (auto [idx, outDim] : llvm::enumerate(parentLL->getOutDimNames())) { - if (idx == slice.getDim()) { + if (idx == getDim()) { // Because we're multiplying by all zeros, we could replace outDimNames[0] // with any other valid out-dim; the layout will be the same. transform *= LinearLayout::zeros1D(parentLL->getOutDimSize(outDim), outDim, outDimNames[0]); } else { - transform *= LinearLayout::identity1D( - parentLL->getOutDimSize(outDim), outDim, - outDimNames[idx - (idx < slice.getDim() ? 0 : 1)]); + transform *= + LinearLayout::identity1D(parentLL->getOutDimSize(outDim), outDim, + outDimNames[idx - (idx < getDim() ? 0 : 1)]); } } LinearLayout sliceLL = parentLL->compose(transform); @@ -545,8 +677,9 @@ LinearLayout sliceToLinearLayout(ArrayRef shape, // // TODO(jlebar): Once getTotalElemsPerThread uses LLs instead of the existing // legacy code, I think we can remove this. - int expectedNumRegisters = getTotalElemsPerThread(RankedTensorType::get( - shape, IntegerType::get(ctx, 32) /*dummy type*/, slice)); + int expectedNumRegisters = + triton::gpu::getTotalElemsPerThread(RankedTensorType::get( + shape, IntegerType::get(ctx, 32) /*dummy type*/, *this)); if (ret.getInDimSize(S("register")) != expectedNumRegisters) { int extraZeros = expectedNumRegisters / ret.getInDimSize(S("register")); // Our use of "dim0" here is arbitrary; because we're adding zeros, any @@ -556,150 +689,17 @@ LinearLayout sliceToLinearLayout(ArrayRef shape, return ret; } -LinearLayout sharedToLinearLayoutNoLeadingOffset(ArrayRef shape, - SharedEncodingAttr shared) { - assert(!shared.getHasLeadingOffset()); - - MLIRContext *ctx = shared.getContext(); - int rank = shape.size(); - if (rank == 1) { - return combineCtaCgaWithShape( - LinearLayout::identity1D(shape[0], S("offset"), S("dim0")), - shared.getCTALayout(), shape); - } - - auto outDimNames = standardOutDimNames(ctx, rank); - - // Construct bases for the 2 most minor dimensions of the layout. These are - // the dims that get swizzled. - assert(shape.size() >= 2); - int colDim = shared.getOrder()[0]; - int rowDim = shared.getOrder()[1]; - int numCols = shape[colDim]; - int numRows = shape[rowDim]; - StringAttr colDimName = outDimNames[colDim]; - StringAttr rowDimName = outDimNames[rowDim]; - - std::vector> bases2D; - for (int logCol = 0; logCol < llvm::Log2_32(numCols); logCol++) { - bases2D.push_back({0, 1 << logCol}); - } - for (int logRow = 0; logRow < llvm::Log2_32(numRows); logRow++) { - int row = 1 << logRow; - int vec = shared.getVec(); - int perPhase = shared.getPerPhase(); - int maxPhase = shared.getMaxPhase(); - bases2D.push_back({row, (vec * ((row / perPhase) % maxPhase)) % numCols}); - } - LinearLayout ctaLayout = - LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName}); - - // Add the remaining dimensions. - for (int i = 2; i < rank; i++) { - int dim = shared.getOrder()[i]; - ctaLayout *= - LinearLayout::identity1D(shape[dim], S("offset"), outDimNames[dim]); - } - - return combineCtaCgaWithShape(ctaLayout, shared.getCTALayout(), shape); -} - -LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, - SharedEncodingAttr shared, - int32_t elemBitWidth) { - assert(shared.getHasLeadingOffset()); - - MLIRContext *ctx = shared.getContext(); - int rank = shape.size(); - if (rank == 1) { - // TODO: Not sure if this is correct. - return combineCtaCgaWithShape( - LinearLayout::identity1D(shape[0], S("offset"), S("dim0")), - shared.getCTALayout(), shape); - } - - int tileWidthBytes; - if (shared.getPerPhase() == 4 && shared.getMaxPhase() == 2) { - tileWidthBytes = 32; - } else if (shared.getPerPhase() == 2 && shared.getMaxPhase() == 4) { - tileWidthBytes = 64; - } else if (shared.getPerPhase() == 1 && shared.getMaxPhase() == 8) { - tileWidthBytes = 128; - } else { - llvm::errs() - << "Illegal shared encoding. If hasLeadingOffset is true, " - "then (perPhase, maxPhase) must be either (4,2), (2,4), or (1,8): " - << shared << "\n"; - llvm_unreachable("Illegal shared encoding"); - } - - auto outDimNames = standardOutDimNames(ctx, rank); - - // Construct bases for a the layout's 2-dimensional tile. - assert(shape.size() >= 2); - int colDim = shared.getOrder()[0]; - int rowDim = shared.getOrder()[1]; - - int tileRows = 8; - int tileCols = 8 * tileWidthBytes / elemBitWidth; - - int vec = 8 * 16 / elemBitWidth; - if (vec != shared.getVec()) { - llvm::errs() << "Illegal shared layout; expected `vec` to be " << vec - << ": " << shared << "\n"; - llvm::report_fatal_error("Illegal shared layout"); - } - - StringAttr colDimName = outDimNames[colDim]; - StringAttr rowDimName = outDimNames[rowDim]; - - std::vector> bases2D; - for (int logCol = 0; logCol < llvm::Log2_32(tileCols); logCol++) { - bases2D.push_back({0, 1 << logCol}); - } - for (int logRow = 0; logRow < llvm::Log2_32(tileRows); logRow++) { - int row = 1 << logRow; - int perPhase = shared.getPerPhase(); - int maxPhase = shared.getMaxPhase(); - bases2D.push_back({row, vec * ((row / perPhase) % maxPhase)}); - } - LinearLayout tileLayout = - LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName}); - - // Add the remaining dimensions. - for (int i = 2; i < rank; i++) { - int dim = shared.getOrder()[i]; - tileLayout *= - LinearLayout::identity1D(shape[dim], S("offset"), outDimNames[dim]); - } - - return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape); +// TODO: DotOperandEncoding doesn't support LinearLayout conversion yet. +std::optional +DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { + return std::nullopt; } -} // anonymous namespace - std::optional toLinearLayout(ArrayRef shape, Attribute layout, std::optional elemBitWidth /*= std::nullopt*/) { - if (auto blocked = dyn_cast(layout)) { - return blockedToLinearLayout(shape, blocked); - } - if (auto mfma = dyn_cast(layout)) { - return mfmaToLinearLayout(shape, mfma); - } - if (auto wmma = dyn_cast(layout)) { - return wmmaToLinearLayout(shape, wmma); - } - if (auto mma = dyn_cast(layout)) { - if (mma.isAmpere()) { - return ampereMmaToLinearLayout(shape, mma); - } - if (mma.isHopper()) { - return hopperMmaToLinearLayout(shape, mma); - } - } - if (auto slice = dyn_cast(layout)) { - return sliceToLinearLayout(shape, slice); + if (auto distributed = dyn_cast(layout)) { + return distributed.toLinearLayout(shape); } if (auto shared = dyn_cast(layout)) { if (shared.getHasLeadingOffset()) {