From 59539b0203aced8240f12a79074b747f7ea903a8 Mon Sep 17 00:00:00 2001 From: HyoungWook Nam Date: Tue, 20 Aug 2024 01:28:27 +0000 Subject: [PATCH] LinearLayout conversion interface (#1875) --- .../TritonGPU/IR/LinearLayoutConversions.cpp | 3 +- .../tritonintlgpu-nested-layout.mlir | 58 +++----- .../Dialect/TritonIntelGPU/IR/Attributes.h | 2 + .../IR/LinearLayoutConversions.h | 5 +- .../lib/Dialect/TritonIntelGPU/IR/Dialect.cpp | 6 + .../IR/LinearLayoutConversions.cpp | 49 +++---- .../ConvertLayoutOpToLLVM.cpp | 15 +-- .../lib/TritonIntelGPUToLLVM/Utility.cpp | 125 ------------------ .../intel/lib/TritonIntelGPUToLLVM/Utility.h | 38 ------ .../TritonGPU/DPAStoLinearLayoutTest.cpp | 17 ++- 10 files changed, 75 insertions(+), 243 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index a65b9e64e2..9bd9ae199d 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -632,8 +632,7 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { std::optional parentLL = triton::gpu::toLinearLayout(parentShape, getParent()); if (!parentLL.has_value()) - llvm::report_fatal_error( - "Failed to compute parent layout for slice layout."); + return std::nullopt; // Remove dimension getDim() from the parent layout. // diff --git a/test/TritonIntelGPU/tritonintlgpu-nested-layout.mlir b/test/TritonIntelGPU/tritonintlgpu-nested-layout.mlir index 2936383c9c..2bb504d76f 100644 --- a/test/TritonIntelGPU/tritonintlgpu-nested-layout.mlir +++ b/test/TritonIntelGPU/tritonintlgpu-nested-layout.mlir @@ -11,46 +11,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32 // CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32 // CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK-DAG: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32 // CHECK-DAG: %[[CST_6:.*]] = llvm.mlir.constant(6 : i32) : i32 // CHECK-DAG: %[[CST_18:.*]] = llvm.mlir.constant(18 : i32) : i32 // CHECK-DAG: %[[CST_20:.*]] = llvm.mlir.constant(20 : i32) : i32 // CHECK-DAG: %[[CST_22:.*]] = llvm.mlir.constant(22 : i32) : i32 - - // CHECK: %[[THREAD_ID:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[CST_0]]) - // CHECK: %[[THREAD_ID_32:.*]] = llvm.trunc %[[THREAD_ID]] : i64 to i32 - // CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREAD_ID_32]], %[[CST_16]] : i32 - // CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_32]], %[[CST_16]] : i32 - // CHECK: %[[WARP_ID_Y:.*]] = llvm.urem %[[WARP_ID]], %[[CST_2]] : i32 - // CHECK: %[[VAL_23:.*]] = llvm.udiv %[[WARP_ID]], %[[CST_2]] : i32 - // CHECK: %[[WARP_ID_X:.*]] = llvm.urem %[[VAL_23]], %[[CST_2]] : i32 - // CHECK: %[[ROUNDED_WARP_ID_X:.*]] = llvm.urem %[[WARP_ID_X]], %[[CST_4]] : i32 - // CHECK: %[[ROUNDED_WARP_ID_Y:.*]] = llvm.urem %[[WARP_ID_Y]], %[[CST_4]] : i32 - // CHECK: %[[WARP_OFFSET_X:.*]] = llvm.mul %[[ROUNDED_WARP_ID_X]], %[[CST_8]] : i32 - // CHECK: %[[WARP_OFFSET_Y:.*]] = llvm.mul %[[ROUNDED_WARP_ID_Y]], %[[CST_8]] : i32 - // CHECK: %[[LANE_OFFSET_X:.*]] = llvm.udiv %[[LANE_ID]], %[[CST_8]] : i32 - // CHECK: %[[OFFSET_X:.*]] = llvm.add %[[LANE_OFFSET_X]], %[[WARP_OFFSET_X]] : i32 - // CHECK: %[[LANE_OFFSET_Y:.*]] = llvm.urem %[[LANE_ID]], %[[CST_8]] : i32 - // CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[LANE_OFFSET_Y]], %[[WARP_OFFSET_Y]] : i32 - // CHECK: %[[VAL_33:.*]] = llvm.urem %[[CST_0]], %[[CST_1]] : i32 - // CHECK: %[[VAL_34:.*]] = llvm.udiv %[[CST_0]], %[[CST_1]] : i32 - // CHECK: %[[VAL_35:.*]] = llvm.urem %[[VAL_34]], %[[CST_1]] : i32 - // CHECK: %[[VAL_36:.*]] = llvm.urem %[[VAL_35]], %[[CST_1]] : i32 - // CHECK: %[[VAL_37:.*]] = llvm.urem %[[VAL_33]], %[[CST_1]] : i32 - // CHECK: %[[CTA_OFFSET_Y:.*]] = llvm.mul %[[VAL_36]], %[[CST_32]] : i32 - // CHECK: %[[CTA_OFFSET_X:.*]] = llvm.mul %[[VAL_37]], %[[CST_32]] : i32 - // CHECK: %[[VAL_40:.*]] = llvm.add %[[OFFSET_X]], %[[CTA_OFFSET_Y]] : i32 - // CHECK: %[[VAL_41:.*]] = llvm.add %[[OFFSET_Y]], %[[CTA_OFFSET_X]] : i32 - // CHECK: %[[OFFSET_X_0:.*]] = llvm.add %[[VAL_40]], %[[CST_0]] : i32 - // CHECK: %[[OFFSET_Y_0:.*]] = llvm.add %[[VAL_41]], %[[CST_0]] : i32 - // CHECK: %[[OFFSET_X_1:.*]] = llvm.add %[[VAL_40]], %[[CST_2]] : i32 - // CHECK: %[[OFFSET_X_2:.*]] = llvm.add %[[VAL_40]], %[[CST_4]] : i32 - // CHECK: %[[OFFSET_X_3:.*]] = llvm.add %[[VAL_40]], %[[CST_6]] : i32 - // CHECK: %[[OFFSET_Y_1:.*]] = llvm.add %[[VAL_41]], %[[CST_16]] : i32 - // CHECK: %[[OFFSET_X_4:.*]] = llvm.add %[[VAL_40]], %[[CST_16]] : i32 - // CHECK: %[[OFFSET_X_5:.*]] = llvm.add %[[VAL_40]], %[[CST_18]] : i32 - // CHECK: %[[OFFSET_X_6:.*]] = llvm.add %[[VAL_40]], %[[CST_20]] : i32 - // CHECK: %[[OFFSET_X_7:.*]] = llvm.add %[[VAL_40]], %[[CST_22]] : i32 + // CHECK-DAG: %[[THREAD_ID:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[CST_0]]) + // CHECK-DAG: %[[THREAD_ID_32:.*]] = llvm.trunc %[[THREAD_ID]] : i64 to i32 + // CHECK-DAG: %[[WARP_ID:.*]] = llvm.udiv %[[THREAD_ID_32]], %[[CST_16]] : i32 + // CHECK-DAG: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_32]], %[[CST_16]] : i32 + // CHECK: %[[VAL_37:.*]] = llvm.and %[[WARP_ID]], %[[CST_1]] : i32 + // CHECK: %[[VAL_38:.*]] = llvm.icmp "eq" %[[VAL_37]], %[[CST_0]] : i32 + // CHECK: %[[VAL_39:.*]] = llvm.select %[[VAL_38]], %[[CST_0]], %[[CST_8]] : i1, i32 + // CHECK: %[[VAL_40:.*]] = llvm.xor %{{.*}}, %[[VAL_39]] : i32 + // CHECK: %[[VAL_41:.*]] = llvm.and %[[WARP_ID]], %[[CST_2]] : i32 + // CHECK: %[[VAL_42:.*]] = llvm.icmp "eq" %[[VAL_41]], %[[CST_0]] : i32 + // CHECK: %[[VAL_43:.*]] = llvm.select %[[VAL_42]], %[[CST_0]], %[[CST_8]] : i1, i32 + // CHECK: %[[VAL_44:.*]] = llvm.xor %{{.*}}, %[[VAL_43]] : i32 + // CHECK: %[[OFFSET_X_0:.*]] = llvm.xor %[[VAL_44]], %[[CST_0]] : i32 + // CHECK: %[[OFFSET_Y_0:.*]] = llvm.xor %[[VAL_40]], %[[CST_0]] : i32 + // CHECK: %[[OFFSET_X_1:.*]] = llvm.xor %[[VAL_44]], %[[CST_2]] : i32 + // CHECK: %[[OFFSET_X_2:.*]] = llvm.xor %[[VAL_44]], %[[CST_4]] : i32 + // CHECK: %[[OFFSET_X_3:.*]] = llvm.xor %[[VAL_44]], %[[CST_6]] : i32 + // CHECK: %[[OFFSET_Y_1:.*]] = llvm.xor %[[VAL_40]], %[[CST_16]] : i32 + // CHECK: %[[OFFSET_X_4:.*]] = llvm.xor %[[VAL_44]], %[[CST_16]] : i32 + // CHECK: %[[OFFSET_X_5:.*]] = llvm.xor %[[VAL_44]], %[[CST_18]] : i32 + // CHECK: %[[OFFSET_X_6:.*]] = llvm.xor %[[VAL_44]], %[[CST_20]] : i32 + // CHECK: %[[OFFSET_X_7:.*]] = llvm.xor %[[VAL_44]], %[[CST_22]] : i32 // CHECK: llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_0]], %[[OFFSET_Y_0]], {{.*}}, {{.*}}) // CHECK: llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_1]], %[[OFFSET_Y_0]], {{.*}}, {{.*}}) // CHECK: llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_2]], %[[OFFSET_Y_0]], {{.*}}, {{.*}}) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Attributes.h b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Attributes.h index db98072fba..b159017e34 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/Attributes.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/Attributes.h @@ -1,6 +1,8 @@ #ifndef TRITON_DIALECT_TRITON_INTEL_GPU_IR_ATTRIBUTES_H #define TRITON_DIALECT_TRITON_INTEL_GPU_IR_ATTRIBUTES_H +#include "triton/Dialect/TritonGPU/IR/Attributes.h" + #define GET_ATTRDEF_CLASSES #include "intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.h.inc" diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h b/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h index 8153393a1b..2758e6341a 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h @@ -15,9 +15,8 @@ namespace mlir::triton::gpu { // DPAS operand B: opIdx=1 // DPAS operand C (default): opIdx=2 // Operand A and B conversion are not used yet -std::optional DPAStoLinearLayout(ArrayRef shape, - Attribute layout, - unsigned opIdx = 2); +LinearLayout DPAStoLinearLayout(ArrayRef shape, Attribute layout, + unsigned opIdx = 2); } // namespace mlir::triton::gpu diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index ef8cb1c055..c6f314aa50 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -2,6 +2,7 @@ #include +#include "intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "triton/Dialect/Triton/IR/Utility.h" @@ -415,6 +416,11 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const { << "}>"; } +std::optional +DpasEncodingAttr::toLinearLayout(ArrayRef shape) const { + return DPAStoLinearLayout(shape, *this); +} + //===----------------------------------------------------------------------===// // WarpEncodingAttr //===----------------------------------------------------------------------===// diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp index 5ac8b9656b..9b5b01ee3e 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp @@ -483,8 +483,8 @@ DPASLaneBasesC(int repeatCount, int executionSize, int threadsPerWarp) { return laneBases; } -std::optional -DPAStoLinearLayout(ArrayRef shape, Attribute layout, unsigned opIdx) { +LinearLayout DPAStoLinearLayout(ArrayRef shape, Attribute layout, + unsigned opIdx) { assert(opIdx < 3 && "opIdx must be 0, 1, or 2"); auto dpas = dyn_cast(layout); assert(dpas && "Must be DPAS layout"); @@ -497,13 +497,12 @@ DPAStoLinearLayout(ArrayRef shape, Attribute layout, unsigned opIdx) { StringAttr kRegister = S("register"); StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); const SmallVector warpsPerCTA = dpas.getWarpsPerCTA(); int threadsPerWarp = triton::gpu::getWarpSize(dpas); unsigned opsPerChannel = dpas.getOpsPerChannel(); auto repCluster = dpas.getRepCluster(); - SmallVector numReps = dpas.getDPASRepetitions(shape, opIdx); - auto tileLayout = LinearLayout::empty(); int systolicDepth = dpas.getSystolicDepth(); int repeatCount = dpas.getRepeatCount(); @@ -520,8 +519,14 @@ DPAStoLinearLayout(ArrayRef shape, Attribute layout, unsigned opIdx) { // A only repeats by repCluster[0] tileLayout *= LinearLayout::identity1D(repCluster[0], kRegister, outDimNames[0]); + nonKDim = 0; KDim = 1; + // K-dimension is shared among warps + tileLayout *= LinearLayout::zeros1D(warpsPerCTA[1], kWarp, outDimNames[1]); + tileLayout *= + LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]); + } else if (opIdx == 1) { // Operand B auto regBasesB = DPASRegBasesB(opsPerChannel, executionSize, threadsPerWarp, systolicDepth); @@ -532,8 +537,14 @@ DPAStoLinearLayout(ArrayRef shape, Attribute layout, unsigned opIdx) { // B only repeats by repCluster[1] tileLayout *= LinearLayout::identity1D(repCluster[1], kRegister, outDimNames[1]); + nonKDim = 1; KDim = 0; + + // K-dimension is shared among warps + tileLayout *= + LinearLayout::identity1D(warpsPerCTA[1], kWarp, outDimNames[1]); + tileLayout *= LinearLayout::zeros1D(warpsPerCTA[0], kWarp, outDimNames[0]); } else { // opIdx=2 -> Operand C auto regBasesC = DPASRegBasesC(repeatCount, executionSize, threadsPerWarp); auto laneBasesC = @@ -547,36 +558,26 @@ DPAStoLinearLayout(ArrayRef shape, Attribute layout, unsigned opIdx) { LinearLayout::identity1D(repCluster[1], kRegister, outDimNames[1]); tileLayout *= LinearLayout::identity1D(repCluster[0], kRegister, outDimNames[0]); + + // The identical layout is repeated among warps + tileLayout *= + LinearLayout::identity1D(warpsPerCTA[1], kWarp, outDimNames[1]); + tileLayout *= + LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]); nonKDim = 0; KDim = 1; } + // Lastly, the layout repeats to match the shape. // Operand A/B repeats through the K-dimension first then repeats - // through non-K dimension. + // through the non-K dimension. + SmallVector numReps = dpas.getDPASRepetitions(shape, opIdx); tileLayout *= LinearLayout::identity1D(numReps[KDim], kRegister, outDimNames[KDim]); tileLayout *= LinearLayout::identity1D(numReps[nonKDim], kRegister, outDimNames[nonKDim]); - // For Operand C, warps split the tensor identically. - // For Operand A and B, warps in the K-dimension share the same data. - // In these cases, the warp hops for K-dimensions are zero. - LinearLayout warpLayout = LinearLayout::empty(); - StringAttr kWarp = S("warp"); - if (opIdx == 0) { - warpLayout = - LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]); - warpLayout *= LinearLayout::zeros1D(warpsPerCTA[1], kWarp, outDimNames[1]); - } else if (opIdx == 1) { - warpLayout = LinearLayout::zeros1D(warpsPerCTA[0], kWarp, outDimNames[0]); - warpLayout *= - LinearLayout::identity1D(warpsPerCTA[1], kWarp, outDimNames[1]); - } else { /* opIdx == 2 */ - warpLayout = identityND(kWarp, warpsPerCTA, {0, 1}, outDimNames); - } - LinearLayout ctaLayout = tileLayout * warpLayout; - - return combineCtaCgaWithShape(std::move(ctaLayout), + return combineCtaCgaWithShape(tileLayout, CTALayoutAttr::getDefault(ctx, rank), shape); } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index b5c54abfb0..bd3998ad6d 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -459,25 +459,16 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion const auto &shape = op.getType().getShape(); std::optional srcLayout; auto srcTy = op.getSrc().getType(); - - if (auto dpasLayout = dyn_cast(srcTy.getEncoding())) { - srcLayout = gpu::DPAStoLinearLayout(shape, dpasLayout); - } else { - srcLayout = gpu::toLinearLayout(shape, srcTy.getEncoding()); - } + srcLayout = gpu::toLinearLayout(shape, srcTy.getEncoding()); std::optional dstLayout; auto dstTy = op.getType(); - if (auto dpasLayout = dyn_cast(dstTy.getEncoding())) { - dstLayout = gpu::DPAStoLinearLayout(shape, dpasLayout); - } else { - dstLayout = gpu::toLinearLayout(shape, dstTy.getEncoding()); - } + + dstLayout = gpu::toLinearLayout(shape, dstTy.getEncoding()); if (!srcLayout.has_value() || !dstLayout.has_value()) { return failure(); } - // There are four cases to handle. // // 1. Transfer between values in the same thread, in which case we simply diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp index 0e6073a403..fdf9d7cdce 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp @@ -113,128 +113,3 @@ LLVM::LLVMFuncOp getSpirvPrintfDeclaration(RewriterBase &rewriter) { } } // namespace mlir::LLVM::intel - -namespace mlir::triton::intel { -bool emitTransferBetweenDPASAndShared( - RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy, - std::optional maxVecElems, Value shmemBase, - ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, - const TargetInfoBase &target, - std::function perVectorCallback) { - MLIRContext *ctx = rewriter.getContext(); - - auto shape = registerTy.getShape(); - int rank = shape.size(); - - StringAttr kBlock = str_attr("block"); - StringAttr kRegister = str_attr("register"); - StringAttr kLane = str_attr("lane"); - StringAttr kWarp = str_attr("warp"); - - std::optional regLayout; - if (auto dpas = dyn_cast(registerTy.getEncoding())) { - // Default is operandC (opidx == 2) - regLayout = triton::gpu::DPAStoLinearLayout(shape, dpas); - } else { - regLayout = triton::gpu::toLinearLayout(shape, registerTy.getEncoding()); - } - - std::optional sharedLayout; - if (auto dpas = dyn_cast(sharedTy.getEncoding())) { - sharedLayout = triton::gpu::DPAStoLinearLayout(shape, dpas); - } else { - sharedLayout = triton::gpu::toLinearLayout( - shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth()); - } - - if (!regLayout.has_value() || !sharedLayout.has_value()) { - return false; - } - auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding()); - - // sharedLayout's in-dims are currently (offset, block). Reshape to - // (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional - // shmem strides. (The offsetX's appear in minor-to-major order.) - auto sharedLegacy = - cast(sharedTy.getEncoding()); - SmallVector> multiDimSharedSize; - for (int i = 0; i < rank; i++) { - int dim = sharedOrder[i]; - int64_t size = std::max( - int64_t{1}, - shape[dim] / sharedLegacy.getCTALayout().getCTASplitNum()[dim]); - multiDimSharedSize.push_back( - {str_attr("offset" + std::to_string(dim)), size}); - } - multiDimSharedSize.push_back({kBlock, sharedLayout->getInDimSize(kBlock)}); - sharedLayout = sharedLayout->reshapeIns(multiDimSharedSize); - - // regToSharedLayout maps from (register, lane, warp, block) to (offsetX1, - // ..., offsetXN, block), where the offsetX's are in minor-to-major order. - LinearLayout regToSharedLayout = regLayout->invertAndCompose(*sharedLayout); - - // TODO(jlebar): We don't currently support loading from shared memory in a - // different CTA. We'd need to emit `mapa.shared::cluster` instructions. - for (int inBlock = 1; inBlock < regToSharedLayout.getInDimSize(kBlock); - inBlock *= 2) { - auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout.apply( - {{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, inBlock}}))); - // offsetX1, ..., offsetXN must all be 0. - if (!llvm::all_of(ArrayRef(idx).drop_back(1), - [&](auto offset) { return offset == 0; })) { - return false; - } - int32_t outBlock = idx.back(); - if (outBlock != inBlock) { - return false; - } - } - - // Determine how many consecutive registers map to consecutive shmem elements - // in out-dimension offsetN. This is our load instruction's vector width. - // - // It's OK if the vector width we choose here is wider than the hardware - // supports; LLVM will legalize it. - // - // TODO(jlebar): shmemStrides are Values, but most of them are usually integer - // constants. We could add those constant strides to the LL, and then before - // calling getNumConsecutiveInOut(), we could flatten consecutive out-dims - // which have known strides. This would allow us to vectorize across multiple - // shmem out dimensions where possible. - const int vecElems = - std::min(regToSharedLayout.getNumConsecutiveInOut(), - maxVecElems.value_or(std::numeric_limits::max())); - - Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(regToSharedLayout.getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); - - int numElems = regToSharedLayout.getInDimSize(kRegister); - auto vecTy = vec_ty(elemLlvmTy, vecElems); - auto ptrTy = ptr_ty(ctx, /*addressSpace=*/3); - Value zero = i32_val(0); - SmallVector ret; - for (int i = 0; i < numElems / vecElems; i++) { - // Get the address to load/store. The multi-dim address is (offsetX1, ..., - // offsetXN, block), where the offsets appear in minor-to-major order, and - // we drop_end to drop block, which we know from above will be 0. - auto multiDimShmemOffset = - llvm::to_vector(llvm::drop_end(llvm::make_second_range( - applyLinearLayout(loc, rewriter, regToSharedLayout, - {{kRegister, i32_val(i * vecElems)}, - {kLane, laneId}, - {kWarp, warpId}, - {kBlock, zero}})))); - - // Reorder strides according to `order`. This way they match the - // multi-dimensional offsets in regToSharedLayout. - Value shmemOffset = dot(rewriter, loc, multiDimShmemOffset, - applyPermutation(shmemStrides, sharedOrder)); - auto vecAddr = gep(ptrTy, elemLlvmTy, shmemBase, shmemOffset); - vecAddr.setInbounds(true); - perVectorCallback(vecTy, vecAddr); - } - return true; -} -} // namespace mlir::triton::intel diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h index 49533a1ba1..6c85f4c6e3 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h @@ -740,33 +740,12 @@ inline DenseMap getSwizzledSharedPtrs( return ret; } -[[nodiscard]] bool emitTransferBetweenDPASAndShared( - RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy, - std::optional maxVecElems, Value shmemBase, - ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, - const TargetInfoBase &target, - std::function perVectorCallback); - inline SmallVector loadSharedToDistributed(RankedTensorType dstTy, MemDescType srcTy, Type elemLlvmTy, SharedMemoryObject &memObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target) { SmallVector ret; - if (isa(dstTy.getEncoding())) { - if (emitTransferBetweenDPASAndShared( - dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, - memObj.getBase(), memObj.getStrides(), loc, rewriter, target, - [&](VectorType vecTy, Value vecAddr) { - auto vecVal = load(vecTy, vecAddr); - vecVal.setAlignment(vecTy.getNumElements() * - elemLlvmTy.getIntOrFloatBitWidth() / 8); - for (int v = 0; v < vecTy.getNumElements(); v++) { - ret.push_back(extract_element(elemLlvmTy, vecVal, i32_val(v))); - } - })) - return ret; - } bool success = emitTransferBetweenRegistersAndShared( dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, memObj.getBase(), memObj.getStrides(), loc, rewriter, target, @@ -789,23 +768,6 @@ inline void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, Value smemBase, ArrayRef dstStrides, Location loc, RewriterBase &rewriter, const TargetInfoBase &target) { - if (isa(srcTy.getEncoding())) { - if (emitTransferBetweenDPASAndShared( - srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase, - dstStrides, loc, rewriter, target, - [&](VectorType vecTy, Value vecAddr) { - ArrayRef vals = srcVals.take_front(vecTy.getNumElements()); - srcVals = srcVals.drop_front(vecTy.getNumElements()); - Value vec = undef(vecTy); - for (int i = 0; i < vals.size(); i++) { - vec = insert_element(vec, vals[i], i32_val(i)); - } - store(vec, vecAddr) - .setAlignment(vecTy.getNumElements() * - elemLlvmTy.getIntOrFloatBitWidth() / 8); - })) - return; - } bool success = emitTransferBetweenRegistersAndShared( srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase, dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { diff --git a/unittest/Dialect/TritonGPU/DPAStoLinearLayoutTest.cpp b/unittest/Dialect/TritonGPU/DPAStoLinearLayoutTest.cpp index caf374e3f9..6d42c9948a 100644 --- a/unittest/Dialect/TritonGPU/DPAStoLinearLayoutTest.cpp +++ b/unittest/Dialect/TritonGPU/DPAStoLinearLayoutTest.cpp @@ -142,7 +142,7 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_withWarp) { { {S("register"), {{2, 0}, {4, 0}, {0, 16}, {8, 0}, {16, 0}}}, {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, - {S("warp"), {{32, 0}, {0, 32}}}, + {S("warp"), {{0, 32}, {32, 0}}}, {S("block"), {}}, }, {S("dim0"), S("dim1")})); @@ -156,7 +156,7 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_withWarpOperandA) { {S("register"), {{0, 1}, {4, 0}, {8, 0}, {16, 0}, {0, 16}, {0, 32}}}, {S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}}, - {S("warp"), {{32, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {32, 0}}}, {S("block"), {}}, }, {S("dim0"), S("dim1")})); @@ -170,7 +170,7 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_withWarpOperandB) { {S("register"), {{1, 0}, {4, 0}, {8, 0}, {0, 16}, {16, 0}, {32, 0}}}, {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {2, 0}}}, - {S("warp"), {{0, 0}, {0, 32}}}, + {S("warp"), {{0, 32}, {0, 0}}}, {S("block"), {}}, }, {S("dim0"), S("dim1")})); @@ -187,6 +187,17 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_withDPASRepetitions) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); + EXPECT_EQ( + DPAStoLinearLayout({128, 128}, dpas({2, 2}, 8, 8, 16, 2, {2, 2}, 32)), + LinearLayout( + { + {S("register"), + {{2, 0}, {4, 0}, {0, 16}, {8, 0}, {0, 64}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("warp"), {{0, 32}, {16, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); } } // anonymous namespace