diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 88ed9a0070..87844f001f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4526,6 +4526,9 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): # skip even if scratch buffer equal to lds_size, because real scratch buffer is typically larger due to padding if scratch_shape[0] * scratch_shape[1] * int32_size >= lds_size: pytest.skip("Scratch buffer is too large") + if is_xpu() and M == 128 and N == 128 and interm_layout and (dst_layout.sz_per_thread == [1, 8] + or dst_layout.sz_per_thread == [4, 4]): + pytest.skip("FIXME: out of resource: shared memory") layouts = f""" #src = {src_layout} diff --git a/test/Conversion/tritongpu_to_gen.mlir b/test/Conversion/tritongpu_to_gen.mlir index a16ce93884..b4ebe39150 100644 --- a/test/Conversion/tritongpu_to_gen.mlir +++ b/test/Conversion/tritongpu_to_gen.mlir @@ -321,8 +321,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func @basic_alloc_tensor() { // CHECK-NEXT: llvm.mlir.constant // CHECK-NEXT: llvm.getelementptr - // CHECK-NEXT: llvm.bitcast - %0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #shared0> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #shared0> tt.return } } @@ -331,8 +330,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: basic_extract_slice(%arg0: !llvm.ptr<3>) - tt.func @basic_extract_slice() { + // CHECK-LABEL: basic_subview(%arg0: !llvm.ptr<3>) + tt.func @basic_subview() { // CHECK: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue @@ -340,11 +339,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue - // CHECK-NEXT: llvm.add - // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 - // CHECK-NEXT: llvm.add - // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 - // CHECK-NEXT: llvm.add // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: llvm.mul // CHECK-NEXT: llvm.add @@ -354,8 +348,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: llvm.add // CHECK-NEXT: llvm.getelementptr %index = arith.constant 1 : i32 - %0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0> - %1 = triton_gpu.extract_slice %0[%index, 0, 0][1, 16, 32][1, 1, 1] : tensor<128x16x32xf32, #shared0> to tensor<16x32xf32, #shared0> + %zero = arith.constant 0 : i32 + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x16x32xf32, #shared0> + %1 = triton_gpu.memdesc_subview %0[%index, %zero, %zero] : !tt.memdesc<128x16x32xf32, #shared0> -> !tt.memdesc<16x32xf32, #shared0> tt.return } } @@ -659,7 +654,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-SAME: vector<8xf32>, !llvm.ptr<3> // CHECK: llvm.store // CHECK-SAME: vector<8xf32>, !llvm.ptr<3> - %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf32, #blocked0> -> tensor<128x32xf32, #shared0> + %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> tt.return } } @@ -938,11 +933,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { // CHECK-LABEL: test_base_index_cache tt.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { + // CHECK: llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.mlir.constant(0 : i32) : i32 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: llvm.call @_Z12get_local_idj([[ZERO]]) : (i32) -> i64 - %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf32, #blocked0> -> tensor<128x32xf32, #shared0> - %1 = triton_gpu.convert_layout %arg0 : tensor<128x32xf32, #blocked0> -> tensor<128x32xf32, #shared0> + %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> + %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> tt.return } } @@ -953,13 +949,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { // CHECK-LABEL: test_index_cache_different_block tt.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { + // CHECK: llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.mlir.constant(0 : i32) : i32 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: llvm.call @_Z12get_local_idj([[ZERO]]) : (i32) -> i64 - %0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf32, #blocked0> -> tensor<128x32xf32, #shared0> + %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> cf.cond_br %arg1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 - %1 = triton_gpu.convert_layout %arg0 : tensor<128x32xf32, #blocked0> -> tensor<128x32xf32, #shared0> + %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> cf.br ^bb2 ^bb2: // 2 preds: ^bb0, ^bb1 tt.return diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 23ae05bc84..b78de69dc5 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -30,12 +30,131 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr bEncoding, const SharedMemoryObject &smemObj, - TritonGPUToLLVMTypeConverter *typeConverter, Value thread); + const LLVMTypeConverter *typeConverter, Value thread); } // namespace intel } // namespace SharedToDotOperandDPAS namespace { + +// shared -> dot_operand if the result layout is dpas +Value lowerSharedToDotOperandDPAS( + triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, + const DpasEncodingAttr &dpasLayout, + const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) { + auto loc = op.getLoc(); + Value src = op.getSrc(); + Value dst = op.getResult(); + + auto llvmElemTy = typeConverter->convertType( + src.getType().cast().getElementType()); + + auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + Value res; + if (!isOuter) { + res = SharedToDotOperandDPAS::intel::convertLayout( + dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout, + smemObj, typeConverter, tid_val()); + } else { + assert(false && "unsupported DPAS layout found"); + } + return res; +} +// shared -> dpas_operand +LogicalResult lowerSharedToDotOperand(triton::gpu::LocalLoadOp op, + triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto loc = op.getLoc(); + auto dstEnc = op.getType().getEncoding().cast(); + auto sharedLayout = + op.getSrc().getType().getEncoding().cast(); + + int K; + if (dstEnc.getOpIdx() == 0) // $a + K = op.getType().getShape()[sharedLayout.getOrder()[0]]; + else // $b + K = op.getType().getShape()[sharedLayout.getOrder()[1]]; + bool isOuter = K == 1; + + Value res; + if (auto dpasLayout = + dstEnc.getParent().dyn_cast_or_null()) { + res = lowerSharedToDotOperandDPAS(op, adaptor, typeConverter, rewriter, + dpasLayout, dstEnc, isOuter); + } else if (auto blockedLayout = + dstEnc.getParent().dyn_cast_or_null()) { + auto thread = getThreadId(rewriter, loc); + res = SharedToDotOperandFMA::convertLayout( + dstEnc.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout, thread, + loc, typeConverter, rewriter); + } else { + assert(false && "Unsupported dot operand layout found"); + } + + rewriter.replaceOp(op, res); + return success(); +} + +LogicalResult lowerSharedToDistributed(triton::gpu::LocalLoadOp op, + triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto loc = op.getLoc(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getResult().getType(); + auto dstShape = dstTy.getShape(); + assert(dstShape.size() <= 2 && + "Unexpected rank of ConvertLayout(shared->blocked)"); + auto srcSharedLayout = srcTy.getEncoding().cast(); + auto dstLayout = dstTy.getEncoding(); + auto inOrd = getOrder(srcSharedLayout); + + auto smemObj = getSharedMemoryObjectFromStruct( + loc, adaptor.getSrc(), typeConverter->convertType(srcTy.getElementType()), + rewriter); + auto elemTy = typeConverter->convertType(dstTy.getElementType()); + + auto srcStrides = + getStridesFromShapeAndOrder(srcTy.getShape(), inOrd, loc, rewriter); + auto dstIndices = emitIndices(loc, rewriter, dstLayout, dstTy, true); + + SmallVector outVals = loadSharedToDistributed( + op.getResult(), dstIndices, op.getSrc(), smemObj, elemTy, loc, rewriter); + + Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + + return success(); +} + +struct LocalLoadOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + triton::gpu::LocalLoadOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemDescType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + if (dstLayout.isa()) { + return lowerSharedToDotOperand(op, adaptor, getTypeConverter(), rewriter); + } + if (srcLayout.isa() && + isaDistributedLayout(dstLayout)) { + return lowerSharedToDistributed(op, adaptor, getTypeConverter(), + rewriter); + } + return failure(); + } +}; + struct ConvertLayoutOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: @@ -49,21 +168,9 @@ struct ConvertLayoutOpConversion RankedTensorType dstTy = op.getType(); Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - if (isaDistributedLayout(srcLayout) && - dstLayout.isa()) { - return lowerDistributedToShared(op, adaptor, rewriter); - } - if (srcLayout.isa() && - dstLayout.isa()) { - return lowerSharedToDotOperand(op, adaptor, rewriter); - } if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) { return lowerDistributedToDistributed(op, adaptor, rewriter); } - if (srcLayout.isa() && - isaDistributedLayout(dstLayout)) { - return lowerSharedToDistributed(op, adaptor, rewriter); - } // TODO: to be implemented llvm_unreachable("unsupported layout conversion"); return failure(); @@ -446,39 +553,6 @@ struct ConvertLayoutOpConversion return success(); } - LogicalResult - lowerSharedToDistributed(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - auto srcTy = op.getSrc().getType(); - auto dstTy = op.getResult().getType(); - auto dstShape = dstTy.getShape(); - assert(dstShape.size() <= 2 && - "Unexpected rank of ConvertLayout(shared->blocked)"); - auto srcSharedLayout = srcTy.getEncoding().cast(); - auto dstLayout = dstTy.getEncoding(); - auto inOrd = getOrder(srcSharedLayout); - - auto smemObj = getSharedMemoryObjectFromStruct( - loc, adaptor.getSrc(), - getTypeConverter()->convertType(srcTy.getElementType()), rewriter); - auto elemTy = getTypeConverter()->convertType(dstTy.getElementType()); - - auto srcStrides = - getStridesFromShapeAndOrder(srcTy.getShape(), inOrd, loc, rewriter); - auto dstIndices = emitIndices(loc, rewriter, dstLayout, dstTy, true); - - SmallVector outVals = - loadSharedToDistributed(op.getResult(), dstIndices, op.getSrc(), - smemObj, elemTy, loc, rewriter); - - Value result = - packLLElements(loc, getTypeConverter(), outVals, rewriter, dstTy); - rewriter.replaceOp(op, result); - - return success(); - } - Value computeStMatrixAddr(Value laneId, int matStride, Location loc, ConversionPatternRewriter &rewriter) const { Value rowInMat = urem(laneId, i32_val(8)); // row in the 8x8 matrix @@ -514,102 +588,6 @@ struct ConvertLayoutOpConversion getTypeConverter()->convertType(elemTy), smemBase, offset); rewriter.create(loc, addr, inputs); } - - // blocked -> shared. - // Swizzling in shared memory to avoid bank conflict. Normally used for - // A/B operands of dots. - LogicalResult - lowerDistributedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - auto srcTy = op.getSrc().getType(); - auto dstTy = op.getType(); - auto dstShapePerCTA = triton::gpu::getShapePerCTA(dstTy); - auto srcLayout = srcTy.getEncoding(); - auto outOrd = dstTy.getEncoding().cast().getOrder(); - assert(srcTy.getShape().size() == 2 || - (srcTy.getShape().size() <= 3 && outOrd[2] == 0) && - "Unexpected rank of ConvertLayout(blocked->shared)"); - Value smemBase = - LLVM::utils::getSharedMemoryBase(loc, rewriter, op.getOperation()); - auto elemTy = getTypeConverter()->convertType(srcTy.getElementType()); - auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); - smemBase = bitcast(smemBase, elemPtrTy); - - int32_t elemSize = elemTy.getIntOrFloatBitWidth(); - unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); - auto dstStrides = - getStridesFromShapeAndOrder(dstShapePerCTA, outOrd, loc, rewriter); - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy, false); - auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - storeDistributedToShared(op.getSrc(), inVals, dstStrides, srcIndices, - op.getResult(), smemBase, elemTy, loc, rewriter); - auto smemObj = SharedMemoryObject(smemBase, elemTy, dstShapePerCTA, outOrd, - loc, rewriter); - auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); - rewriter.replaceOp(op, retVal); - return success(); - } - - // shared -> dpas_operand - LogicalResult - lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - auto dstEnc = op.getType().getEncoding().cast(); - auto sharedLayout = - op.getSrc().getType().getEncoding().cast(); - - int K; - if (dstEnc.getOpIdx() == 0) // $a - K = op.getType().getShape()[sharedLayout.getOrder()[0]]; - else // $b - K = op.getType().getShape()[sharedLayout.getOrder()[1]]; - bool isOuter = K == 1; - - Value res; - if (auto dpasLayout = - dstEnc.getParent().dyn_cast_or_null()) { - res = lowerSharedToDotOperandDPAS(op, adaptor, rewriter, dpasLayout, - dstEnc, isOuter); - } else if (auto blockedLayout = - dstEnc.getParent().dyn_cast_or_null()) { - auto thread = getThreadId(rewriter, loc); - res = SharedToDotOperandFMA::convertLayout( - dstEnc.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout, - thread, loc, getTypeConverter(), rewriter); - } else { - assert(false && "Unsupported dot operand layout found"); - } - - rewriter.replaceOp(op, res); - return success(); - } - - // shared -> dot_operand if the result layout is dpas - Value lowerSharedToDotOperandDPAS( - triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, const DpasEncodingAttr &dpasLayout, - const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const { - auto loc = op.getLoc(); - Value src = op.getSrc(); - Value dst = op.getResult(); - - auto llvmElemTy = getTypeConverter()->convertType( - src.getType().cast().getElementType()); - - auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), - llvmElemTy, rewriter); - Value res; - if (!isOuter) { - res = SharedToDotOperandDPAS::intel::convertLayout( - dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout, - smemObj, getTypeConverter(), tid_val()); - } else { - assert(false && "unsupported DPAS layout found"); - } - return res; - } }; } // namespace @@ -617,4 +595,5 @@ void mlir::triton::intel::populateConvertLayoutOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp index de208ede0d..06fd1d7371 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp @@ -17,7 +17,7 @@ template class DpasMatmulLoader { unsigned warpsPerTile, ArrayRef smemStrides, SmallVector instrShape, ConversionPatternRewriter &rewriter, - TritonGPUToLLVMTypeConverter *typeConverter, Location loc) + const LLVMTypeConverter *typeConverter, Location loc) : dpasLayout(dpasLayout), tensorTy(tensorTy), smemStrides(smemStrides), rewriter(rewriter), loc(loc) { static_assert(opIdx == 0 || opIdx == 1); @@ -190,7 +190,7 @@ Value DpasMatmulLoader::loadMatrix(int repOuter, int repInner, Value composeValuesToDotOperandLayoutStruct( const ValueTable &vals, int n0, int n1, - TritonGPUToLLVMTypeConverter *typeConverter, Location loc, + const LLVMTypeConverter *typeConverter, Location loc, ConversionPatternRewriter &rewriter) { std::vector elems; for (int m = 0; m < n0; ++m) { @@ -234,7 +234,7 @@ getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj, DpasEncodingAttr dpasLayout, unsigned warpsPerTile, SmallVector instrShape, Value warpId, Value outerWarpDim, Value laneId, ValueTable &vals, - TritonGPUToLLVMTypeConverter *typeConverter, + const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc) { static_assert(opIdx == 0 || opIdx == 1); @@ -282,7 +282,7 @@ getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj, template Value loadOperand(ConversionPatternRewriter &rewriter, Location loc, Value threadId, DotOperandEncodingAttr encoding, - TritonGPUToLLVMTypeConverter *typeConverter, Value tensor, + const LLVMTypeConverter *typeConverter, Value tensor, const SharedMemoryObject &smemObj) { static_assert(opIdx == 0 || opIdx == 1); @@ -343,8 +343,7 @@ namespace intel { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, - TritonGPUToLLVMTypeConverter *typeConverter, - Value threadId) { + const LLVMTypeConverter *typeConverter, Value threadId) { switch (opIdx) { case 0: return loadOperand<0>(rewriter, loc, threadId, encoding, typeConverter, diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DecomposeUnsupportedConversions.cpp index 335f06590b..8bb2052398 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -128,7 +128,7 @@ struct DecomposeUnsupportedConversions auto dstDotOp = dstType.getEncoding().dyn_cast(); if (srcBlocked && dstDotOp) { - auto tmpType = RankedTensorType::get( + auto tmpType = MemDescType::get( dstType.getShape(), dstType.getElementType(), triton::gpu::SharedEncodingAttr::get( mod.getContext(), dstDotOp, srcType.getShape(), diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index d62dbcb1a3..202356d78e 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -686,30 +686,31 @@ struct AtomicRMWOpConversion } }; -struct InsertSliceAsyncOpConversion - : public ConvertTritonGPUOpToLLVMPattern, +struct AsyncCopyGlobalToLocalOpConversion + : public ConvertTritonGPUOpToLLVMPattern< + triton::gpu::AsyncCopyGlobalToLocalOp>, public LoadStoreConversionBase { using ConvertTritonGPUOpToLLVMPattern< - triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern; + triton::gpu::AsyncCopyGlobalToLocalOp>::ConvertTritonGPUOpToLLVMPattern; - InsertSliceAsyncOpConversion(TritonGPUToLLVMTypeConverter &converter, - ModuleAxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) - : ConvertTritonGPUOpToLLVMPattern( + AsyncCopyGlobalToLocalOpConversion(TritonGPUToLLVMTypeConverter &converter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern( converter, benefit), LoadStoreConversionBase(axisAnalysisPass) {} LogicalResult - matchAndRewrite(triton::gpu::InsertSliceAsyncOp op, OpAdaptor adaptor, + matchAndRewrite(triton::gpu::AsyncCopyGlobalToLocalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // This function should not be called on the genx target since all - // InsertSliceAsyncOps would be decomposed into InsertSliceOps by the - // decomposeInsertSliceAsyncOp function. + // AsyncCopyGlobalToLocalOps would be decomposed into InsertSliceOps by the + // decomposeAsyncCopyGlobalToLocalOp function. // FIXME: remove this assertion once a suitable replacement instruction // exists for the generated PTX in this function (cp.async.cg.shared.global) assert(false && - "InsertSliceAsyncOpConversion: genx target not supported yet"); + "AsyncCopyGlobalToLocalOpConversion: genx target not supported yet"); // insert_slice_async %src, %dst, %index, %mask, %other auto loc = op.getLoc(); @@ -723,7 +724,7 @@ struct InsertSliceAsyncOpConversion auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType()); auto srcLayout = srcTy.getEncoding(); assert((srcLayout.isa() && - "Unexpected srcLayout in InsertSliceAsyncOpConversion")); + "Unexpected srcLayout in AsyncCopyGlobalToLocalOpConversion")); auto resSharedLayout = dstTy.getEncoding().cast(); auto srcShape = srcTy.getShape(); assert((srcShape.size() <= 3) && @@ -733,7 +734,6 @@ struct InsertSliceAsyncOpConversion Value llSrc = adaptor.getSrc(); Value llMask = adaptor.getMask(); Value llOther = adaptor.getOther(); - Value llIndex = adaptor.getIndex(); // %src auto srcElems = unpackLLElements(loc, llSrc, rewriter); @@ -741,22 +741,6 @@ struct InsertSliceAsyncOpConversion // %dst auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, resElemTy, rewriter); - auto axis = op->getAttrOfType("axis").getInt(); - SmallVector offsetVals; - SmallVector srcStrides; - for (auto i = 0; i < dstTy.getShape().size(); ++i) { - if (i == axis) { - offsetVals.emplace_back(llIndex); - } else { - offsetVals.emplace_back(i32_val(0)); - srcStrides.emplace_back(smemObj.strides[i]); - } - } - // Compute the offset based on the original dimensions of the shared - // memory object - auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides); - auto dstPtrTy = ptr_ty(rewriter.getContext(), 3); - Value dstPtrBase = gep(dstPtrTy, resElemTy, smemObj.base, dstOffset); // %mask SmallVector maskElems; @@ -787,9 +771,10 @@ struct InsertSliceAsyncOpConversion unsigned numElems = getTotalElemsPerThread(srcTy); unsigned perPhase = resSharedLayout.getPerPhase(); unsigned maxPhase = resSharedLayout.getMaxPhase(); + SmallVector offsetVals = {smemObj.strides.size(), i32_val(0)}; DenseMap sharedPtrs = getSwizzledSharedPtrs(loc, inVec, srcTy, resSharedLayout, resElemTy, - smemObj, rewriter, offsetVals, srcStrides); + smemObj, rewriter, offsetVals, smemObj.strides); // A sharedLayout encoding has a "vec" parameter. // On the column dimension, if inVec > outVec, it means we have to divide @@ -852,67 +837,11 @@ struct InsertSliceAsyncOpConversion } } - rewriter.replaceOp(op, llDst); - return success(); - } -}; - -struct ExtractSliceOpConversion - : public ConvertTritonGPUOpToLLVMPattern { - using ConvertTritonGPUOpToLLVMPattern< - triton::gpu::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::gpu::ExtractSliceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // %dst = extract_slice %src[%offsets] - Location loc = op->getLoc(); - auto srcTy = op.getSrc().getType(); - auto srcLayout = srcTy.getEncoding().dyn_cast(); - assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion"); - assert(op.hasUnitStride() && - "Only unit stride supported by ExtractSliceOpConversion"); - - auto typeConverter = getTypeConverter(); - auto llvmElemTy = typeConverter->convertType(srcTy.getElementType()); - - // newBase = base + offset - // Triton supports either static and dynamic offsets - auto smemObj = LLVM::utils::getSharedMemoryObjectFromStruct( - loc, adaptor.getSrc(), llvmElemTy, rewriter); - SmallVector opOffsetVals; - SmallVector offsetVals; - auto mixedOffsets = op.getMixedOffsets(); - for (auto i = 0, j = 0; i < mixedOffsets.size(); ++i) { - if (op.isDynamicOffset(i)) { - // adaptor.getOffsets() returns list of variable offsets. the size of - // the list may not be the same as mixedOffsets - opOffsetVals.emplace_back(adaptor.getOffsets()[j]); - ++j; - } else - opOffsetVals.emplace_back(i32_val(op.getStaticOffset(i))); - offsetVals.emplace_back(add(smemObj.offsets[i], opOffsetVals[i])); - } - // Compute the offset based on the original strides of the shared memory - // object - auto offset = dot(rewriter, loc, opOffsetVals, smemObj.strides); - // newShape = rank_reduce(shape) - // Triton only supports static tensor sizes - SmallVector strideVals; - for (auto i = 0; i < op.getStaticSizes().size(); ++i) { - if (op.getStaticSize(i) == 1) { - offsetVals.erase(offsetVals.begin() + i); - } else { - strideVals.emplace_back(smemObj.strides[i]); - } - } - - auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); - smemObj = - SharedMemoryObject(gep(elemPtrTy, llvmElemTy, smemObj.base, offset), - llvmElemTy, strideVals, offsetVals); - auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); - rewriter.replaceOp(op, retVal); + // Drop the result token. + Value zero = rewriter.create( + op.getLoc(), IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); + rewriter.replaceOp(op, zero); return success(); } }; @@ -1011,9 +940,8 @@ void mlir::triton::intel::populateLoadStoreOpToLLVMPatterns( patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, axisInfoAnalysis, - benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, + axisInfoAnalysis, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp index f9edbb3598..54b8f85816 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp @@ -5,22 +5,51 @@ namespace { using namespace mlir; using namespace mlir::triton; +using namespace mlir::triton::gpu; -struct AllocTensorOpConversion - : public ConvertTritonGPUOpToLLVMPattern { +// blocked -> shared. +// Swizzling in shared memory to avoid bank conflict. Normally used for +// A/B operands of dots. +void lowerDistributedToShared(LocalAllocOp op, LocalAllocOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto loc = op.getLoc(); + auto srcTy = op.getInit().getType(); + auto dstTy = op.getType(); + auto dstShapePerCTA = triton::gpu::getShapePerCTA(dstTy); + auto srcLayout = srcTy.getEncoding(); + auto outOrd = dstTy.getEncoding().cast().getOrder(); + assert(srcTy.getShape().size() == 2 || + (srcTy.getShape().size() <= 3 && outOrd[2] == 0) && + "Unexpected rank of ConvertLayout(blocked->shared)"); + Value smemBase = + LLVM::utils::getSharedMemoryBase(loc, rewriter, op.getOperation()); + auto elemTy = typeConverter->convertType(srcTy.getElementType()); + + int32_t elemSize = elemTy.getIntOrFloatBitWidth(); + auto mmaLayout = srcLayout.dyn_cast(); + unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); + auto dstStrides = LLVM::utils::getStridesFromShapeAndOrder( + dstShapePerCTA, outOrd, loc, rewriter); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy, false); + auto inVals = unpackLLElements(loc, adaptor.getInit(), rewriter); + storeDistributedToShared(op.getInit(), inVals, dstStrides, srcIndices, + op.getResult(), smemBase, elemTy, loc, rewriter); +} + +struct LocalAllocOpConversion + : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< - triton::gpu::AllocTensorOp>::ConvertTritonGPUOpToLLVMPattern; + triton::gpu::LocalAllocOp>::ConvertTritonGPUOpToLLVMPattern; LogicalResult - matchAndRewrite(triton::gpu::AllocTensorOp op, OpAdaptor adaptor, + matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value smemBase = LLVM::utils::getSharedMemoryBase(loc, rewriter, op.getOperation()); - auto resultTy = op.getType().dyn_cast(); - auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + auto resultTy = op.getType().cast(); auto typeConverter = getTypeConverter(); - smemBase = bitcast(smemBase, elemPtrTy); auto sharedLayout = resultTy.getEncoding().cast(); auto order = sharedLayout.getOrder(); @@ -36,6 +65,11 @@ struct AllocTensorOpConversion newOrder = SmallVector(order.begin(), order.end()); } + // If there is an initial tensor, store it into the shared memory. + if (op.getInit()) { + lowerDistributedToShared(op, adaptor, typeConverter, rewriter); + } + auto llvmElemTy = typeConverter->convertType(resultTy.getElementType()); auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA, @@ -46,13 +80,13 @@ struct AllocTensorOpConversion } }; -struct DeallocTensorOpConversion - : public ConvertTritonGPUOpToLLVMPattern { +struct LocalDeallocOpConversion + : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< - triton::gpu::DeallocTensorOp>::ConvertTritonGPUOpToLLVMPattern; + triton::gpu::LocalDeallocOp>::ConvertTritonGPUOpToLLVMPattern; LogicalResult - matchAndRewrite(triton::gpu::DeallocTensorOp op, OpAdaptor adaptor, + matchAndRewrite(triton::gpu::LocalDeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.eraseOp(op); return success(); @@ -64,6 +98,6 @@ struct DeallocTensorOpConversion void mlir::triton::intel::populateMemoryOpToLLVMPattern( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h index 2fef64e2a2..8fca99e21c 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h @@ -1237,7 +1237,7 @@ loadSharedToDistributed(Value dst, ArrayRef> dstIndices, auto dstTy = dst.getType().cast(); auto dstShape = dstTy.getShape(); assert(dstShape.size() <= 2 && "Unexpected rank of loadSharedToDistributed"); - auto srcTy = src.getType().cast(); + auto srcTy = src.getType().cast(); auto dstDistributedLayout = dstTy.getEncoding(); if (auto mmaLayout = dstDistributedLayout.dyn_cast()) { assert((!mmaLayout.isVolta()) && @@ -1289,7 +1289,7 @@ static void storeDistributedToShared(Value src, ArrayRef inVals, auto rank = srcShape.size(); assert(rank == 2 || rank == 3 && "Unexpected rank of storeDistributedToShared"); - auto dstTy = dst.getType().cast(); + auto dstTy = dst.getType().cast(); auto srcDistributedLayout = srcTy.getEncoding(); if (auto mmaLayout = srcDistributedLayout.dyn_cast()) { assert((!mmaLayout.isVolta()) &&