From 91d84c379e5acdb1283dfd8c009f5755d92b1bc1 Mon Sep 17 00:00:00 2001 From: Dewei Date: Mon, 25 Mar 2024 01:08:38 -0700 Subject: [PATCH 01/17] [Conversion] convert triton ops with block ptr to llvm --- include/triton/Tools/Sys/GetEnv.hpp | 1 + .../TritonGPUToLLVM/TypeConverter.cpp | 29 +- .../lib/TritonIntelGPUToLLVM/CMakeLists.txt | 1 + .../PatternTritonGPUOpToLLVM.h | 4 + .../TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp | 16 +- .../TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 372 ++++++++++++++++++ 6 files changed, 413 insertions(+), 10 deletions(-) create mode 100644 third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 9caf2c9734..6352643e4f 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -44,6 +44,7 @@ inline const std::set ENV_VARS = { "TRITON_ENABLE_LLVM_DEBUG", "USE_TTGIR_LOC", "TRITON_INTEL_EMULATE_FP16_ATOMICS", + "INTEL_ENABLE_BLOCK_PTR", }; namespace tools { diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index ae74a03da0..e2278dd1cc 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -17,12 +17,29 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( MLIRContext *ctx, LowerToLLVMOptions &option, const DataLayoutAnalysis *analysis) : LLVMTypeConverter(ctx, option, analysis) { - addConversion([&](triton::PointerType type) -> std::optional { - return convertTritonPointerType(type); - }); - addConversion([&](RankedTensorType type) -> std::optional { - return convertTritonTensorType(type); - }); + if (mlir::triton::tools::getBoolEnv("INTEL_ENABLE_BLOCK_PTR")) { + // tt::pointer to v2i32 + addConversion([&](PointerType type) -> std::optional { + if (isa(type.getPointeeType())) { + auto i32Type = mlir::IntegerType::get(type.getContext(), 32); + return mlir::VectorType::get(2, i32Type); + } + return LLVM::LLVMPointerType::get(type.getContext(), + type.getAddressSpace()); + }); + // tensor type is flattened and divided by 16(subgroupSize) + addConversion([&](mlir::RankedTensorType type) -> mlir::Type { + return mlir::VectorType::get(type.getNumElements() / 16, + type.getElementType()); + }); + } else { + addConversion([&](triton::PointerType type) -> std::optional { + return convertTritonPointerType(type); + }); + addConversion([&](RankedTensorType type) -> std::optional { + return convertTritonTensorType(type); + }); + } addConversion([&](MemDescType type) -> std::optional { return convertMemDescType(type); }); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt index 493771c869..81ce866594 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt @@ -18,6 +18,7 @@ add_triton_library(TritonIntelGPUToLLVM ScanOpToLLVM.cpp Utility.cpp TensorPtrOpsToLLVM.cpp + TritonOpsToLLVM.cpp ClusterOpsToLLVM.cpp AllocateSharedMemory.cpp TargetInfo.cpp diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h index 88b0c407de..d2b63fb63e 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -10,6 +10,10 @@ namespace mlir { namespace triton { namespace intel { +void populateTritonOpsToLLVMPatterns( + TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + Target target, PatternBenefit benefit); + void populateBarrierOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp index 3472e22176..5746b4eff8 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp @@ -201,6 +201,8 @@ struct ConvertTritonGPUToLLVM void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); + auto enableBlockPtr = + mlir::triton::tools::getBoolEnv("INTEL_ENABLE_BLOCK_PTR"); mlir::LowerToLLVMOptions option(context); option.overrideIndexBitwidth(32); @@ -223,8 +225,9 @@ struct ConvertTritonGPUToLLVM RewritePatternSet funcPatterns(context); funcPatterns.add(typeConverter, numWarps, /*benefit=*/1); - mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, - funcPatterns); + if (!enableBlockPtr) + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + funcPatterns); if (failed( applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) return signalPassFailure(); @@ -237,6 +240,11 @@ struct ConvertTritonGPUToLLVM mlir::triton::intel::TargetInfo targetInfo(computeCapability); int benefit = 10; using namespace mlir::triton::intel; + if (enableBlockPtr) { + populateTritonOpsToLLVMPatterns(typeConverter, patterns, target, benefit); + populateControlFlowOpToLLVMPattern(typeConverter, patterns, target, + benefit); + } else { populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, benefit); populateDotOpToLLVMPatterns(typeConverter, patterns, benefit); mlir::triton::intel::populateElementwiseOpToLLVMPatterns( @@ -265,8 +273,8 @@ struct ConvertTritonGPUToLLVM patterns, benefit); mlir::triton::intel::populateMakeRangeOpToLLVMPattern(typeConverter, patterns, benefit); - mlir::triton::intel::populateSPMDOpToLLVMPattern(typeConverter, patterns, - targetInfo, benefit); + } + populateSPMDOpToLLVMPattern(typeConverter, patterns, benefit); // TODO(thomas): this should probably be done in a separate step to not // interfere with our own lowering of arith ops. Add arith/math's patterns // to help convert scalar expression to LLVM. diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp new file mode 100644 index 0000000000..43eae213b2 --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -0,0 +1,372 @@ +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGEN/IR/TritonGENDialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +#include "PatternTritonGPUOpToLLVM.h" +#include "TypeConverter.h" +#include "Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +/// offsetX, offsetY for 2D tensor desc +class MakeTensorPtrOpConversion + : public ConvertTritonGPUOpToLLVMPattern { +public: + using ConvertTritonGPUOpToLLVMPattern< + MakeTensorPtrOp>::ConvertTritonGPUOpToLLVMPattern; + LogicalResult + matchAndRewrite(MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto i32Type = rewriter.getI32Type(); + auto i64Type = rewriter.getI64Type(); + auto v2i32 = VectorType::get(2, i32Type); + Value payLoad = rewriter.create(loc, v2i32); + auto createIntConstant = [&](Type type, unsigned value) { + auto attr = rewriter.getIntegerAttr(type, value); + return rewriter.create(loc, type, attr); + }; + // if (rank == 2) { + auto offsetX = op.getOffsets()[1]; + auto offsetY = op.getOffsets()[0]; + auto idx0 = createIntConstant(i32Type, 0); + auto idx1 = createIntConstant(i32Type, 1); + payLoad = + rewriter.create(loc, payLoad, offsetX, idx0); + payLoad = + rewriter.create(loc, payLoad, offsetY, idx1); + rewriter.replaceOp(op, payLoad); + return success(); + } +}; + +class AdvanceOpConversion : public ConvertTritonGPUOpToLLVMPattern { +public: + using ConvertTritonGPUOpToLLVMPattern< + AdvanceOp>::ConvertTritonGPUOpToLLVMPattern; + LogicalResult + matchAndRewrite(AdvanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto i32Type = rewriter.getI32Type(); + auto offsets = adaptor.getOffsets(); + Value ptr = adaptor.getPtr(); + for (size_t i = 0; i < offsets.size(); i++) { + auto offset = offsets[i]; + if (auto cst = dyn_cast(offset.getDefiningOp())) + if (auto attr = dyn_cast(cst.getValue()); + attr && attr.getInt() == 0) + continue; + auto idx0 = rewriter.create( + loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)); + auto idx1 = rewriter.create( + loc, i32Type, rewriter.getIntegerAttr(i32Type, 1)); + Value idx = i == 0 ? idx1 : idx0; + auto oldOffset = rewriter.create(loc, ptr, idx); + auto newOffset = + rewriter.create(loc, i32Type, oldOffset, offset); + ptr = rewriter.create(loc, ptr, newOffset, idx); + } + rewriter.replaceOp(op, ptr); + return success(); + } +}; + +template +class LoadStorePrefetchOpConversion + : public ConvertTritonGPUOpToLLVMPattern { +public: + using ConvertTritonGPUOpToLLVMPattern< + OpType>::ConvertTritonGPUOpToLLVMPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto ptrType = cast(op.getPtr().getType()); + auto tType = cast(ptrType.getPointeeType()); + auto rank = tType.getRank(); + assert(rank <= 2 && "only support 1d/2d load/store/prefetch for now"); + auto loc = op.getLoc(); + constexpr bool isLoad = std::is_same_v; + constexpr bool isPrefetch = + std::is_same_v; + auto createIntConstant = [&](Type type, unsigned value) { + auto attr = rewriter.getIntegerAttr(type, value); + return rewriter.create(loc, type, attr); + }; + auto i16Type = rewriter.getI16Type(); + auto i32Type = rewriter.getI32Type(); + auto i64Type = rewriter.getI64Type(); + auto v4i64 = VectorType::get(4, i64Type); + auto vnni = false; + auto transpose = false; + if constexpr (isLoad) { + auto idxAttr = op->template getAttrOfType("DotIdx"); + vnni = idxAttr.getInt() == 1 ? true : false; + } + unsigned dataSize = tType.getElementType().getIntOrFloatBitWidth(); + auto blockWidth = tType.getShape()[1]; + auto blockHeight = tType.getShape()[0]; + auto idx0 = createIntConstant(i32Type, 0); + auto idx1 = createIntConstant(i32Type, 1); + Value ptr = op.getPtr(); + if (auto cast = + dyn_cast(ptr.getDefiningOp())) + ptr = cast.getInputs()[0]; + MakeTensorPtrOp ptrOp = getMakeTensorPtrOp(ptr); + Value base = ptrOp.getBase(); + if (auto cast = + dyn_cast(base.getDefiningOp())) + base = cast.getInputs()[0]; + + auto insertPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(ptrOp); + auto bytes = createIntConstant( + i32Type, tType.getElementType().getIntOrFloatBitWidth() / 8); + auto one = createIntConstant(i32Type, 1); + Value surfaceW = + rewriter.create(loc, i32Type, ptrOp.getShape()[1]); + surfaceW = rewriter.create(loc, surfaceW, bytes); + surfaceW = rewriter.create(loc, surfaceW, one); + Value surfaceH = + rewriter.create(loc, i32Type, ptrOp.getShape()[0]); + surfaceH = rewriter.create(loc, surfaceH, one); + Value surfaceP = + rewriter.create(loc, i32Type, ptrOp.getStrides()[0]); + surfaceP = rewriter.create(loc, surfaceP, bytes); + surfaceP = rewriter.create(loc, surfaceP, one); + rewriter.restoreInsertionPoint(insertPoint); + + auto getIntType = [&](Type type, bool is16Bit = false) { + auto tType = cast(type); + auto elemType = is16Bit ? i16Type : i32Type; + auto ratio = + elemType.getIntOrFloatBitWidth() / tType.getElementTypeBitWidth(); + auto num = tType.getNumElements() / 16 / ratio; + return VectorType::get(num, elemType); + }; + auto tensorPtr = adaptor.getPtr(); + auto offsetX = + rewriter.create(loc, tensorPtr, idx0); + auto offsetY = + rewriter.create(loc, tensorPtr, idx1); + if constexpr (isLoad) { + auto resType = + this->getTypeConverter()->convertType(op->getResult(0).getType()); + auto idxAttr = op->template getAttrOfType("DotIdx"); + auto idx = idxAttr.getInt(); + auto intType = getIntType(op->getResult(0).getType(), idx == 0); + auto load = rewriter.create( + loc, intType, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, + dataSize, blockWidth / 2, blockHeight, 2 /*v_blocks*/, transpose, + vnni); + auto cast = rewriter.create(loc, resType, load); + rewriter.replaceOp(op, cast); + } else if constexpr (isPrefetch) { + auto load = rewriter.create( + loc, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, dataSize, + blockWidth, blockHeight, 1 /*v_blocks*/, transpose, vnni); + rewriter.eraseOp(op); + } else { + auto intType = getIntType(op.getValue().getType()); + auto cast = + rewriter.create(loc, intType, adaptor.getValue()); + rewriter.create( + loc, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, dataSize, + blockWidth, blockHeight, 1 /*v_blocks*/, transpose, vnni, cast); + rewriter.eraseOp(op); + } + return success(); + } +}; + +class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { +public: + using ConvertTritonGPUOpToLLVMPattern::ConvertTritonGPUOpToLLVMPattern; + LogicalResult + matchAndRewrite(DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto i16Type = rewriter.getI16Type(); + auto i32Type = rewriter.getI32Type(); + auto encodePrecision = [&](Type type) -> TritonGen::PrecisionType { + if (type == rewriter.getBF16Type()) + return TritonGen::PrecisionType::BF16; // 9; + else if (type == rewriter.getF16Type()) + return TritonGen::PrecisionType::FP16; // 10; + else if (type == rewriter.getTF32Type()) + return TritonGen::PrecisionType::TF32; // 12; + else { + assert(0 && "add more support"); + return TritonGen::PrecisionType::PRECISION_UNUSED; + } + }; + auto preca = encodePrecision(op.getA().getType().getElementType()); + auto precb = encodePrecision(op.getB().getType().getElementType()); + auto precA = + TritonGen::PrecisionTypeAttr::get(rewriter.getContext(), preca); + auto precB = + TritonGen::PrecisionTypeAttr::get(rewriter.getContext(), precb); + auto rc = IntegerAttr::get(i32Type, 8); + // sd dpasW fixed in genx.dpas lowering + auto getIntType = [&](Type type, bool is16Bit = false) { + auto tType = cast(type); + auto elemType = is16Bit ? i16Type : i32Type; + auto ratio = + elemType.getIntOrFloatBitWidth() / tType.getElementTypeBitWidth(); + auto num = tType.getNumElements() / 16 / ratio; + return VectorType::get(num, elemType); + }; + auto intTypeA = getIntType(op.getA().getType(), true); + auto castA = + rewriter.create(loc, intTypeA, adaptor.getA()); + auto intTypeB = getIntType(op.getB().getType()); + auto castB = + rewriter.create(loc, intTypeB, adaptor.getB()); + auto dpas = rewriter.create( + loc, adaptor.getC().getType(), adaptor.getC(), castA, castB, precA, + precB, rc); + rewriter.replaceOp(op, dpas); + return success(); + } +}; + +class GlueOpConversion : public ConvertTritonGPUOpToLLVMPattern { +public: + using ConvertTritonGPUOpToLLVMPattern< + GlueOp>::ConvertTritonGPUOpToLLVMPattern; + LogicalResult + matchAndRewrite(GlueOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto operands = adaptor.getOperands(); + auto dstType = + cast(getTypeConverter()->convertType(op.getType())); + auto numElts = dstType.getNumElements(); + SmallVector indices(numElts); + std::iota(indices.begin(), indices.end(), 0); + auto attr = rewriter.getDenseI32ArrayAttr(indices); + auto num = operands.size(); + if (num == 1) { + rewriter.replaceOp(op, operands[0]); + } else if (num == 2) { + rewriter.replaceOpWithNewOp( + op, dstType, operands[0], operands[1], attr); + } else if (num == 4) { + auto subType = VectorType::get(numElts / 2, dstType.getElementType()); + indices.pop_back_n(numElts / 2); + auto attr01 = rewriter.getDenseI32ArrayAttr(indices); + auto shfl01 = rewriter.create( + loc, subType, operands[0], operands[1], attr01); + auto attr23 = rewriter.getDenseI32ArrayAttr(indices); + auto shfl23 = rewriter.create( + loc, subType, operands[2], operands[3], attr23); + auto shfl = rewriter.create(loc, dstType, shfl01, + shfl23, attr); + rewriter.replaceOp(op, shfl); + } else { + assert(0 && "add more support for tt.glue to llvm"); + } + return success(); + } +}; + +class CastOpConversion : public ConvertTritonGPUOpToLLVMPattern { +public: + using ConvertTritonGPUOpToLLVMPattern< + CastOp>::ConvertTritonGPUOpToLLVMPattern; + LogicalResult + matchAndRewrite(CastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstType = getTypeConverter()->convertType(op.getType()); + auto cast = + rewriter.create(loc, dstType, adaptor.getSrc()); + rewriter.replaceOp(op, cast); + return success(); + } +}; + +class ExtractOpConversion : public ConvertTritonGPUOpToLLVMPattern { +public: + using ConvertTritonGPUOpToLLVMPattern< + ExtractOp>::ConvertTritonGPUOpToLLVMPattern; + LogicalResult + matchAndRewrite(ExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto base = adaptor.getBase(); + auto idx = op.getIdx(); + auto dstType = + cast(getTypeConverter()->convertType(op.getType())); + auto numElts = dstType.getNumElements(); + SmallVector indices(numElts); + auto start = idx * numElts; + std::iota(indices.begin(), indices.end(), start); + auto attr = rewriter.getDenseI32ArrayAttr(indices); + rewriter.replaceOpWithNewOp(op, dstType, base, base, + attr); + return success(); + } +}; + +// fixme: support it in upstream constantOpLowering +class ArithConstantOpLowering + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + mlir::arith::ConstantOp>::ConvertTritonGPUOpToLLVMPattern; + LogicalResult + matchAndRewrite(mlir::arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto srcType = dyn_cast(op.getType()); + if (!srcType || srcType.getNumElements() == 1) + return failure(); + + // arith.constant should only have vector or tenor types. + assert((isa(srcType))); + + Type dstType = getTypeConverter()->convertType(srcType); + if (!dstType) + return failure(); + + auto dstElementsAttr = dyn_cast(op.getValue()); + if (!dstElementsAttr) + return failure(); + + ShapedType dstAttrType = dstElementsAttr.getType(); + auto vecType = cast(dstType); + dstAttrType = + VectorType::get(vecType.getNumElements(), vecType.getElementType()); + dstElementsAttr = dstElementsAttr.resizeSplat(dstAttrType); + auto newOp = + rewriter.create(loc, dstType, dstElementsAttr); + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +} // namespace + +void mlir::triton::populateTritonOpsToLLVMPatterns( + TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + Target target, PatternBenefit benefit) { + patterns.add(typeConverter, target, benefit); + patterns.add(typeConverter, target, benefit); + patterns.add(typeConverter, target, benefit); + patterns.add>( + typeConverter, target, benefit); + patterns.add>(typeConverter, target, + benefit); + patterns.add>(typeConverter, target, + benefit); + patterns.add(typeConverter, target, benefit); + patterns.add(typeConverter, target, benefit); + patterns.add(typeConverter, target, benefit); + patterns.add(typeConverter, target, benefit); +} From e3cbf2809b4a01fa0bc1c713c21872bfb64c96dd Mon Sep 17 00:00:00 2001 From: Dewei Wang Date: Wed, 17 Apr 2024 03:02:10 -0700 Subject: [PATCH 02/17] fix rebase --- lib/Analysis/Utility.cpp | 2 +- .../TritonGPUToLLVM/TypeConverter.cpp | 1 + .../PatternTritonGPUOpToLLVM.h | 2 +- .../TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp | 78 +++++++++------- .../TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 89 ++++++++----------- 5 files changed, 84 insertions(+), 88 deletions(-) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index fe99e71950..9fe405e266 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -911,7 +911,7 @@ MakeTensorPtrOp getMakeTensorPtrOp(Value v) { if (auto forOp = dyn_cast(argOwner)) return getMakeTensorPtrOp( forOp.getOperand(argNum + forOp.getNumControlOperands() - 1)); - if (auto funcOp = dyn_cast(argOwner)) { + if (auto funcOp = dyn_cast(argOwner)) { Block *block = arg.getOwner(); Operation *op; int tOrF; diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index e2278dd1cc..b5e2ee32d6 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -2,6 +2,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" using namespace mlir; using namespace mlir::triton; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h index d2b63fb63e..6c6a644ff8 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -12,7 +12,7 @@ namespace intel { void populateTritonOpsToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - Target target, PatternBenefit benefit); + PatternBenefit benefit); void populateBarrierOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp index 5746b4eff8..91334a1ba9 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp @@ -17,6 +17,7 @@ #include "intel/include/GPUToTritonGEN/GPUToTritonGENPass.h" #include "intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h" +#include "intel/include/TritonIntelGPUToLLVM/Passes.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" @@ -26,6 +27,7 @@ #include "triton/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "triton/Tools/Sys/GetPlatform.hpp" #include "PatternTritonGPUOpToLLVM.h" @@ -203,6 +205,10 @@ struct ConvertTritonGPUToLLVM ModuleOp mod = getOperation(); auto enableBlockPtr = mlir::triton::tools::getBoolEnv("INTEL_ENABLE_BLOCK_PTR"); + // fixme: set subgroupSize 16 for now + if (enableBlockPtr) + mod->setAttr("triton_gpu.threads-per-warp", + IntegerAttr::get(IntegerType::get(context, 32), 16)); mlir::LowerToLLVMOptions option(context); option.overrideIndexBitwidth(32); @@ -212,10 +218,12 @@ struct ConvertTritonGPUToLLVM int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - // Allocate shared memory and set barrier - ModuleAllocation allocation(mod); - ModuleMembarAnalysis membarPass(&allocation); - membarPass.run(); + if (!enableBlockPtr) { + // Allocate shared memory and set barrier + ModuleAllocation allocation(mod); + ModuleMembarAnalysis membarPass(&allocation); + membarPass.run(); + } // Lower functions { @@ -241,40 +249,42 @@ struct ConvertTritonGPUToLLVM int benefit = 10; using namespace mlir::triton::intel; if (enableBlockPtr) { - populateTritonOpsToLLVMPatterns(typeConverter, patterns, target, benefit); - populateControlFlowOpToLLVMPattern(typeConverter, patterns, target, - benefit); + mlir::triton::intel::populateTritonOpsToLLVMPatterns(typeConverter, + patterns, benefit); + mlir::triton::intel::populateControlFlowOpToLLVMPattern( + typeConverter, patterns, benefit); } else { - populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, benefit); - populateDotOpToLLVMPatterns(typeConverter, patterns, benefit); - mlir::triton::intel::populateElementwiseOpToLLVMPatterns( - typeConverter, patterns, axisInfoAnalysis, computeCapability, - targetInfo, benefit); - populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis, - benefit); - mlir::triton::intel::populateReduceOpToLLVMPatterns(typeConverter, patterns, + populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, benefit); + populateDotOpToLLVMPatterns(typeConverter, patterns, benefit); + mlir::triton::intel::populateElementwiseOpToLLVMPatterns( + typeConverter, patterns, axisInfoAnalysis, computeCapability, + targetInfo, benefit); + populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, + axisInfoAnalysis, benefit); + mlir::triton::intel::populateReduceOpToLLVMPatterns( + typeConverter, patterns, targetInfo, benefit); + mlir::triton::intel::populateScanOpToLLVMPatterns(typeConverter, patterns, targetInfo, benefit); - mlir::triton::intel::populateScanOpToLLVMPatterns(typeConverter, patterns, - targetInfo, benefit); - mlir::triton::intel::populateViewOpToLLVMPatterns(typeConverter, patterns, - benefit); - - populateTensorPtrOpsToLLVMPatterns(typeConverter, patterns, benefit); - populateClusterOpsToLLVMPatterns(typeConverter, patterns, benefit); - mlir::triton::intel::populateHistogramOpToLLVMPatterns(typeConverter, - patterns, benefit); - mlir::triton::intel::populatePrintOpToLLVMPattern(typeConverter, patterns, - targetInfo, benefit); - mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, - targetInfo, benefit); - mlir::triton::intel::populateMemoryOpToLLVMPattern(typeConverter, patterns, - benefit); - mlir::triton::intel::populateControlFlowOpToLLVMPattern(typeConverter, + mlir::triton::intel::populateViewOpToLLVMPatterns(typeConverter, patterns, + benefit); + + populateTensorPtrOpsToLLVMPatterns(typeConverter, patterns, benefit); + populateClusterOpsToLLVMPatterns(typeConverter, patterns, benefit); + mlir::triton::intel::populateHistogramOpToLLVMPatterns(typeConverter, + patterns, benefit); + mlir::triton::intel::populatePrintOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::intel::populateMemoryOpToLLVMPattern(typeConverter, + patterns, benefit); + mlir::triton::intel::populateControlFlowOpToLLVMPattern( + typeConverter, patterns, benefit); + mlir::triton::intel::populateMakeRangeOpToLLVMPattern(typeConverter, patterns, benefit); - mlir::triton::intel::populateMakeRangeOpToLLVMPattern(typeConverter, - patterns, benefit); } - populateSPMDOpToLLVMPattern(typeConverter, patterns, benefit); + mlir::triton::intel::populateSPMDOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); // TODO(thomas): this should probably be done in a separate step to not // interfere with our own lowering of arith ops. Add arith/math's patterns // to help convert scalar expression to LLVM. diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index 43eae213b2..c984a27833 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -6,11 +6,13 @@ #include "triton/Tools/Sys/GetEnv.hpp" #include "PatternTritonGPUOpToLLVM.h" -#include "TypeConverter.h" #include "Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" using namespace mlir; using namespace mlir::triton; +using namespace mlir::triton::gpu::intel; namespace { @@ -93,8 +95,7 @@ class LoadStorePrefetchOpConversion assert(rank <= 2 && "only support 1d/2d load/store/prefetch for now"); auto loc = op.getLoc(); constexpr bool isLoad = std::is_same_v; - constexpr bool isPrefetch = - std::is_same_v; + constexpr bool isPrefetch = std::is_same_v; auto createIntConstant = [&](Type type, unsigned value) { auto attr = rewriter.getIntegerAttr(type, value); return rewriter.create(loc, type, attr); @@ -111,6 +112,9 @@ class LoadStorePrefetchOpConversion } unsigned dataSize = tType.getElementType().getIntOrFloatBitWidth(); auto blockWidth = tType.getShape()[1]; + assert(blockWidth == 16 || blockWidth == 32 && "only support 16/32 block"); + auto vBlks = blockWidth == 32 ? 2 : 1; + blockWidth = 16; auto blockHeight = tType.getShape()[0]; auto idx0 = createIntConstant(i32Type, 0); auto idx1 = createIntConstant(i32Type, 1); @@ -161,24 +165,24 @@ class LoadStorePrefetchOpConversion auto idxAttr = op->template getAttrOfType("DotIdx"); auto idx = idxAttr.getInt(); auto intType = getIntType(op->getResult(0).getType(), idx == 0); - auto load = rewriter.create( + auto load = rewriter.create( loc, intType, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, - dataSize, blockWidth / 2, blockHeight, 2 /*v_blocks*/, transpose, - vnni); + dataSize, blockWidth, blockHeight, vBlks, transpose, vnni); auto cast = rewriter.create(loc, resType, load); rewriter.replaceOp(op, cast); } else if constexpr (isPrefetch) { - auto load = rewriter.create( + auto load = rewriter.create( loc, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, dataSize, - blockWidth, blockHeight, 1 /*v_blocks*/, transpose, vnni); + blockWidth, blockHeight, vBlks, transpose, vnni, + TritonGEN::PrefetchCacheControl::L1C_L3C); rewriter.eraseOp(op); } else { auto intType = getIntType(op.getValue().getType()); auto cast = rewriter.create(loc, intType, adaptor.getValue()); - rewriter.create( + rewriter.create( loc, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, dataSize, - blockWidth, blockHeight, 1 /*v_blocks*/, transpose, vnni, cast); + blockWidth, blockHeight, vBlks, transpose, vnni, cast); rewriter.eraseOp(op); } return success(); @@ -194,24 +198,24 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { auto loc = op.getLoc(); auto i16Type = rewriter.getI16Type(); auto i32Type = rewriter.getI32Type(); - auto encodePrecision = [&](Type type) -> TritonGen::PrecisionType { + auto encodePrecision = [&](Type type) -> TritonGEN::PrecisionType { if (type == rewriter.getBF16Type()) - return TritonGen::PrecisionType::BF16; // 9; + return TritonGEN::PrecisionType::BF16; else if (type == rewriter.getF16Type()) - return TritonGen::PrecisionType::FP16; // 10; + return TritonGEN::PrecisionType::FP16; else if (type == rewriter.getTF32Type()) - return TritonGen::PrecisionType::TF32; // 12; + return TritonGEN::PrecisionType::TF32; else { assert(0 && "add more support"); - return TritonGen::PrecisionType::PRECISION_UNUSED; + return TritonGEN::PrecisionType::UNUSED; } }; auto preca = encodePrecision(op.getA().getType().getElementType()); auto precb = encodePrecision(op.getB().getType().getElementType()); auto precA = - TritonGen::PrecisionTypeAttr::get(rewriter.getContext(), preca); + TritonGEN::PrecisionTypeAttr::get(rewriter.getContext(), preca); auto precB = - TritonGen::PrecisionTypeAttr::get(rewriter.getContext(), precb); + TritonGEN::PrecisionTypeAttr::get(rewriter.getContext(), precb); auto rc = IntegerAttr::get(i32Type, 8); // sd dpasW fixed in genx.dpas lowering auto getIntType = [&](Type type, bool is16Bit = false) { @@ -228,7 +232,7 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { auto intTypeB = getIntType(op.getB().getType()); auto castB = rewriter.create(loc, intTypeB, adaptor.getB()); - auto dpas = rewriter.create( + auto dpas = rewriter.create( loc, adaptor.getC().getType(), adaptor.getC(), castA, castB, precA, precB, rc); rewriter.replaceOp(op, dpas); @@ -270,28 +274,12 @@ class GlueOpConversion : public ConvertTritonGPUOpToLLVMPattern { shfl23, attr); rewriter.replaceOp(op, shfl); } else { - assert(0 && "add more support for tt.glue to llvm"); + assert(0 && "add more support for glue op to llvm"); } return success(); } }; -class CastOpConversion : public ConvertTritonGPUOpToLLVMPattern { -public: - using ConvertTritonGPUOpToLLVMPattern< - CastOp>::ConvertTritonGPUOpToLLVMPattern; - LogicalResult - matchAndRewrite(CastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto dstType = getTypeConverter()->convertType(op.getType()); - auto cast = - rewriter.create(loc, dstType, adaptor.getSrc()); - rewriter.replaceOp(op, cast); - return success(); - } -}; - class ExtractOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: using ConvertTritonGPUOpToLLVMPattern< @@ -301,7 +289,7 @@ class ExtractOpConversion : public ConvertTritonGPUOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto base = adaptor.getBase(); - auto idx = op.getIdx(); + auto idx = op.getIndex(); auto dstType = cast(getTypeConverter()->convertType(op.getType())); auto numElts = dstType.getNumElements(); @@ -315,7 +303,7 @@ class ExtractOpConversion : public ConvertTritonGPUOpToLLVMPattern { } }; -// fixme: support it in upstream constantOpLowering +// FIXME: support it in upstream constantOpLowering class ArithConstantOpLowering : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< @@ -353,20 +341,17 @@ class ArithConstantOpLowering } // namespace -void mlir::triton::populateTritonOpsToLLVMPatterns( +void mlir::triton::intel::populateTritonOpsToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - Target target, PatternBenefit benefit) { - patterns.add(typeConverter, target, benefit); - patterns.add(typeConverter, target, benefit); - patterns.add(typeConverter, target, benefit); - patterns.add>( - typeConverter, target, benefit); - patterns.add>(typeConverter, target, - benefit); - patterns.add>(typeConverter, target, - benefit); - patterns.add(typeConverter, target, benefit); - patterns.add(typeConverter, target, benefit); - patterns.add(typeConverter, target, benefit); - patterns.add(typeConverter, target, benefit); + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add>(typeConverter, + benefit); + patterns.add>(typeConverter, benefit); + patterns.add>(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); } From 48b85c8b1f550019c09f974bec5a8ba84b87181f Mon Sep 17 00:00:00 2001 From: Dewei Wang Date: Wed, 17 Apr 2024 20:26:53 -0700 Subject: [PATCH 03/17] address review comments --- .../TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp | 4 +- .../TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 97 +++++++++++++++---- 2 files changed, 82 insertions(+), 19 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp index 91334a1ba9..fe5dc30124 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp @@ -205,7 +205,7 @@ struct ConvertTritonGPUToLLVM ModuleOp mod = getOperation(); auto enableBlockPtr = mlir::triton::tools::getBoolEnv("INTEL_ENABLE_BLOCK_PTR"); - // fixme: set subgroupSize 16 for now + // FIXME: set subgroupSize 16 for now if (enableBlockPtr) mod->setAttr("triton_gpu.threads-per-warp", IntegerAttr::get(IntegerType::get(context, 32), 16)); @@ -218,6 +218,7 @@ struct ConvertTritonGPUToLLVM int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + // FIXME: handle shared memory with block ptr if (!enableBlockPtr) { // Allocate shared memory and set barrier ModuleAllocation allocation(mod); @@ -248,6 +249,7 @@ struct ConvertTritonGPUToLLVM mlir::triton::intel::TargetInfo targetInfo(computeCapability); int benefit = 10; using namespace mlir::triton::intel; + // ops with block ptr use different ways to convert to llvm if (enableBlockPtr) { mlir::triton::intel::populateTritonOpsToLLVMPatterns(typeConverter, patterns, benefit); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index c984a27833..f62e0c97d3 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -25,20 +25,20 @@ class MakeTensorPtrOpConversion LogicalResult matchAndRewrite(MakeTensorPtrOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto i32Type = rewriter.getI32Type(); - auto i64Type = rewriter.getI64Type(); - auto v2i32 = VectorType::get(2, i32Type); + Location loc = op.getLoc(); + Type i32Type = rewriter.getI32Type(); + Type i64Type = rewriter.getI64Type(); + VectorType v2i32 = VectorType::get(2, i32Type); Value payLoad = rewriter.create(loc, v2i32); auto createIntConstant = [&](Type type, unsigned value) { auto attr = rewriter.getIntegerAttr(type, value); return rewriter.create(loc, type, attr); }; - // if (rank == 2) { - auto offsetX = op.getOffsets()[1]; - auto offsetY = op.getOffsets()[0]; - auto idx0 = createIntConstant(i32Type, 0); - auto idx1 = createIntConstant(i32Type, 1); + // assert(rank == 2 && "add more support for rank != 2"); + Value offsetX = op.getOffsets()[1]; + Value offsetY = op.getOffsets()[0]; + Value idx0 = createIntConstant(i32Type, 0); + Value idx1 = createIntConstant(i32Type, 1); payLoad = rewriter.create(loc, payLoad, offsetX, idx0); payLoad = @@ -55,23 +55,23 @@ class AdvanceOpConversion : public ConvertTritonGPUOpToLLVMPattern { LogicalResult matchAndRewrite(AdvanceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto i32Type = rewriter.getI32Type(); - auto offsets = adaptor.getOffsets(); + Location loc = op.getLoc(); + Type i32Type = rewriter.getI32Type(); + SmallVector offsets = adaptor.getOffsets(); Value ptr = adaptor.getPtr(); for (size_t i = 0; i < offsets.size(); i++) { - auto offset = offsets[i]; + Value offset = offsets[i]; if (auto cst = dyn_cast(offset.getDefiningOp())) if (auto attr = dyn_cast(cst.getValue()); attr && attr.getInt() == 0) continue; - auto idx0 = rewriter.create( + Value idx0 = rewriter.create( loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)); - auto idx1 = rewriter.create( + Value idx1 = rewriter.create( loc, i32Type, rewriter.getIntegerAttr(i32Type, 1)); Value idx = i == 0 ? idx1 : idx0; - auto oldOffset = rewriter.create(loc, ptr, idx); - auto newOffset = + Value oldOffset = rewriter.create(loc, ptr, idx); + Value newOffset = rewriter.create(loc, i32Type, oldOffset, offset); ptr = rewriter.create(loc, ptr, newOffset, idx); } @@ -80,6 +80,57 @@ class AdvanceOpConversion : public ConvertTritonGPUOpToLLVMPattern { } }; +// TritonGen 2DBlockLoadOp Desc: LSC 2d block prefetch +// Output: nothing is returned +// Arg 0: flat image base offset +// Arg 1: flat image base width +// Arg 2: flat image base height +// Arg 3: flat image base pitch +// Arg 4: offset x +// Arg 5: offset y +// Arg 6: elemSize +// Arg 7: tile width +// Arg 8: tile height +// Arg 9: V - num blocks (2 for simple 2d block read) +// Arg 10: transpose +// Arg 11: vnni transform (for transpose+transform use transpose only and +// elemSize 32) +// Arg 12: cache controls options (LSC_CACHE_OPTS) + +// TritonGen 2DBlockLoadOp Desc: LSC 2d block read +// Output: +// Arg 0: flat image base offset +// Arg 1: flat image base width +// Arg 2: flat image base height +// Arg 3: flat image base pitch +// Arg 4: offset x +// Arg 5: offset y +// Arg 6: elemSize +// Arg 7: tile width +// Arg 8: tile height +// Arg 9: V - num blocks (2 for simple 2d block read) +// Arg 10: transpose +// Arg 11: vnni transform (for transpose+transform use transpose only and +// elemSize 32) +// Arg 12: cache controls options (LSC_CACHE_OPTS) + +// TritonGen 2DBlockStoreOp Desc: LSC 2d block write +// Output: nothing is returned +// Arg 0: flat image base offset +// Arg 1: flat image base width +// Arg 2: flat image base height +// Arg 3: flat image base pitch +// Arg 4: offset x +// Arg 5: offset y +// Arg 6: elemSize +// Arg 7: tile width +// Arg 8: tile height +// Arg 9: V - num blocks (2 for simple 2d block read) +// Arg 10: transpose +// Arg 11: vnni transform (for transpose+transform use transpose only and +// elemSize 32) +// Arg 12: cache controls options (LSC_CACHE_OPTS) +// Arg 13: stored value template class LoadStorePrefetchOpConversion : public ConvertTritonGPUOpToLLVMPattern { @@ -189,6 +240,16 @@ class LoadStorePrefetchOpConversion } }; +// TritonGen DpasOp Desc: XeHP SDV: dot product accumulate systolic +// Output: dst +// Arg 0: src0(acc) +// Arg 1: src1 +// Arg 2: src2 +// Arg 3: src1's precision +// Arg 4: src2's precision +// Arg 5: systolic depth +// Arg 6: repeat count +// Arg 7: isDpasw class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: using ConvertTritonGPUOpToLLVMPattern::ConvertTritonGPUOpToLLVMPattern; @@ -217,7 +278,6 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { auto precB = TritonGEN::PrecisionTypeAttr::get(rewriter.getContext(), precb); auto rc = IntegerAttr::get(i32Type, 8); - // sd dpasW fixed in genx.dpas lowering auto getIntType = [&](Type type, bool is16Bit = false) { auto tType = cast(type); auto elemType = is16Bit ? i16Type : i32Type; @@ -232,6 +292,7 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { auto intTypeB = getIntType(op.getB().getType()); auto castB = rewriter.create(loc, intTypeB, adaptor.getB()); + // sd dpasW fixed in genx.dpas lowering auto dpas = rewriter.create( loc, adaptor.getC().getType(), adaptor.getC(), castA, castB, precA, precB, rc); From f40bbccc9650cd31a94cc6da5619a37b4c401cbf Mon Sep 17 00:00:00 2001 From: Dewei Wang Date: Fri, 19 Apr 2024 00:59:08 -0700 Subject: [PATCH 04/17] add lit test --- .../tritongpu_to_llvm_intel_block_ptr.mlir | 99 +++++++++++++ .../TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 133 +++++++++--------- 2 files changed, 162 insertions(+), 70 deletions(-) create mode 100644 test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir diff --git a/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir new file mode 100644 index 0000000000..259cfe9232 --- /dev/null +++ b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir @@ -0,0 +1,99 @@ +// RUN: INTEL_ENABLE_BLOCK_PTR=1 triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm | FileCheck %s +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 1 : i32} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) attributes {noinline = false} { + %c3_i32 = arith.constant 3 : i32 + %c7_i32 = arith.constant 7 : i32 + %c63_i32 = arith.constant 63 : i32 + %c48_i32 = arith.constant 48 : i32 + %c24_i32 = arith.constant 24 : i32 + %c64_i32 = arith.constant 64 : i32 + %c16_i32 = arith.constant 16 : i32 + %c4096_i32 = arith.constant 4096 : i32 + %c8_i32 = arith.constant 8 : i32 + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<8x16xf32> + %0 = gpu.subgroup_id : index + %1 = arith.index_cast %0 : index to i32 + %2 = tt.get_program_id x : i32 + %3 = arith.divsi %2, %c64_i32 : i32 + %4 = arith.muli %3, %c4_i32 : i32 + %5 = arith.subi %c16_i32, %4 : i32 + %6 = arith.minsi %5, %c4_i32 : i32 + %7 = arith.remsi %2, %6 : i32 + %8 = arith.addi %4, %7 : i32 + %9 = arith.andi %2, %c63_i32 : i32 + %10 = arith.divsi %9, %6 : i32 + %11 = arith.muli %8, %c256_i32 : i32 + %12 = arith.muli %1, %c8_i32 : i32 + %13 = arith.addi %12, %11 : i32 + // CHECK-LABEL: @matmul_kernel_with_block_pointers + // CHECK: [[undef:%.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK: [[zero:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[one:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: [[insert0:%.*]] = llvm.insertelement {{.*}}, [[undef]][[[zero]] : i32] : vector<2xi32> + // CHECK: [[insert1:%.*]] = llvm.insertelement {{.*}}, [[insert0]][[[one]] : i32] : vector<2xi32> + %14 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%13, %c0_i32] {order = array} : , 1> + // CHECK: llvm.call @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid + triton_intel_gpu.prefetch %14 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> + %18 = arith.divsi %1, %c4_i32 : i32 + %19 = arith.andi %18, %c7_i32 : i32 + %20 = arith.muli %19, %c32_i32 : i32 + %21 = arith.addi %20, %11 : i32 + %22 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%21, %c0_i32] {order = array} : , 1> + %23 = arith.muli %10, %c256_i32 : i32 + %24 = arith.divsi %1, %c8_i32 : i32 + %25 = arith.andi %24, %c3_i32 : i32 + %26 = arith.muli %25, %c8_i32 : i32 + %27 = arith.andi %1, %c7_i32 : i32 + %28 = arith.muli %27, %c32_i32 : i32 + %29 = arith.addi %28, %23 : i32 + %34 = arith.andi %1, %c3_i32 : i32 + %35 = arith.muli %34, %c64_i32 : i32 + %36 = arith.addi %35, %23 : i32 + %37 = tt.make_tensor_ptr %arg1, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %36] {order = array} : , 1> + %38 = arith.addi %36, %c32_i32 : i32 + %39 = tt.make_tensor_ptr %arg1, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %38] {order = array} : , 1> + cf.br ^bb1(%c0_i32, %cst, %22, %37, %39 : i32, tensor<8x16xf32>, !tt.ptr, 1>, !tt.ptr, 1>, !tt.ptr, 1>) + ^bb1(%40: i32, %41: tensor<8x16xf32>, %57: !tt.ptr, 1>, %58: !tt.ptr, 1>, %59: !tt.ptr, 1>): // 2 preds: ^bb0, ^bb2 + %62 = arith.cmpi slt, %40, %c4096_i32 : i32 + cf.cond_br %62, ^bb2, ^bb3 + ^bb2: // pred: ^bb1 + // CHECK: [[A:%.*]] = llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v64i16({{.*}} -> vector<64xi16> + // CHECK: [[castA:%.*]] = llvm.bitcast [[A]] : vector<64xi16> to vector<64xf16> + // CHECK: [[B0:%.*]] = llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v32i32({{.*}} -> vector<32xi32> + // CHECK: [[castB:%.*]] = llvm.bitcast [[B0]] : vector<32xi32> to vector<64xf16> + // CHECK: [[B1:%.*]] = llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v32i32({{.*}} -> vector<32xi32> + // CHECK: [[subA:%.*]] = llvm.shufflevector [[castA]], [[castA]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<64xf16> + // CHECK: [[subB:%.*]] = llvm.shufflevector [[castB]], [[castB]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<64xf16> + // CHECK: [[castDotA:%.*]] = llvm.bitcast [[subA]] : vector<8xf16> to vector<8xi16> + // CHECK: [[castDotB:%.*]] = llvm.bitcast [[subB]] : vector<16xf16> to vector<8xi32> + // CHECK: llvm.call @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f([[castDotA]], [[castDotB]], {{.*}} -> vector<8xf32> + %63 = tt.load %57 {DotIdx = 0 : i32, boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> + %64 = tt.load %58 {DotIdx = 1 : i32, boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> + %65 = tt.load %59 {DotIdx = 1 : i32, boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> + %66 = triton_intel_gpu.extract %63[0] : tensor<32x32xf16> -> tensor<8x16xf16> + %67 = triton_intel_gpu.extract %64[0] : tensor<32x32xf16> -> tensor<16x16xf16> + %68 = tt.dot %66, %67, %41 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32> + %69 = triton_intel_gpu.extract %63[4] : tensor<32x32xf16> -> tensor<8x16xf16> + %70 = triton_intel_gpu.extract %64[1] : tensor<32x32xf16> -> tensor<16x16xf16> + %71 = tt.dot %69, %70, %68 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32> + // CHECK: [[oldOffset:%.*]] = llvm.extractelement {{.*}} : vector<2xi32> + // CHECK: [[newOffset:%.*]] = llvm.add [[oldOffset]], {{.*}} : i32 + // CHECK: llvm.insertelement [[newOffset]], {{.*}} : vector<2xi32> + %115 = tt.advance %57, [%c0_i32, %c32_i32] : , 1> + %117 = tt.advance %58, [%c32_i32, %c0_i32] : , 1> + %118 = tt.advance %59, [%c32_i32, %c0_i32] : , 1> + %119 = arith.addi %40, %c32_i32 : i32 + cf.br ^bb1(%119, %71, %115, %117, %118 : i32, tensor<8x16xf32>, !tt.ptr, 1>, !tt.ptr, 1>, !tt.ptr, 1>) + ^bb3: // pred: ^bb1 + %120 = tt.make_tensor_ptr %arg2, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%21, %36] {order = array} : , 1> + // CHECK: llvm.call @llvm.genx.GenISA.LSC2DBlockWrite.v8i32 + tt.store %120, %41 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1> + tt.return + } +} diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index f62e0c97d3..bc9b6b3d65 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -16,7 +16,7 @@ using namespace mlir::triton::gpu::intel; namespace { -/// offsetX, offsetY for 2D tensor desc +/// v2i32 [offsetX, offsetY] for 2D tensor desc class MakeTensorPtrOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: @@ -48,6 +48,9 @@ class MakeTensorPtrOpConversion } }; +/// %oldOffset = llvm.extract %v2i32, 0/1 +/// %newOffset = llvm.add %oldOffset, %advanceStep +/// offset = llvm.insert %v2i32, 0/1 class AdvanceOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: using ConvertTritonGPUOpToLLVMPattern< @@ -80,25 +83,8 @@ class AdvanceOpConversion : public ConvertTritonGPUOpToLLVMPattern { } }; -// TritonGen 2DBlockLoadOp Desc: LSC 2d block prefetch -// Output: nothing is returned -// Arg 0: flat image base offset -// Arg 1: flat image base width -// Arg 2: flat image base height -// Arg 3: flat image base pitch -// Arg 4: offset x -// Arg 5: offset y -// Arg 6: elemSize -// Arg 7: tile width -// Arg 8: tile height -// Arg 9: V - num blocks (2 for simple 2d block read) -// Arg 10: transpose -// Arg 11: vnni transform (for transpose+transform use transpose only and -// elemSize 32) -// Arg 12: cache controls options (LSC_CACHE_OPTS) - -// TritonGen 2DBlockLoadOp Desc: LSC 2d block read -// Output: +// TritonGen 2DBlock Prefetch/LoadOp Desc: LSC 2d block prefetch/load +// Output: for prefetch, nothing is returned. for load a vector is returned // Arg 0: flat image base offset // Arg 1: flat image base width // Arg 2: flat image base height @@ -142,33 +128,32 @@ class LoadStorePrefetchOpConversion ConversionPatternRewriter &rewriter) const override { auto ptrType = cast(op.getPtr().getType()); auto tType = cast(ptrType.getPointeeType()); - auto rank = tType.getRank(); + unsigned rank = tType.getRank(); assert(rank <= 2 && "only support 1d/2d load/store/prefetch for now"); - auto loc = op.getLoc(); + Location loc = op.getLoc(); constexpr bool isLoad = std::is_same_v; constexpr bool isPrefetch = std::is_same_v; auto createIntConstant = [&](Type type, unsigned value) { auto attr = rewriter.getIntegerAttr(type, value); return rewriter.create(loc, type, attr); }; - auto i16Type = rewriter.getI16Type(); - auto i32Type = rewriter.getI32Type(); - auto i64Type = rewriter.getI64Type(); - auto v4i64 = VectorType::get(4, i64Type); - auto vnni = false; - auto transpose = false; + Type i16Type = rewriter.getI16Type(); + Type i32Type = rewriter.getI32Type(); + Type i64Type = rewriter.getI64Type(); + bool vnni = false; + bool transpose = false; if constexpr (isLoad) { auto idxAttr = op->template getAttrOfType("DotIdx"); vnni = idxAttr.getInt() == 1 ? true : false; } unsigned dataSize = tType.getElementType().getIntOrFloatBitWidth(); - auto blockWidth = tType.getShape()[1]; + unsigned blockWidth = tType.getShape()[1]; assert(blockWidth == 16 || blockWidth == 32 && "only support 16/32 block"); - auto vBlks = blockWidth == 32 ? 2 : 1; + unsigned vBlks = blockWidth == 32 ? 2 : 1; blockWidth = 16; - auto blockHeight = tType.getShape()[0]; - auto idx0 = createIntConstant(i32Type, 0); - auto idx1 = createIntConstant(i32Type, 1); + unsigned blockHeight = tType.getShape()[0]; + Value idx0 = createIntConstant(i32Type, 0); + Value idx1 = createIntConstant(i32Type, 1); Value ptr = op.getPtr(); if (auto cast = dyn_cast(ptr.getDefiningOp())) @@ -179,11 +164,11 @@ class LoadStorePrefetchOpConversion dyn_cast(base.getDefiningOp())) base = cast.getInputs()[0]; - auto insertPoint = rewriter.saveInsertionPoint(); + OpBuilder::InsertPoint insertPoint = rewriter.saveInsertionPoint(); rewriter.setInsertionPointAfter(ptrOp); - auto bytes = createIntConstant( + Value bytes = createIntConstant( i32Type, tType.getElementType().getIntOrFloatBitWidth() / 8); - auto one = createIntConstant(i32Type, 1); + Value one = createIntConstant(i32Type, 1); Value surfaceW = rewriter.create(loc, i32Type, ptrOp.getShape()[1]); surfaceW = rewriter.create(loc, surfaceW, bytes); @@ -205,31 +190,31 @@ class LoadStorePrefetchOpConversion auto num = tType.getNumElements() / 16 / ratio; return VectorType::get(num, elemType); }; - auto tensorPtr = adaptor.getPtr(); - auto offsetX = + Value tensorPtr = adaptor.getPtr(); + Value offsetX = rewriter.create(loc, tensorPtr, idx0); - auto offsetY = + Value offsetY = rewriter.create(loc, tensorPtr, idx1); if constexpr (isLoad) { - auto resType = + Type resType = this->getTypeConverter()->convertType(op->getResult(0).getType()); auto idxAttr = op->template getAttrOfType("DotIdx"); - auto idx = idxAttr.getInt(); - auto intType = getIntType(op->getResult(0).getType(), idx == 0); + unsigned idx = idxAttr.getInt(); + Type intType = getIntType(op->getResult(0).getType(), idx == 0); auto load = rewriter.create( loc, intType, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, dataSize, blockWidth, blockHeight, vBlks, transpose, vnni); auto cast = rewriter.create(loc, resType, load); rewriter.replaceOp(op, cast); } else if constexpr (isPrefetch) { - auto load = rewriter.create( + rewriter.create( loc, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, dataSize, blockWidth, blockHeight, vBlks, transpose, vnni, TritonGEN::PrefetchCacheControl::L1C_L3C); rewriter.eraseOp(op); } else { - auto intType = getIntType(op.getValue().getType()); - auto cast = + Type intType = getIntType(op.getValue().getType()); + Value cast = rewriter.create(loc, intType, adaptor.getValue()); rewriter.create( loc, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, dataSize, @@ -256,9 +241,9 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { LogicalResult matchAndRewrite(DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto i16Type = rewriter.getI16Type(); - auto i32Type = rewriter.getI32Type(); + Location loc = op.getLoc(); + Type i16Type = rewriter.getI16Type(); + Type i32Type = rewriter.getI32Type(); auto encodePrecision = [&](Type type) -> TritonGEN::PrecisionType { if (type == rewriter.getBF16Type()) return TritonGEN::PrecisionType::BF16; @@ -271,8 +256,10 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { return TritonGEN::PrecisionType::UNUSED; } }; - auto preca = encodePrecision(op.getA().getType().getElementType()); - auto precb = encodePrecision(op.getB().getType().getElementType()); + TritonGEN::PrecisionType preca = + encodePrecision(op.getA().getType().getElementType()); + TritonGEN::PrecisionType precb = + encodePrecision(op.getB().getType().getElementType()); auto precA = TritonGEN::PrecisionTypeAttr::get(rewriter.getContext(), preca); auto precB = @@ -280,17 +267,17 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { auto rc = IntegerAttr::get(i32Type, 8); auto getIntType = [&](Type type, bool is16Bit = false) { auto tType = cast(type); - auto elemType = is16Bit ? i16Type : i32Type; - auto ratio = + Type elemType = is16Bit ? i16Type : i32Type; + unsigned ratio = elemType.getIntOrFloatBitWidth() / tType.getElementTypeBitWidth(); - auto num = tType.getNumElements() / 16 / ratio; + unsigned num = tType.getNumElements() / 16 / ratio; return VectorType::get(num, elemType); }; - auto intTypeA = getIntType(op.getA().getType(), true); - auto castA = + Type intTypeA = getIntType(op.getA().getType(), true); + Value castA = rewriter.create(loc, intTypeA, adaptor.getA()); - auto intTypeB = getIntType(op.getB().getType()); - auto castB = + Type intTypeB = getIntType(op.getB().getType()); + Value castB = rewriter.create(loc, intTypeB, adaptor.getB()); // sd dpasW fixed in genx.dpas lowering auto dpas = rewriter.create( @@ -301,6 +288,9 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { } }; +/// %glue = ttgi.glue %a, %b : tensor<4xf16>, tensor<4xf16> : tensor<8xf16> +/// is converted to +/// %glue = llvm.shufflevector %a, %b : [0, 1, 2, 3, 4, 5, 6, 7] : vector<8xf16> class GlueOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: using ConvertTritonGPUOpToLLVMPattern< @@ -308,15 +298,15 @@ class GlueOpConversion : public ConvertTritonGPUOpToLLVMPattern { LogicalResult matchAndRewrite(GlueOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto operands = adaptor.getOperands(); + Location loc = op.getLoc(); + SmallVector operands = adaptor.getOperands(); auto dstType = cast(getTypeConverter()->convertType(op.getType())); - auto numElts = dstType.getNumElements(); + unsigned numElts = dstType.getNumElements(); SmallVector indices(numElts); std::iota(indices.begin(), indices.end(), 0); - auto attr = rewriter.getDenseI32ArrayAttr(indices); - auto num = operands.size(); + DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(indices); + unsigned num = operands.size(); if (num == 1) { rewriter.replaceOp(op, operands[0]); } else if (num == 2) { @@ -325,10 +315,10 @@ class GlueOpConversion : public ConvertTritonGPUOpToLLVMPattern { } else if (num == 4) { auto subType = VectorType::get(numElts / 2, dstType.getElementType()); indices.pop_back_n(numElts / 2); - auto attr01 = rewriter.getDenseI32ArrayAttr(indices); + DenseI32ArrayAttr attr01 = rewriter.getDenseI32ArrayAttr(indices); auto shfl01 = rewriter.create( loc, subType, operands[0], operands[1], attr01); - auto attr23 = rewriter.getDenseI32ArrayAttr(indices); + DenseI32ArrayAttr attr23 = rewriter.getDenseI32ArrayAttr(indices); auto shfl23 = rewriter.create( loc, subType, operands[2], operands[3], attr23); auto shfl = rewriter.create(loc, dstType, shfl01, @@ -341,6 +331,9 @@ class GlueOpConversion : public ConvertTritonGPUOpToLLVMPattern { } }; +/// %extract = ttgi.extract %a[0] : tensor<8xf16> -> tensor<4xf16> +/// is converted to +/// %extract = llvm.shufflevector %a, %a : [0, 1, 2, 3] : vector<4xf16> class ExtractOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: using ConvertTritonGPUOpToLLVMPattern< @@ -348,16 +341,16 @@ class ExtractOpConversion : public ConvertTritonGPUOpToLLVMPattern { LogicalResult matchAndRewrite(ExtractOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto base = adaptor.getBase(); - auto idx = op.getIndex(); + Location loc = op.getLoc(); + Value base = adaptor.getBase(); + unsigned idx = op.getIndex(); auto dstType = cast(getTypeConverter()->convertType(op.getType())); - auto numElts = dstType.getNumElements(); + unsigned numElts = dstType.getNumElements(); SmallVector indices(numElts); - auto start = idx * numElts; + unsigned start = idx * numElts; std::iota(indices.begin(), indices.end(), start); - auto attr = rewriter.getDenseI32ArrayAttr(indices); + DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(indices); rewriter.replaceOpWithNewOp(op, dstType, base, base, attr); return success(); From d8adced34bde5f2b1ae4f45098b3c04aaa2649c6 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 23 Apr 2024 16:23:18 +0000 Subject: [PATCH 05/17] Remove unnecessary code Signed-off-by: Tiotto, Ettore --- .../__grp__kernel_nospec.json | 1 - .../kernel_nospec.json | 1 - .../kernel_nospec.llir | 35 ------------------ .../kernel_nospec.spv | Bin 1016 -> 0 bytes .../kernel_nospec.ttgir | 9 ----- .../kernel_nospec.ttir | 9 ----- .../arch_utils.so | Bin 17528 -> 0 bytes 7 files changed, 55 deletions(-) delete mode 100644 python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/__grp__kernel_nospec.json delete mode 100644 python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.json delete mode 100644 python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.llir delete mode 100644 python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.spv delete mode 100644 python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.ttgir delete mode 100644 python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.ttir delete mode 100644 python/test/unit/.tmp/bdcd280543c195b10bf66ec82aadd9bc/arch_utils.so diff --git a/python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/__grp__kernel_nospec.json b/python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/__grp__kernel_nospec.json deleted file mode 100644 index c5c920573a..0000000000 --- a/python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/__grp__kernel_nospec.json +++ /dev/null @@ -1 +0,0 @@ -{"child_paths": {"kernel_nospec.ttir": ".tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.ttir", "kernel_nospec.ttgir": ".tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.ttgir", "kernel_nospec.llir": ".tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.llir", "kernel_nospec.spv": ".tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.spv", "kernel_nospec.json": ".tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.json"}} \ No newline at end of file diff --git a/python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.json b/python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.json deleted file mode 100644 index 75052d40f0..0000000000 --- a/python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.json +++ /dev/null @@ -1 +0,0 @@ -{"hash": "a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc", "target": ["xpu", {"dev_type": "gpu", "device_arch": 23, "driver_version": "1.3.28202", "gpu_eu_count": 448, "gpu_subslice_count": 56, "has_fp64": true, "max_compute_units": 448, "max_num_sub_groups": 64, "max_work_group_size": 1024, "name": "Intel(R) Data Center GPU Max 1100", "platform_name": "Intel(R) Level-Zero", "sub_group_sizes": [16, 32], "total_memory": 51539607552, "vendor": "Intel(R) Corporation", "version": "1.3"}, 32], "num_warps": 4, "num_ctas": 1, "num_stages": 2, "cluster_dims": [1, 1, 1], "threads_per_warp": 32, "optimize_epilogue": false, "enable_fp_fusion": true, "default_dot_input_precision": "tf32", "allowed_dot_input_precisions": ["tf32", "tf32x3", "ieee"], "allow_fp8e4nv": true, "allow_fp8e4b15": false, "max_num_imprecise_acc_default": 0, "extern_libs": [["libdevice", "/home/jovyan/intel-xpu-backend-for-triton/python/triton/backends/intel/lib/libsycl-spir64-unknown-unknown.bc"]], "debug": null, "AMDGCN_ENABLE_DUMP": false, "DISABLE_FAST_REDUCTION": false, "DISABLE_LLVM_OPT": false, "DISABLE_MMA_V3": false, "DISABLE_PTXAS_OPT": false, "LLVM_IR_ENABLE_DUMP": false, "MLIR_ENABLE_DIAGNOSTICS": false, "MLIR_ENABLE_DUMP": false, "TRITON_DISABLE_LINE_INFO": true, "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE": false, "TRITON_ENABLE_LLVM_DEBUG": false, "TRITON_INTEL_EMULATE_FP16_ATOMICS": false, "TRITON_INTEL_ENABLE_BLOCK_PTR": false, "USE_TTGIR_LOC": false, "shared": 0, "name": "kernel_nospec"} \ No newline at end of file diff --git a/python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.llir b/python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.llir deleted file mode 100644 index 0bf10ef3a5..0000000000 --- a/python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.llir +++ /dev/null @@ -1,35 +0,0 @@ -; ModuleID = 'LLVMDialectModule' -source_filename = "LLVMDialectModule" -target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" -target triple = "spir64-unknown-unknown" - -declare spir_func i64 @_Z12get_local_idj(i32) local_unnamed_addr - -define spir_kernel void @kernel_nospec(ptr addrspace(1) nocapture writeonly %0, i32 %1) local_unnamed_addr !max_work_group_size !5 !intel_reqd_sub_group_size !6 { - %3 = tail call i64 @_Z12get_local_idj(i32 0) - %4 = and i64 %3, 4294967295 - %5 = icmp eq i64 %4, 0 - br i1 %5, label %6, label %9 - -6: ; preds = %2 - %7 = add i32 %1, 3 - %8 = insertelement <1 x i32> poison, i32 %7, i64 0 - store <1 x i32> %8, ptr addrspace(1) %0, align 4 - br label %9 - -9: ; preds = %6, %2 - ret void -} - -!opencl.spir.version = !{!0, !0, !0, !0, !0, !0, !0, !0, !0, !0, !0, !0, !0, !0, !0, !0, !0, !0, !0, !0} -!spirv.Source = !{!1, !1, !1, !1, !1, !1, !1, !1, !1, !1, !1, !1, !1, !1, !1, !1, !1, !1, !1, !1} -!llvm.ident = !{!2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2, !2} -!llvm.module.flags = !{!3, !4} - -!0 = !{i32 1, i32 2} -!1 = !{i32 4, i32 100000} -!2 = !{!"Intel(R) oneAPI DPC++/C++ Compiler 2024.1.0 (2024.1.0.20240216)"} -!3 = !{i32 1, !"wchar_size", i32 4} -!4 = !{i32 7, !"frame-pointer", i32 2} -!5 = !{i64 128, i64 1, i64 1} -!6 = !{i64 32} diff --git a/python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.spv b/python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.spv deleted file mode 100644 index c7c611b29ea6108cbee4fb1b86fb1850ed31355c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1016 zcma))OKTHh6oyYOO=D|s+FI|X6ogXf(kqH6QUbv{ZigCY5sZ_NOtbFF-{IOn;tz78 zg6Q*1CKBq#6JE~c`_6mb?@Vgd-NmR>Vl^7E6I0uWYSan!3#D(Ae!2AbO8;$Zdau46 zu@udC{o-xXe)i^RC;5<$vb3KZ<@u=h?l@1!r_Lv7Hy@0W?%?B)Q|MxuX%+nQVVXVZ z9E|h#fvy@VqQDyb>bMF@_G~a7rd@J%uA|zsUC=7>Mv*U|?Sx!zQRCu%&*PmTjcKS5zx#R@huZ`*MsXNm%@14P`)%p(665elV-Fut#x{LcIr+4kGFv*_F;1+%M zxF=1H+nmecK8)pB0pH%dvYC){=ZC7=fP_KNxH_q$s69oj?gs@b|a+~W_T8K3?aufGFk@Aki~&A;34 ZK703dVzu1B-({}5Y617 {tt.divisibility = 16 : i32} loc(unknown), %arg1: i32 loc(unknown)) attributes {noinline = false} { - %c3_i32 = arith.constant 3 : i32 loc(#loc) - %0 = arith.addi %arg1, %c3_i32 : i32 loc(#loc) - tt.store %arg0, %0 : !tt.ptr loc(#loc) - tt.return loc(#loc) - } loc(#loc) -} loc(#loc) diff --git a/python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.ttir b/python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.ttir deleted file mode 100644 index e735d19634..0000000000 --- a/python/test/unit/.tmp/a162c3a261d601b95c621d183d3d30955efcd817805e2c2e8cdb9ca50b77efcc/kernel_nospec.ttir +++ /dev/null @@ -1,9 +0,0 @@ -#loc = loc(unknown) -module { - tt.func public @kernel_nospec(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: i32 loc(unknown)) attributes {noinline = false} { - %c3_i32 = arith.constant 3 : i32 loc(#loc) - %0 = arith.addi %arg1, %c3_i32 : i32 loc(#loc) - tt.store %arg0, %0 : !tt.ptr loc(#loc) - tt.return loc(#loc) - } loc(#loc) -} loc(#loc) diff --git a/python/test/unit/.tmp/bdcd280543c195b10bf66ec82aadd9bc/arch_utils.so b/python/test/unit/.tmp/bdcd280543c195b10bf66ec82aadd9bc/arch_utils.so deleted file mode 100644 index d3880dece6ffd08f1dccf2b0dd722b16aa0e2e45..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 17528 zcmeHPeQaCR6~9T_5ok%$7SxvuT=drfjt*H@QW{F$HG$(|g1V6roH4xh>>fpFc zTrYKJPm5;Ot{{kMqvElRIanoLr)y>@J*F&2JK3J0Xlxup#L4f%f+)->dzJD z=PPHXB0^YJf6_@=SvQh#ElXIPdv;neJ7f39v#y=lv$Hkfr0hMB-h{1WByyl?Wrv~(i%-T%#CsXbrAknc9JQ^46M>|dS;uS=fS@rlTO@Nh zi40cbnN&Yjvu-T9b}cGxmXc&236e#uKN_{N$Vs0NX(Ec;Ce)X;&0u7lK1}$q zhMzZNv*Lt?KTG&Y4ewa4^v5-Piu7O5aBlCshI4zbYIr00pU`mD&ujQN@lzV^BR`Aj zdCqxzisCdh{50XKE8u<&50n3(hI2m<)^OI}UjgsYa6YdA4X-19P{SLEPy2-|e7?+| z)No-ay>ShnBAzMxnM?W?HGO_4)&+zttiP1>eHzYrn7B>pW1mD_mjv=2+}lr1d2nt6 z4K8}{E2Rf@SnOByG9WdoDvyR=F^FT|$Yr%()ez^t8WEo#k3t|Wf6quxh^>BAK^#9^ z{C-tK3~uaV3HntP@hgg`5Md8K*Msl(;PX6qj|Z>!-~%4K!Gl+;)eKZKP|ZL!1OJ~H zc)#KLzXp%JT^Bq)=MU?J2tF|G))vkOk3C;^Mkb`N>2-h$tKNWP!!keAkRL$V^djMCI3t0u?wGkQu04X9=p)V6OuoTJa(OvW0LAN^8CG*3I&Y&YKRZ> z5FGhMaO6zBAK{T#51i>63)~gh7ueGs?i~3uSa#QAdt@{Ss$e%CFtE!%PO#HCau(oE z5quorPH;QA>0|hP6&!gzH1fx$@%&1Fq2cWD*(*nbZy)=pHhXtxp&>N7>NU{H)rCe1 z`5*kfPza5D7#evi-$70ab91*`IGBG0BoNKhpn73R5Wu?N$nW!C1T_B>C3*=&c>z{U z;}?u$7ZkK;D3K(eSUWV55Yk|3N zPQ`W$m^C~w0NI>80EgneQM?=Qo3EB?2C5mTW}upZY6hwqsAiy=focY-8K`F9e=`G~ zIleT`#KGL*;ds=RbALi2hC&R~jJq(6m&@3qt$kl-Yr7TL-5L~oyYA@Pbyt_TLxlIW z3K&cZn7VT^MqEw>!c3n5^L?(9GObn_NNFxt1l=Sk4>&aW&6>>zv|g0|tcgAjrG#$&aK zdeqqppF#M16J*!niMG@?w$v|fsW$@ks{{4EV{^Vd`{7xS)YeWQYe*3&)PYF0hp~sD zy#}D@cPDtoIL^T*sNtwmE!7NEGf>SyH3QWQR5MV`Ks5vZ{S5H>v~gNr20wmf876w( zD2)++CVI=WhT?c#8ULo{-=I?@AI2ZYP+gEdt{;U#m&Z zYo(r`_;`&Jj!UqRnGk--%CQyJaGZpN$3-~7jAfA`;v!vOj6X?%@O2j?$GnD%+k+`d z8UGmR!!)AgsW|~dGRY@M{t@xZsNMfG;q~l#yPN3?=^(m~XdltTL?0yjb)w%T`V7&t zL|-PF_LsKAc_~H9a7HNxd6k9NA(l!jb933x0;R+4FJ-*4sg&@2y`S`Uw6=cMSiQG5 zmvVDPQ@PNOZ{x_0e)_!zBDG96FY z8Kc?PywPkjea338nON~{@NHcqaMAL*;;V~w15Vbh94wyKZ1G+}*TCzv7T!GYbGd;O zUGtOEWt;oJ{Uw;)bSe4!Yo z=ZFD3mRI0YBd!$u+{g35l7|oC++Xm0`Yz$ze=vSg;tPxYg#kR4MG!FhXW}#e&zv((NL?NB+6L=plo%fS3Dz zUj!V#uSQNJ>Ays{U&G%bocm*r=VQQY#1i2TCfpIs1B-fj{W!s5J>b~xxaNPG z)SoYSeLCyk19-Xk@00qA#IRp!4v@hh;e5Z7SR4mjUY~S7%lAxRQz<{Eq@PCd!U}Ti zA?4yZE%m3r7gj$j+np!4-(wv29n}c9I{b_%`ntrg68iJ@g7gDdC4b>$_8#D{7^S#> zh+r`r`kfl+rP6*z6rB&a=0XaVR=`cbjZznqyS;*bPX+ui;48}bd9;H5w*W8K?x_m; zKLvaNTvz(OmtR-V|DDvI8PNL`@XO!@Yq@xq15TGziL++F%jH=(eSNse#5|ITq|>m~ z#F1NGs@P6DQpP?^zv&^)B-9G)cYx-tN;B|D?jH(W$ zqC62#9RguG0=&{)E#TNmw!XI-#-LpulV=Y}bxKO3P~C&Z;=(GcsghltZEq zW)5JkC6`PNfr*E+KofUfO{*p&@zlD_TOwOFN6cs%zOQ>iH>-_RZFV1WO&b)u({=}( z7z7E>COLJZCPM-+aTDHRsSC2(2YbYVKgsfd1 z9o_AFtUZC2P&??g-Q5+~+1V;&h+wv(Yp>NFBwxX{-2lb1jx`WT#b#{g(sqx1(k``1 zd&tN|$$m4dWZ#jx&3ZscAIHC%gXLS!|ivx|Ey&59d2UN0xY{L z4lmK98|ek=W)vOZT7N2M4n(p8!i){2Kuu9MBTSSWwli7Sg;&B@Aj{Z^2pUi=op1%* z3~@L@)$Dhm;@X37lwZD?ap>l14F$S(;6)JH6=<_s_I9Gr$vBs?X1vsr;pGMs+{nVp6SJ8ibF zzfM~~j$>8EdA!QhAbUN2g2kOs zfv&I^KaY=@^7xqj>*s%f>@f~3JlE-Wh%#DzJiG-h+to`zqc)ugA;wOwWT2j)n35z{2md(m0TY^3UvF z106hnw&!u~Fx#WhQ*?imAj4~)?Ri|m|EGZe9|6lbZuaAVt+##NZ#l7EHEJkQJhDC0 zKY_$+FZ`tHQ#4I16SiaeHxPL3c|4OQ`*AiT%5gJ{|FS~a8xrDiv#&{M==wT;IT&Ls zoF5+N`pKT-RNhouTt`|@?Nhi#5!>VQ!CNp^EKKn|!|YJO(e0n0_W8Y}vYNJId!&5t zu@2I4vVVx)FD?aaEqwTVxc|ZTmVE9|EtcRGRqfPnIBu5}=R>8O{iEBIhM^f~toR3~ Cirvuw From 344d0c5194e8569ed7053a65c1bd212e7dedfd5b Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 23 Apr 2024 20:19:03 +0000 Subject: [PATCH 06/17] Removeunnecessary ssa variable in lit test Signed-off-by: Tiotto, Ettore --- test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir index 6ca8a9cd91..18b15f6af6 100644 --- a/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir +++ b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir @@ -33,10 +33,10 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %11 = arith.muli %8, %c256_i32 : i32 %12 = arith.muli %1, %c8_i32 : i32 %13 = arith.addi %12, %11 : i32 - // CHECK: [[undef:%.*]] = llvm.mlir.undef : vector<2xi32> - // CHECK-DAG: [[zero:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK-DAG: [[one:%.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: [[insert0:%.*]] = llvm.insertelement {{.*}}, [[undef]][[[zero]] : i32] : vector<2xi32> + // CHECK: [[undef:%.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK-DAG: [[zero:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-DAG: [[one:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: [[insert0:%.*]] = llvm.insertelement {{.*}}, [[undef]][[[zero]] : i32] : vector<2xi32> // CHECK-NEXT: [[insert1:%.*]] = llvm.insertelement {{.*}}, [[insert0]][[[one]] : i32] : vector<2xi32> %14 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%13, %c0_i32] {order = array} : , 1> // CHECK: llvm.call @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid From a70b13ff08cdeb1435cb7db5feb81dc942fe3ce1 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 23 Apr 2024 20:23:00 +0000 Subject: [PATCH 07/17] Simplify code in new transformations Signed-off-by: Tiotto, Ettore --- .../TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 246 +++++++++--------- 1 file changed, 119 insertions(+), 127 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index bc9b6b3d65..91ffc97951 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -1,14 +1,7 @@ -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "PatternTritonGPUOpToLLVM.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGEN/IR/TritonGENDialect.h" -#include "triton/Tools/Sys/GetEnv.hpp" - -#include "PatternTritonGPUOpToLLVM.h" -#include "Utility.h" -#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" -#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" using namespace mlir; using namespace mlir::triton; @@ -16,6 +9,20 @@ using namespace mlir::triton::gpu::intel; namespace { +LLVM::ConstantOp createIntConstant(IntegerType type, unsigned value, + ConversionPatternRewriter &rewriter, + Location loc) { + auto attr = rewriter.getIntegerAttr(type, value); + return rewriter.create(loc, type, attr); +} + +VectorType getVectorType(RankedTensorType tensorType, Type elemType) { + unsigned ratio = + elemType.getIntOrFloatBitWidth() / tensorType.getElementTypeBitWidth(); + unsigned num = tensorType.getNumElements() / 16 / ratio; + return VectorType::get(num, elemType); +}; + /// v2i32 [offsetX, offsetY] for 2D tensor desc class MakeTensorPtrOpConversion : public ConvertTritonGPUOpToLLVMPattern { @@ -26,19 +33,14 @@ class MakeTensorPtrOpConversion matchAndRewrite(MakeTensorPtrOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Type i32Type = rewriter.getI32Type(); - Type i64Type = rewriter.getI64Type(); + + IntegerType i32Type = rewriter.getI32Type(); VectorType v2i32 = VectorType::get(2, i32Type); - Value payLoad = rewriter.create(loc, v2i32); - auto createIntConstant = [&](Type type, unsigned value) { - auto attr = rewriter.getIntegerAttr(type, value); - return rewriter.create(loc, type, attr); - }; - // assert(rank == 2 && "add more support for rank != 2"); Value offsetX = op.getOffsets()[1]; Value offsetY = op.getOffsets()[0]; - Value idx0 = createIntConstant(i32Type, 0); - Value idx1 = createIntConstant(i32Type, 1); + Value payLoad = rewriter.create(loc, v2i32); + Value idx0 = createIntConstant(i32Type, 0, rewriter, loc); + Value idx1 = createIntConstant(i32Type, 1, rewriter, loc); payLoad = rewriter.create(loc, payLoad, offsetX, idx0); payLoad = @@ -48,8 +50,8 @@ class MakeTensorPtrOpConversion } }; -/// %oldOffset = llvm.extract %v2i32, 0/1 -/// %newOffset = llvm.add %oldOffset, %advanceStep +/// %oldOffset = llvm.extract %v2i32, 0/1 +/// %newOffset = llvm.add %oldOffset, %advanceStep /// offset = llvm.insert %v2i32, 0/1 class AdvanceOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: @@ -59,7 +61,6 @@ class AdvanceOpConversion : public ConvertTritonGPUOpToLLVMPattern { matchAndRewrite(AdvanceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Type i32Type = rewriter.getI32Type(); SmallVector offsets = adaptor.getOffsets(); Value ptr = adaptor.getPtr(); for (size_t i = 0; i < offsets.size(); i++) { @@ -68,11 +69,13 @@ class AdvanceOpConversion : public ConvertTritonGPUOpToLLVMPattern { if (auto attr = dyn_cast(cst.getValue()); attr && attr.getInt() == 0) continue; - Value idx0 = rewriter.create( - loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)); - Value idx1 = rewriter.create( - loc, i32Type, rewriter.getIntegerAttr(i32Type, 1)); - Value idx = i == 0 ? idx1 : idx0; + + IntegerType i32Type = rewriter.getI32Type(); + Value idx = (i == 0) + ? rewriter.create( + loc, i32Type, rewriter.getIntegerAttr(i32Type, 1)) + : rewriter.create( + loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)); Value oldOffset = rewriter.create(loc, ptr, idx); Value newOffset = rewriter.create(loc, i32Type, oldOffset, offset); @@ -83,40 +86,40 @@ class AdvanceOpConversion : public ConvertTritonGPUOpToLLVMPattern { } }; -// TritonGen 2DBlock Prefetch/LoadOp Desc: LSC 2d block prefetch/load -// Output: for prefetch, nothing is returned. for load a vector is returned -// Arg 0: flat image base offset -// Arg 1: flat image base width -// Arg 2: flat image base height -// Arg 3: flat image base pitch -// Arg 4: offset x -// Arg 5: offset y -// Arg 6: elemSize -// Arg 7: tile width -// Arg 8: tile height -// Arg 9: V - num blocks (2 for simple 2d block read) -// Arg 10: transpose -// Arg 11: vnni transform (for transpose+transform use transpose only and -// elemSize 32) -// Arg 12: cache controls options (LSC_CACHE_OPTS) +/// TritonGen 2DBlock Prefetch/LoadOp Desc: LSC 2d block prefetch/load +/// Output: for prefetch, nothing is returned. for load a vector is returned +/// Arg 0: flat image base offset +/// Arg 1: flat image base width +/// Arg 2: flat image base height +/// Arg 3: flat image base pitch +/// Arg 4: offset x +/// Arg 5: offset y +/// Arg 6: elemSize +/// Arg 7: tile width +/// Arg 8: tile height +/// Arg 9: V - num blocks (2 for simple 2d block read) +/// Arg 10: transpose +/// Arg 11: vnni transform (for transpose+transform use transpose only and +/// elemSize 32) +/// Arg 12: cache controls options (LSC_CACHE_OPTS) -// TritonGen 2DBlockStoreOp Desc: LSC 2d block write -// Output: nothing is returned -// Arg 0: flat image base offset -// Arg 1: flat image base width -// Arg 2: flat image base height -// Arg 3: flat image base pitch -// Arg 4: offset x -// Arg 5: offset y -// Arg 6: elemSize -// Arg 7: tile width -// Arg 8: tile height -// Arg 9: V - num blocks (2 for simple 2d block read) -// Arg 10: transpose -// Arg 11: vnni transform (for transpose+transform use transpose only and -// elemSize 32) -// Arg 12: cache controls options (LSC_CACHE_OPTS) -// Arg 13: stored value +/// TritonGen 2DBlockStoreOp Desc: LSC 2d block write +/// Output: nothing is returned +/// Arg 0: flat image base offset +/// Arg 1: flat image base width +/// Arg 2: flat image base height +/// Arg 3: flat image base pitch +/// Arg 4: offset x +/// Arg 5: offset y +/// Arg 6: elemSize +/// Arg 7: tile width +/// Arg 8: tile height +/// Arg 9: V - num blocks (2 for simple 2d block read) +/// Arg 10: transpose +/// Arg 11: vnni transform (for transpose+transform use transpose only and +/// elemSize 32) +/// Arg 12: cache controls options (LSC_CACHE_OPTS) +/// Arg 13: stored value template class LoadStorePrefetchOpConversion : public ConvertTritonGPUOpToLLVMPattern { @@ -127,37 +130,36 @@ class LoadStorePrefetchOpConversion matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto ptrType = cast(op.getPtr().getType()); - auto tType = cast(ptrType.getPointeeType()); - unsigned rank = tType.getRank(); - assert(rank <= 2 && "only support 1d/2d load/store/prefetch for now"); + auto tensorType = cast(ptrType.getPointeeType()); + assert(tensorType.getRank() <= 2 && + "only support 1d/2d load/store/prefetch for now"); + Location loc = op.getLoc(); constexpr bool isLoad = std::is_same_v; constexpr bool isPrefetch = std::is_same_v; - auto createIntConstant = [&](Type type, unsigned value) { - auto attr = rewriter.getIntegerAttr(type, value); - return rewriter.create(loc, type, attr); - }; - Type i16Type = rewriter.getI16Type(); - Type i32Type = rewriter.getI32Type(); - Type i64Type = rewriter.getI64Type(); - bool vnni = false; - bool transpose = false; + + IntegerType i16Type = rewriter.getI16Type(); + IntegerType i32Type = rewriter.getI32Type(); + IntegerType i64Type = rewriter.getI64Type(); + bool vnni = false, transpose = false; if constexpr (isLoad) { auto idxAttr = op->template getAttrOfType("DotIdx"); vnni = idxAttr.getInt() == 1 ? true : false; } - unsigned dataSize = tType.getElementType().getIntOrFloatBitWidth(); - unsigned blockWidth = tType.getShape()[1]; + + unsigned dataSize = tensorType.getElementType().getIntOrFloatBitWidth(); + unsigned blockWidth = tensorType.getShape()[1]; assert(blockWidth == 16 || blockWidth == 32 && "only support 16/32 block"); unsigned vBlks = blockWidth == 32 ? 2 : 1; blockWidth = 16; - unsigned blockHeight = tType.getShape()[0]; - Value idx0 = createIntConstant(i32Type, 0); - Value idx1 = createIntConstant(i32Type, 1); + unsigned blockHeight = tensorType.getShape()[0]; + Value idx0 = createIntConstant(i32Type, 0, rewriter, loc); + Value idx1 = createIntConstant(i32Type, 1, rewriter, loc); Value ptr = op.getPtr(); if (auto cast = dyn_cast(ptr.getDefiningOp())) ptr = cast.getInputs()[0]; + MakeTensorPtrOp ptrOp = getMakeTensorPtrOp(ptr); Value base = ptrOp.getBase(); if (auto cast = @@ -167,8 +169,9 @@ class LoadStorePrefetchOpConversion OpBuilder::InsertPoint insertPoint = rewriter.saveInsertionPoint(); rewriter.setInsertionPointAfter(ptrOp); Value bytes = createIntConstant( - i32Type, tType.getElementType().getIntOrFloatBitWidth() / 8); - Value one = createIntConstant(i32Type, 1); + i32Type, tensorType.getElementType().getIntOrFloatBitWidth() / 8, + rewriter, loc); + Value one = createIntConstant(i32Type, 1, rewriter, loc); Value surfaceW = rewriter.create(loc, i32Type, ptrOp.getShape()[1]); surfaceW = rewriter.create(loc, surfaceW, bytes); @@ -182,27 +185,22 @@ class LoadStorePrefetchOpConversion surfaceP = rewriter.create(loc, surfaceP, one); rewriter.restoreInsertionPoint(insertPoint); - auto getIntType = [&](Type type, bool is16Bit = false) { - auto tType = cast(type); - auto elemType = is16Bit ? i16Type : i32Type; - auto ratio = - elemType.getIntOrFloatBitWidth() / tType.getElementTypeBitWidth(); - auto num = tType.getNumElements() / 16 / ratio; - return VectorType::get(num, elemType); - }; Value tensorPtr = adaptor.getPtr(); Value offsetX = rewriter.create(loc, tensorPtr, idx0); Value offsetY = rewriter.create(loc, tensorPtr, idx1); + if constexpr (isLoad) { Type resType = this->getTypeConverter()->convertType(op->getResult(0).getType()); auto idxAttr = op->template getAttrOfType("DotIdx"); unsigned idx = idxAttr.getInt(); - Type intType = getIntType(op->getResult(0).getType(), idx == 0); + Type vectorType = + getVectorType(cast(op->getResult(0).getType()), + idx == 0 ? i16Type : i32Type); auto load = rewriter.create( - loc, intType, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, + loc, vectorType, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, dataSize, blockWidth, blockHeight, vBlks, transpose, vnni); auto cast = rewriter.create(loc, resType, load); rewriter.replaceOp(op, cast); @@ -213,9 +211,10 @@ class LoadStorePrefetchOpConversion TritonGEN::PrefetchCacheControl::L1C_L3C); rewriter.eraseOp(op); } else { - Type intType = getIntType(op.getValue().getType()); + VectorType vectorType = getVectorType( + cast(op.getValue().getType()), i32Type); Value cast = - rewriter.create(loc, intType, adaptor.getValue()); + rewriter.create(loc, vectorType, adaptor.getValue()); rewriter.create( loc, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, dataSize, blockWidth, blockHeight, vBlks, transpose, vnni, cast); @@ -241,9 +240,6 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { LogicalResult matchAndRewrite(DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Type i16Type = rewriter.getI16Type(); - Type i32Type = rewriter.getI32Type(); auto encodePrecision = [&](Type type) -> TritonGEN::PrecisionType { if (type == rewriter.getBF16Type()) return TritonGEN::PrecisionType::BF16; @@ -251,35 +247,30 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { return TritonGEN::PrecisionType::FP16; else if (type == rewriter.getTF32Type()) return TritonGEN::PrecisionType::TF32; - else { - assert(0 && "add more support"); - return TritonGEN::PrecisionType::UNUSED; - } + assert(false && "add more support"); + return TritonGEN::PrecisionType::UNUSED; }; - TritonGEN::PrecisionType preca = + + TritonGEN::PrecisionType precATy = encodePrecision(op.getA().getType().getElementType()); - TritonGEN::PrecisionType precb = + TritonGEN::PrecisionType precBTy = encodePrecision(op.getB().getType().getElementType()); auto precA = - TritonGEN::PrecisionTypeAttr::get(rewriter.getContext(), preca); + TritonGEN::PrecisionTypeAttr::get(rewriter.getContext(), precATy); auto precB = - TritonGEN::PrecisionTypeAttr::get(rewriter.getContext(), precb); + TritonGEN::PrecisionTypeAttr::get(rewriter.getContext(), precBTy); + + Location loc = op.getLoc(); + IntegerType i16Type = rewriter.getI16Type(); + IntegerType i32Type = rewriter.getI32Type(); + VectorType typeA = + getVectorType(cast(op.getA().getType()), i16Type); + Value castA = rewriter.create(loc, typeA, adaptor.getA()); + VectorType typeB = + getVectorType(cast(op.getB().getType()), i32Type); + Value castB = rewriter.create(loc, typeB, adaptor.getB()); auto rc = IntegerAttr::get(i32Type, 8); - auto getIntType = [&](Type type, bool is16Bit = false) { - auto tType = cast(type); - Type elemType = is16Bit ? i16Type : i32Type; - unsigned ratio = - elemType.getIntOrFloatBitWidth() / tType.getElementTypeBitWidth(); - unsigned num = tType.getNumElements() / 16 / ratio; - return VectorType::get(num, elemType); - }; - Type intTypeA = getIntType(op.getA().getType(), true); - Value castA = - rewriter.create(loc, intTypeA, adaptor.getA()); - Type intTypeB = getIntType(op.getB().getType()); - Value castB = - rewriter.create(loc, intTypeB, adaptor.getB()); - // sd dpasW fixed in genx.dpas lowering + // sd dpasW fixed in genx.dpas lowering. auto dpas = rewriter.create( loc, adaptor.getC().getType(), adaptor.getC(), castA, castB, precA, precB, rc); @@ -306,13 +297,16 @@ class GlueOpConversion : public ConvertTritonGPUOpToLLVMPattern { SmallVector indices(numElts); std::iota(indices.begin(), indices.end(), 0); DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(indices); - unsigned num = operands.size(); - if (num == 1) { + + switch (operands.size()) { + case 1: rewriter.replaceOp(op, operands[0]); - } else if (num == 2) { + break; + case 2: rewriter.replaceOpWithNewOp( op, dstType, operands[0], operands[1], attr); - } else if (num == 4) { + break; + case 4: { auto subType = VectorType::get(numElts / 2, dstType.getElementType()); indices.pop_back_n(numElts / 2); DenseI32ArrayAttr attr01 = rewriter.getDenseI32ArrayAttr(indices); @@ -324,8 +318,9 @@ class GlueOpConversion : public ConvertTritonGPUOpToLLVMPattern { auto shfl = rewriter.create(loc, dstType, shfl01, shfl23, attr); rewriter.replaceOp(op, shfl); - } else { - assert(0 && "add more support for glue op to llvm"); + } break; + default: + llvm_unreachable("add more support for glue op to llvm"); } return success(); } @@ -341,14 +336,12 @@ class ExtractOpConversion : public ConvertTritonGPUOpToLLVMPattern { LogicalResult matchAndRewrite(ExtractOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); Value base = adaptor.getBase(); - unsigned idx = op.getIndex(); auto dstType = cast(getTypeConverter()->convertType(op.getType())); unsigned numElts = dstType.getNumElements(); SmallVector indices(numElts); - unsigned start = idx * numElts; + unsigned start = op.getIndex() * numElts; std::iota(indices.begin(), indices.end(), start); DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(indices); rewriter.replaceOpWithNewOp(op, dstType, base, base, @@ -381,9 +374,8 @@ class ArithConstantOpLowering if (!dstElementsAttr) return failure(); - ShapedType dstAttrType = dstElementsAttr.getType(); auto vecType = cast(dstType); - dstAttrType = + ShapedType dstAttrType = VectorType::get(vecType.getNumElements(), vecType.getElementType()); dstElementsAttr = dstElementsAttr.resizeSplat(dstAttrType); auto newOp = From 0d28d247fd158803b776174d17f22722aab7e64e Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 23 Apr 2024 20:28:03 +0000 Subject: [PATCH 08/17] Create TritonIntelGPUToLLVMTypeConverter Signed-off-by: Tiotto, Ettore --- .../TritonGPUToLLVM/TypeConverter.cpp | 30 +++----------- .../TritonIntelGPUToLLVM/TypeConverter.h | 15 +++++++ .../lib/TritonIntelGPUToLLVM/CMakeLists.txt | 25 +++++------ .../TritonIntelGPUToLLVM/ClusterOpsToLLVM.cpp | 4 +- .../lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp | 8 ++-- .../TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp | 6 +-- .../TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp | 4 +- .../LoadStoreOpToLLVM.cpp | 11 ++--- .../PatternTritonGPUOpToLLVM.h | 41 ++++++++++--------- .../TritonIntelGPUToLLVM/PipelineManager.h | 12 +++--- .../TritonIntelGPUToLLVM/PrintOpToLLVM.cpp | 5 ++- .../TensorPtrOpsToLLVM.cpp | 4 +- .../TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp | 5 +-- .../TritonGPUToLLVMBase.h | 7 ++-- .../TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 5 ++- .../TritonIntelGPUToLLVM/TypeConverter.cpp | 30 ++++++++++++++ 16 files changed, 122 insertions(+), 90 deletions(-) create mode 100644 third_party/intel/include/TritonIntelGPUToLLVM/TypeConverter.h create mode 100644 third_party/intel/lib/TritonIntelGPUToLLVM/TypeConverter.cpp diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index 0d3c220b39..616179084e 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -2,7 +2,6 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Tools/Sys/GetEnv.hpp" using namespace mlir; using namespace mlir::triton; @@ -18,29 +17,12 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( MLIRContext *ctx, LowerToLLVMOptions &option, const DataLayoutAnalysis *analysis) : LLVMTypeConverter(ctx, option, analysis) { - if (mlir::triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_BLOCK_PTR")) { - // tt::pointer to v2i32 - addConversion([&](PointerType type) -> std::optional { - if (isa(type.getPointeeType())) { - auto i32Type = mlir::IntegerType::get(type.getContext(), 32); - return mlir::VectorType::get(2, i32Type); - } - return LLVM::LLVMPointerType::get(type.getContext(), - type.getAddressSpace()); - }); - // tensor type is flattened and divided by 16(subgroupSize) - addConversion([&](mlir::RankedTensorType type) -> mlir::Type { - return mlir::VectorType::get(type.getNumElements() / 16, - type.getElementType()); - }); - } else { - addConversion([&](triton::PointerType type) -> std::optional { - return convertTritonPointerType(type); - }); - addConversion([&](RankedTensorType type) -> std::optional { - return convertTritonTensorType(type); - }); - } + addConversion([&](triton::PointerType type) -> std::optional { + return convertTritonPointerType(type); + }); + addConversion([&](RankedTensorType type) -> std::optional { + return convertTritonTensorType(type); + }); addConversion([&](MemDescType type) -> std::optional { return convertMemDescType(type); }); diff --git a/third_party/intel/include/TritonIntelGPUToLLVM/TypeConverter.h b/third_party/intel/include/TritonIntelGPUToLLVM/TypeConverter.h new file mode 100644 index 0000000000..b1269c04e4 --- /dev/null +++ b/third_party/intel/include/TritonIntelGPUToLLVM/TypeConverter.h @@ -0,0 +1,15 @@ +#ifndef TRITON_CONVERSION_TRITONINTELGPUTOLLVM_TYPECONVERTER_H +#define TRITON_CONVERSION_TRITONINTELGPUTOLLVM_TYPECONVERTER_H + +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" + +class TritonIntelGPUToLLVMTypeConverter : public TritonGPUToLLVMTypeConverter { +public: + using TypeConverter::convertType; + + TritonIntelGPUToLLVMTypeConverter( + MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis = nullptr); +}; + +#endif // TRITON_CONVERSION_TRITONINTELGPUTOLLVM_TYPECONVERTER_H diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt index 052a0ee4e2..b6f123b2b6 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt @@ -1,27 +1,28 @@ add_triton_library(TritonIntelGPUToLLVM + AllocateSharedMemory.cpp + ClusterOpsToLLVM.cpp + ControlFlowOpToLLVM.cpp ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp ConvertLayoutOpToLLVM.cpp + DecomposeUnsupportedConversions.cpp DotOpToLLVM/DPAS.cpp DotOpToLLVM/FMA.cpp DotOpToLLVM.cpp + ElementwiseOpToLLVM.cpp HistogramOpToLLVM.cpp - ElementwiseOpToLLVM.cpp LoadStoreOpToLLVM.cpp - TritonGPUToLLVM.cpp - DecomposeUnsupportedConversions.cpp - PrintOpToLLVM.cpp - MemoryOpToLLVM.cpp - ControlFlowOpToLLVM.cpp MakeRangeOpToLLVM.cpp - SPMDOpToLLVM.cpp - ReduceOpToLLVM.cpp + MemoryOpToLLVM.cpp + PrintOpToLLVM.cpp + ReduceOpToLLVM.cpp + TritonGPUToLLVM.cpp ScanOpToLLVM.cpp - Utility.cpp + SPMDOpToLLVM.cpp + TargetInfo.cpp TensorPtrOpsToLLVM.cpp TritonOpsToLLVM.cpp - ClusterOpsToLLVM.cpp - AllocateSharedMemory.cpp - TargetInfo.cpp + TypeConverter.cpp + Utility.cpp ViewOpToLLVM.cpp DEPENDS diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ClusterOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ClusterOpsToLLVM.cpp index 67500a8dc7..768bfec0f8 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ClusterOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ClusterOpsToLLVM.cpp @@ -55,8 +55,8 @@ struct ClusterWaitOpConversion : public ConvertTritonGPUOpToLLVMPattern< } // namespace void mlir::triton::intel::populateClusterOpsToLLVMPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { + TritonIntelGPUToLLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); return; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp index 3e8696437e..63d546eac5 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp @@ -13,11 +13,11 @@ using ::mlir::triton::gpu::intel::DpasEncodingAttr; namespace fma_details { LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, - TritonGPUToLLVMTypeConverter *typeConverter, + TritonIntelGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter); LogicalResult convertDPAS(triton::DotOp op, triton::DotOp::Adaptor adaptor, - TritonGPUToLLVMTypeConverter *typeConverter, + TritonIntelGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter); } // namespace fma_details @@ -62,7 +62,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { } // namespace void mlir::triton::intel::populateDotOpToLLVMPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { + TritonIntelGPUToLLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, benefit); } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp index 04c2454283..5dbb923f7a 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp @@ -25,7 +25,7 @@ class DotOpDPASConversionHelper { DotOpDPASConversionHelper(DpasEncodingAttr dpasLayout, ConversionPatternRewriter &rewriter, - TritonGPUToLLVMTypeConverter *typeConverter, + TritonIntelGPUToLLVMTypeConverter *typeConverter, Location loc) : dpasLayout(dpasLayout), rewriter(rewriter), typeConverter(typeConverter), loc(loc), ctx(dpasLayout.getContext()) {} @@ -302,7 +302,7 @@ class DotOpDPASConversionHelper { DpasEncodingAttr dpasLayout; ConversionPatternRewriter &rewriter; - TritonGPUToLLVMTypeConverter *typeConverter; + TritonIntelGPUToLLVMTypeConverter *typeConverter; Location loc; MLIRContext *ctx; }; @@ -311,7 +311,7 @@ class DotOpDPASConversionHelper { namespace fma_details { LogicalResult convertDPAS(triton::DotOp op, triton::DotOp::Adaptor adaptor, - TritonGPUToLLVMTypeConverter *typeConverter, + TritonIntelGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter) { LLVM_DEBUG({ auto module = op->getParentOfType(); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp index 32bc890b9f..368f9dc03e 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -15,7 +15,7 @@ namespace { static ValueTableFMA getValueTableFromStructFMA( Value val, int K, int n0, int shapePerCTATile, int sizePerThread, ConversionPatternRewriter &rewriter, Location loc, - TritonGPUToLLVMTypeConverter *typeConverter, Type type) { + TritonIntelGPUToLLVMTypeConverter *typeConverter, Type type) { ValueTableFMA res; auto elems = unpackLLElements(loc, val, rewriter); int index = 0; @@ -31,7 +31,7 @@ static ValueTableFMA getValueTableFromStructFMA( namespace fma_details { LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, - TritonGPUToLLVMTypeConverter *typeConverter, + TritonIntelGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter) { auto *ctx = rewriter.getContext(); auto loc = op.getLoc(); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index d413240843..58239cbcfc 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -135,7 +135,7 @@ struct LoadOpConversion using ConvertTritonGPUOpToLLVMPattern< triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern; - LoadOpConversion(TritonGPUToLLVMTypeConverter &converter, + LoadOpConversion(TritonIntelGPUToLLVMTypeConverter &converter, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), @@ -300,7 +300,7 @@ struct StoreOpConversion using ConvertTritonGPUOpToLLVMPattern< triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern; - StoreOpConversion(TritonGPUToLLVMTypeConverter &converter, + StoreOpConversion(TritonIntelGPUToLLVMTypeConverter &converter, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), @@ -418,7 +418,7 @@ struct AtomicCASOpConversion using ConvertTritonGPUOpToLLVMPattern< triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern; - AtomicCASOpConversion(TritonGPUToLLVMTypeConverter &converter, + AtomicCASOpConversion(TritonIntelGPUToLLVMTypeConverter &converter, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, @@ -785,8 +785,9 @@ struct AtomicRMWOpConversion } // namespace void mlir::triton::intel::populateLoadStoreOpToLLVMPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { + TritonIntelGPUToLLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit) { patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h index 8446211869..2d254bafda 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -3,24 +3,24 @@ #include "TargetInfo.h" #include "TritonGPUToLLVMBase.h" +#include "intel/include/TritonIntelGPUToLLVM/TypeConverter.h" #include "triton/Analysis/AxisInfo.h" -#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" namespace mlir { namespace triton { namespace intel { void populateTritonOpsToLLVMPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit); + TritonIntelGPUToLLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, PatternBenefit benefit); void populateBarrierOpToLLVMPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit); + TritonIntelGPUToLLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, PatternBenefit benefit); void populateClusterOpsToLLVMPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit); + TritonIntelGPUToLLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, PatternBenefit benefit); void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, @@ -30,9 +30,9 @@ void populateViewOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); -void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - PatternBenefit benefit); +void populateDotOpToLLVMPatterns( + TritonIntelGPUToLLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, PatternBenefit benefit); void populateElementwiseOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, @@ -44,8 +44,9 @@ void populateHistogramOpToLLVMPatterns(LLVMTypeConverter &typeConverter, PatternBenefit benefit); void populateLoadStoreOpToLLVMPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit); + TritonIntelGPUToLLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit); void populateReduceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, @@ -57,17 +58,17 @@ void populateScanOpToLLVMPatterns(LLVMTypeConverter &typeConverter, PatternBenefit benefit); void populateTensorPtrOpsToLLVMPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit); + TritonIntelGPUToLLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, PatternBenefit benefit); void populateTritonGPUToLLVMPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit); + TritonIntelGPUToLLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, PatternBenefit benefit); -void populatePrintOpToLLVMPattern(TritonGPUToLLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - const TargetInfoBase &targetInfo, - PatternBenefit benefit); +void populatePrintOpToLLVMPattern( + TritonIntelGPUToLLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, const TargetInfoBase &targetInfo, + PatternBenefit benefit); void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h index b4a4b88746..50aded5d68 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h @@ -23,7 +23,6 @@ #include "intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" -#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "PatternTritonGPUOpToLLVM.h" @@ -162,7 +161,7 @@ class TritonGPUToLLVMPipelineManager { /// Populate the conversion pipeline for function operations. void populateFunctionConversionPatterns( RewritePatternSet &funcPatterns, - TritonGPUToLLVMTypeConverter &typeConverter, int numWarps) const { + TritonIntelGPUToLLVMTypeConverter &typeConverter, int numWarps) const { funcPatterns.add(typeConverter, numWarps, /*benefit=*/1); if (!blockPtrPathIsEnabled) @@ -171,10 +170,11 @@ class TritonGPUToLLVMPipelineManager { } /// Populate the conversion pipeline for various operations. - void populateConversionPatterns(RewritePatternSet &patterns, - ModuleAxisInfoAnalysis &axisInfoAnalysis, - TritonGPUToLLVMTypeConverter &typeConverter, - TargetInfo &targetInfo, int benefit) const { + void + populateConversionPatterns(RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + TritonIntelGPUToLLVMTypeConverter &typeConverter, + TargetInfo &targetInfo, int benefit) const { using namespace mlir; using namespace mlir::triton; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PrintOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/PrintOpToLLVM.cpp index 6f9c4ac37a..f0f1d3ceaf 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PrintOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PrintOpToLLVM.cpp @@ -228,7 +228,8 @@ struct PrintOpConversion } // namespace void mlir::triton::intel::populatePrintOpToLLVMPattern( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - const TargetInfoBase &targetInfo, PatternBenefit benefit) { + TritonIntelGPUToLLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, const TargetInfoBase &targetInfo, + PatternBenefit benefit) { patterns.add(typeConverter, targetInfo, benefit); } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TensorPtrOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TensorPtrOpsToLLVM.cpp index 15bc79fff5..8f4884ec60 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TensorPtrOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TensorPtrOpsToLLVM.cpp @@ -96,8 +96,8 @@ struct AdvanceOpConversion } // namespace void mlir::triton::intel::populateTensorPtrOpsToLLVMPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { + TritonIntelGPUToLLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); return; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp index 459cc7e3b6..90db713e26 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp @@ -23,7 +23,6 @@ #include "PipelineManager.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" -#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" namespace mlir { namespace triton { @@ -84,7 +83,7 @@ struct ConvertTritonGPUToLLVM mlir::LowerToLLVMOptions option(context); option.overrideIndexBitwidth(32); - TritonGPUToLLVMTypeConverter typeConverter(context, option); + TritonIntelGPUToLLVMTypeConverter typeConverter(context, option); TritonLLVMConversionTarget convTarget(*context); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); @@ -100,7 +99,7 @@ struct ConvertTritonGPUToLLVM // Lower functions { mlir::LowerToLLVMOptions option(context); - TritonGPUToLLVMTypeConverter typeConverter(context, option); + TritonIntelGPUToLLVMTypeConverter typeConverter(context, option); TritonLLVMFunctionConversionTarget funcTarget(*context); RewritePatternSet funcPatterns(context); pipelineManager.populateFunctionConversionPatterns( diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVMBase.h b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVMBase.h index 0f802b32a9..ce91881a11 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVMBase.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVMBase.h @@ -6,7 +6,8 @@ // and #include "triton/Analysis/Allocation.h" -#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "intel/include/TritonIntelGPUToLLVM/TypeConverter.h" + // #include "TritonIntelGPUToLLVM/Passes.h" #include "Utility.h" @@ -56,10 +57,10 @@ class ConvertTritonGPUOpToLLVMPattern ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {} protected: - TritonGPUToLLVMTypeConverter *getTypeConverter() const { + TritonIntelGPUToLLVMTypeConverter *getTypeConverter() const { LLVMTypeConverter *ret = ((ConvertTritonGPUOpToLLVMPatternBase *)this)->getTypeConverter(); - return (TritonGPUToLLVMTypeConverter *)ret; + return (TritonIntelGPUToLLVMTypeConverter *)ret; } }; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index 91ffc97951..a64ca9b887 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -1,5 +1,6 @@ #include "PatternTritonGPUOpToLLVM.h" #include "triton/Analysis/Utility.h" + #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGEN/IR/TritonGENDialect.h" @@ -388,8 +389,8 @@ class ArithConstantOpLowering } // namespace void mlir::triton::intel::populateTritonOpsToLLVMPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { + TritonIntelGPUToLLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TypeConverter.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TypeConverter.cpp new file mode 100644 index 0000000000..f3c1c659da --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TypeConverter.cpp @@ -0,0 +1,30 @@ +#include "intel/include/TritonIntelGPUToLLVM/TypeConverter.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +using namespace mlir; +using namespace mlir::triton; + +TritonIntelGPUToLLVMTypeConverter::TritonIntelGPUToLLVMTypeConverter( + MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis) + : TritonGPUToLLVMTypeConverter(ctx, option, analysis) { + // The following type conversions have been registered by the base class (in + // the constructor) + if (mlir::triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_BLOCK_PTR")) { + // tt::pointer to v2i32 + addConversion([&](PointerType type) -> std::optional { + if (isa(type.getPointeeType())) { + auto i32Type = mlir::IntegerType::get(type.getContext(), 32); + return mlir::VectorType::get(2, i32Type); + } + return LLVM::LLVMPointerType::get(type.getContext(), + type.getAddressSpace()); + }); + + // tensor type is flattened and divided by 16(subgroupSize) + addConversion([&](mlir::RankedTensorType type) -> mlir::Type { + return mlir::VectorType::get(type.getNumElements() / 16, + type.getElementType()); + }); + } +} From 8b9bc218a7970e410db78cab8c9375d7e82ca8d2 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 23 Apr 2024 20:28:41 +0000 Subject: [PATCH 09/17] [NFC]: Fix precommit Signed-off-by: Tiotto, Ettore --- .../tritongpu_to_llvm_intel_block_ptr.mlir | 4 ++-- .../lib/TritonIntelGPUToLLVM/CMakeLists.txt | 20 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir index 18b15f6af6..36c789b3aa 100644 --- a/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir +++ b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir @@ -2,7 +2,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 1 : i32} { tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { - // CHECK-LABEL: @matmul_kernel_with_block_pointers + // CHECK-LABEL: @matmul_kernel_with_block_pointers %c3_i32 = arith.constant 3 : i32 %c7_i32 = arith.constant 7 : i32 %c63_i32 = arith.constant 63 : i32 @@ -33,7 +33,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %11 = arith.muli %8, %c256_i32 : i32 %12 = arith.muli %1, %c8_i32 : i32 %13 = arith.addi %12, %11 : i32 - // CHECK: [[undef:%.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK: [[undef:%.*]] = llvm.mlir.undef : vector<2xi32> // CHECK-DAG: [[zero:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK-DAG: [[one:%.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK-NEXT: [[insert0:%.*]] = llvm.insertelement {{.*}}, [[undef]][[[zero]] : i32] : vector<2xi32> diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt index b6f123b2b6..d20e1f5ddd 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt @@ -1,28 +1,28 @@ add_triton_library(TritonIntelGPUToLLVM AllocateSharedMemory.cpp - ClusterOpsToLLVM.cpp - ControlFlowOpToLLVM.cpp + ClusterOpsToLLVM.cpp + ControlFlowOpToLLVM.cpp ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp ConvertLayoutOpToLLVM.cpp - DecomposeUnsupportedConversions.cpp + DecomposeUnsupportedConversions.cpp DotOpToLLVM/DPAS.cpp DotOpToLLVM/FMA.cpp DotOpToLLVM.cpp - ElementwiseOpToLLVM.cpp + ElementwiseOpToLLVM.cpp HistogramOpToLLVM.cpp LoadStoreOpToLLVM.cpp MakeRangeOpToLLVM.cpp - MemoryOpToLLVM.cpp - PrintOpToLLVM.cpp - ReduceOpToLLVM.cpp + MemoryOpToLLVM.cpp + PrintOpToLLVM.cpp + ReduceOpToLLVM.cpp TritonGPUToLLVM.cpp ScanOpToLLVM.cpp SPMDOpToLLVM.cpp - TargetInfo.cpp + TargetInfo.cpp TensorPtrOpsToLLVM.cpp TritonOpsToLLVM.cpp - TypeConverter.cpp - Utility.cpp + TypeConverter.cpp + Utility.cpp ViewOpToLLVM.cpp DEPENDS From 683ec3e0903110a0cf2fb85c4811e3a922400b0e Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 24 Apr 2024 18:21:20 +0000 Subject: [PATCH 10/17] Final code cleanup Signed-off-by: Tiotto, Ettore --- .../tritongpu_to_llvm_intel_block_ptr.mlir | 2 +- .../TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp | 1 - .../TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 145 ++++++++---------- 3 files changed, 61 insertions(+), 87 deletions(-) diff --git a/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir index 36c789b3aa..85d4b0a058 100644 --- a/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir +++ b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir @@ -35,8 +35,8 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %13 = arith.addi %12, %11 : i32 // CHECK: [[undef:%.*]] = llvm.mlir.undef : vector<2xi32> // CHECK-DAG: [[zero:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK-DAG: [[one:%.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK-NEXT: [[insert0:%.*]] = llvm.insertelement {{.*}}, [[undef]][[[zero]] : i32] : vector<2xi32> + // CHECK-NEXT: [[one:%.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK-NEXT: [[insert1:%.*]] = llvm.insertelement {{.*}}, [[insert0]][[[one]] : i32] : vector<2xi32> %14 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%13, %c0_i32] {order = array} : , 1> // CHECK: llvm.call @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp index 90db713e26..7012047a62 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp @@ -80,7 +80,6 @@ struct ConvertTritonGPUToLLVM MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); intel::TritonGPUToLLVMPipelineManager pipelineManager(mod, context); - mlir::LowerToLLVMOptions option(context); option.overrideIndexBitwidth(32); TritonIntelGPUToLLVMTypeConverter typeConverter(context, option); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index 8f2dbf4e3e..c652e448d7 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -13,11 +13,11 @@ namespace { VectorType getVectorType(RankedTensorType tensorType, Type elemType) { unsigned ratio = elemType.getIntOrFloatBitWidth() / tensorType.getElementTypeBitWidth(); - unsigned num = tensorType.getNumElements() / 16 / ratio; + unsigned num = (tensorType.getNumElements() / 16) / ratio; return vec_ty(elemType, num); }; -/// v2i32 [offsetX, offsetY] for 2D tensor desc +/// v2i32 [offsetX, offsetY] for 2D tensor desc. class MakeTensorPtrOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: @@ -27,16 +27,12 @@ class MakeTensorPtrOpConversion matchAndRewrite(MakeTensorPtrOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - - IntegerType i32Type = rewriter.getI32Type(); - VectorType v2i32 = vec_ty(i32Type, 2); + VectorType v2i32 = vec_ty(i32_ty, 2); Value offsetX = op.getOffsets()[1]; Value offsetY = op.getOffsets()[0]; Value payLoad = undef(v2i32); - Value idx0 = i32_val(0); - Value idx1 = i32_val(1); - payLoad = insert_element(payLoad, offsetX, idx0); - payLoad = insert_element(payLoad, offsetY, idx1); + payLoad = insert_element(payLoad, offsetX, i32_val(0)); + payLoad = insert_element(payLoad, offsetY, i32_val(1)); rewriter.replaceOp(op, payLoad); return success(); } @@ -53,8 +49,9 @@ class AdvanceOpConversion : public ConvertTritonGPUOpToLLVMPattern { matchAndRewrite(AdvanceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - SmallVector offsets = adaptor.getOffsets(); + ValueRange offsets = adaptor.getOffsets(); Value ptr = adaptor.getPtr(); + for (size_t i = 0; i < offsets.size(); ++i) { Value offset = offsets[i]; if (auto cst = dyn_cast(offset.getDefiningOp())) @@ -62,16 +59,12 @@ class AdvanceOpConversion : public ConvertTritonGPUOpToLLVMPattern { attr && attr.getInt() == 0) continue; - IntegerType i32Type = rewriter.getI32Type(); - Value idx = (i == 0) - ? rewriter.create( - loc, i32Type, rewriter.getIntegerAttr(i32Type, 1)) - : rewriter.create( - loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)); + Value idx = i32_val(!i); Value oldOffset = extract_element(ptr, idx); - Value newOffset = add(i32Type, oldOffset, offset); + Value newOffset = add(i32_ty, oldOffset, offset); ptr = insert_element(ptr, newOffset, idx); } + rewriter.replaceOp(op, ptr); return success(); } @@ -111,7 +104,8 @@ class AdvanceOpConversion : public ConvertTritonGPUOpToLLVMPattern { /// elemSize 32) /// Arg 12: cache controls options (LSC_CACHE_OPTS) /// Arg 13: stored value -template +template ::value>> class LoadStorePrefetchOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: @@ -129,9 +123,6 @@ class LoadStorePrefetchOpConversion constexpr bool isLoad = std::is_same_v; constexpr bool isPrefetch = std::is_same_v; - IntegerType i16Type = rewriter.getI16Type(); - IntegerType i32Type = rewriter.getI32Type(); - IntegerType i64Type = rewriter.getI64Type(); bool vnni = false, transpose = false; if constexpr (isLoad) { auto idxAttr = op->template getAttrOfType("DotIdx"); @@ -144,8 +135,6 @@ class LoadStorePrefetchOpConversion unsigned vBlks = blockWidth == 32 ? 2 : 1; blockWidth = 16; unsigned blockHeight = tensorType.getShape()[0]; - Value idx0 = i32_val(0); - Value idx1 = i32_val(1); Value ptr = op.getPtr(); if (auto cast = dyn_cast(ptr.getDefiningOp())) @@ -162,22 +151,14 @@ class LoadStorePrefetchOpConversion Value bytes = i32_val(tensorType.getElementType().getIntOrFloatBitWidth() / 8); Value one = i32_val(1); - Value surfaceW = - rewriter.create(loc, i32Type, ptrOp.getShape()[1]); - surfaceW = rewriter.create(loc, surfaceW, bytes); - surfaceW = rewriter.create(loc, surfaceW, one); - Value surfaceH = - rewriter.create(loc, i32Type, ptrOp.getShape()[0]); - surfaceH = rewriter.create(loc, surfaceH, one); - Value surfaceP = - rewriter.create(loc, i32Type, ptrOp.getStrides()[0]); - surfaceP = rewriter.create(loc, surfaceP, bytes); - surfaceP = rewriter.create(loc, surfaceP, one); + Value surfaceW = sub(mul(trunc(i32_ty, ptrOp.getShape()[1]), bytes), one); + Value surfaceH = sub(trunc(i32_ty, ptrOp.getShape()[0]), one); + Value surfaceP = sub(mul(trunc(i32_ty, ptrOp.getStrides()[0]), bytes), one); rewriter.restoreInsertionPoint(insertPoint); Value tensorPtr = adaptor.getPtr(); - Value offsetX = extract_element(tensorPtr, idx0); - Value offsetY = extract_element(tensorPtr, idx1); + Value offsetX = extract_element(tensorPtr, i32_val(0)); + Value offsetY = extract_element(tensorPtr, i32_val(1)); if constexpr (isLoad) { Type resType = @@ -185,12 +166,12 @@ class LoadStorePrefetchOpConversion auto idxAttr = op->template getAttrOfType("DotIdx"); unsigned idx = idxAttr.getInt(); Type vectorType = - getVectorType(cast(op->getResult(0).getType()), - idx == 0 ? i16Type : i32Type); + getVectorType(cast(op.getResult().getType()), + idx == 0 ? i16_ty : i32_ty); auto load = rewriter.create( loc, vectorType, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, dataSize, blockWidth, blockHeight, vBlks, transpose, vnni); - auto cast = rewriter.create(loc, resType, load); + auto cast = bitcast(load, resType); rewriter.replaceOp(op, cast); } else if constexpr (isPrefetch) { rewriter.create( @@ -200,28 +181,28 @@ class LoadStorePrefetchOpConversion rewriter.eraseOp(op); } else { VectorType vectorType = getVectorType( - cast(op.getValue().getType()), i32Type); - Value cast = - rewriter.create(loc, vectorType, adaptor.getValue()); + cast(op.getValue().getType()), i32_ty); + Value cast = bitcast(adaptor.getValue(), vectorType); rewriter.create( loc, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, dataSize, blockWidth, blockHeight, vBlks, transpose, vnni, cast); rewriter.eraseOp(op); } + return success(); } }; -// TritonGen DpasOp Desc: XeHP SDV: dot product accumulate systolic -// Output: dst -// Arg 0: src0(acc) -// Arg 1: src1 -// Arg 2: src2 -// Arg 3: src1's precision -// Arg 4: src2's precision -// Arg 5: systolic depth -// Arg 6: repeat count -// Arg 7: isDpasw +/// TritonGen DpasOp Desc: XeHP SDV: dot product accumulate systolic +/// Output: dst +/// Arg 0: src0(acc) +/// Arg 1: src1 +/// Arg 2: src2 +/// Arg 3: src1's precision +/// Arg 4: src2's precision +/// Arg 5: systolic depth +/// Arg 6: repeat count +/// Arg 7: isDpasw class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: using ConvertTritonGPUOpToLLVMPattern::ConvertTritonGPUOpToLLVMPattern; @@ -229,9 +210,9 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { matchAndRewrite(DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto encodePrecision = [&](Type type) -> TritonGEN::PrecisionType { - if (type == rewriter.getBF16Type()) + if (type == bf16_ty) return TritonGEN::PrecisionType::BF16; - else if (type == rewriter.getF16Type()) + else if (type == f16_ty) return TritonGEN::PrecisionType::FP16; else if (type == rewriter.getTF32Type()) return TritonGEN::PrecisionType::TF32; @@ -249,26 +230,23 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { TritonGEN::PrecisionTypeAttr::get(rewriter.getContext(), precBTy); Location loc = op.getLoc(); - IntegerType i16Type = rewriter.getI16Type(); - IntegerType i32Type = rewriter.getI32Type(); - VectorType typeA = - getVectorType(cast(op.getA().getType()), i16Type); - Value castA = rewriter.create(loc, typeA, adaptor.getA()); + Type typeA = + getVectorType(cast(op.getA().getType()), i16_ty); + Value castA = bitcast(adaptor.getA(), typeA); VectorType typeB = - getVectorType(cast(op.getB().getType()), i32Type); - Value castB = rewriter.create(loc, typeB, adaptor.getB()); - auto rc = IntegerAttr::get(i32Type, 8); + getVectorType(cast(op.getB().getType()), i32_ty); + Value castB = bitcast(adaptor.getB(), typeB); + auto rc = IntegerAttr::get(i32_ty, 8); // sd dpasW fixed in genx.dpas lowering. - auto dpas = rewriter.create( - loc, adaptor.getC().getType(), adaptor.getC(), castA, castB, precA, + rewriter.replaceOpWithNewOp( + op, adaptor.getC().getType(), adaptor.getC(), castA, castB, precA, precB, rc); - rewriter.replaceOp(op, dpas); return success(); } }; /// %glue = ttgi.glue %a, %b : tensor<4xf16>, tensor<4xf16> : tensor<8xf16> -/// is converted to +/// is converted to: /// %glue = llvm.shufflevector %a, %b : [0, 1, 2, 3, 4, 5, 6, 7] : vector<8xf16> class GlueOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: @@ -303,13 +281,13 @@ class GlueOpConversion : public ConvertTritonGPUOpToLLVMPattern { DenseI32ArrayAttr attr23 = rewriter.getDenseI32ArrayAttr(indices); auto shfl23 = rewriter.create( loc, subType, operands[2], operands[3], attr23); - auto shfl = rewriter.create(loc, dstType, shfl01, + rewriter.replaceOpWithNewOp(op, dstType, shfl01, shfl23, attr); - rewriter.replaceOp(op, shfl); } break; default: llvm_unreachable("add more support for glue op to llvm"); } + return success(); } }; @@ -351,25 +329,22 @@ class ArithConstantOpLowering if (!srcType || srcType.getNumElements() == 1) return failure(); - // arith.constant should only have vector or tenor types. - assert((isa(srcType))); + assert((isa(srcType)) && + "arith.constant should only have vector or tensor type"); - Type dstType = getTypeConverter()->convertType(srcType); - if (!dstType) - return failure(); - - auto dstElementsAttr = dyn_cast(op.getValue()); - if (!dstElementsAttr) - return failure(); + if (Type dstType = getTypeConverter()->convertType(srcType)) { + if (auto dstElementsAttr = dyn_cast(op.getValue())) { + auto vecType = cast(dstType); + VectorType dstAttrType = + vec_ty(vecType.getElementType(), vecType.getNumElements()); + dstElementsAttr = dstElementsAttr.resizeSplat(dstAttrType); + rewriter.replaceOpWithNewOp(op, vecType, + dstElementsAttr); + return success(); + } + } - auto vecType = cast(dstType); - ShapedType dstAttrType = - vec_ty(vecType.getElementType(), vecType.getNumElements()); - dstElementsAttr = dstElementsAttr.resizeSplat(dstAttrType); - auto newOp = - rewriter.create(loc, dstType, dstElementsAttr); - rewriter.replaceOp(op, newOp); - return success(); + return failure(); } }; From 6604da51d1b7375fa41540afd964309f0b064ab6 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 24 Apr 2024 18:21:51 +0000 Subject: [PATCH 11/17] Fix precommit Signed-off-by: Tiotto, Ettore --- test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir | 2 +- third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir index 85d4b0a058..4fefab1733 100644 --- a/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir +++ b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir @@ -36,7 +36,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c // CHECK: [[undef:%.*]] = llvm.mlir.undef : vector<2xi32> // CHECK-DAG: [[zero:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: [[insert0:%.*]] = llvm.insertelement {{.*}}, [[undef]][[[zero]] : i32] : vector<2xi32> - // CHECK-NEXT: [[one:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: [[one:%.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK-NEXT: [[insert1:%.*]] = llvm.insertelement {{.*}}, [[insert0]][[[one]] : i32] : vector<2xi32> %14 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%13, %c0_i32] {order = array} : , 1> // CHECK: llvm.call @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt index b0dea3563f..0efdc17af4 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt @@ -20,7 +20,7 @@ add_triton_library(TritonIntelGPUToLLVM SPMDOpToLLVM.cpp TargetInfo.cpp TensorPtrOpsToLLVM.cpp - TritonGPUToLLVM.cpp + TritonGPUToLLVM.cpp TritonOpsToLLVM.cpp TypeConverter.cpp Utility.cpp From 70a2d7b27b9d8f9b72fc473537f61c7bc06297b8 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 24 Apr 2024 18:34:34 +0000 Subject: [PATCH 12/17] Address code review comments Signed-off-by: Tiotto, Ettore --- test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir index 4fefab1733..8aca7dc465 100644 --- a/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir +++ b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir @@ -1,4 +1,4 @@ -// RUN: TRITON_INTEL_ENABLE_BLOCK_PTR=1 triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm | FileCheck %s +// RUN: TRITON_INTEL_ENABLE_BLOCK_PTR=1 triton-opt %s --convert-triton-intel-gpu-to-llvm | FileCheck %s module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 1 : i32} { tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { From 8509d8a290e500d1aca8d5683cc0bdadc9eba232 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 24 Apr 2024 20:40:21 +0000 Subject: [PATCH 13/17] Address code review comments Signed-off-by: Tiotto, Ettore --- .../tritongpu_to_llvm_intel_block_ptr.mlir | 46 +++++++++++++------ .../TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp | 2 +- .../TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 31 +++++-------- 3 files changed, 43 insertions(+), 36 deletions(-) diff --git a/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir index 8aca7dc465..cc315e6195 100644 --- a/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir +++ b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir @@ -1,6 +1,12 @@ // RUN: TRITON_INTEL_ENABLE_BLOCK_PTR=1 triton-opt %s --convert-triton-intel-gpu-to-llvm | FileCheck %s -module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 1 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 1 : i32} { + // CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockWrite.v8i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32, vector<8xi32>) + // CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {passthrough = ["convergent"]} + // CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<32xi32> + // CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v64i16(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<64xi16> + // CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { // CHECK-LABEL: @matmul_kernel_with_block_pointers %c3_i32 = arith.constant 3 : i32 @@ -33,13 +39,15 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %11 = arith.muli %8, %c256_i32 : i32 %12 = arith.muli %1, %c8_i32 : i32 %13 = arith.addi %12, %11 : i32 - // CHECK: [[undef:%.*]] = llvm.mlir.undef : vector<2xi32> - // CHECK-DAG: [[zero:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK-NEXT: [[insert0:%.*]] = llvm.insertelement {{.*}}, [[undef]][[[zero]] : i32] : vector<2xi32> - // CHECK-NEXT: [[one:%.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK-NEXT: [[insert1:%.*]] = llvm.insertelement {{.*}}, [[insert0]][[[one]] : i32] : vector<2xi32> + // CHECK: [[UNDEF:%.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: [[INSERT0:%.*]] = llvm.insertelement {{.*}}, [[UNDEF]][[[ZERO]] : i32] : vector<2xi32> + // CHECK-NEXT: [[ONE:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: [[INSERT1:%.*]] = llvm.insertelement {{.*}}, [[INSERT0]][[[ONE]] : i32] : vector<2xi32> %14 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%13, %c0_i32] {order = array} : , 1> - // CHECK: llvm.call @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid + + // CHECK: [[PTR:%.*]] = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64 + // CHECK: llvm.call @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid([[PTR]], {{.*}}) triton_intel_gpu.prefetch %14 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> %18 = arith.divsi %1, %c4_i32 : i32 %19 = arith.andi %18, %c7_i32 : i32 @@ -58,16 +66,23 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %62 = arith.cmpi slt, %40, %c4096_i32 : i32 cf.cond_br %62, ^bb2, ^bb3 ^bb2: - // CHECK: [[A:%.*]] = llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v64i16({{.*}} -> vector<64xi16> + // CHECK: [[A_PTR:%.*]] = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64 + // CHECK: [[A:%.*]] = llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v64i16([[A_PTR]], {{.*}} -> vector<64xi16> // CHECK-NEXT: [[castA:%.*]] = llvm.bitcast [[A]] : vector<64xi16> to vector<64xf16> - // CHECK: [[B0:%.*]] = llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v32i32({{.*}} -> vector<32xi32> + // CHECK: [[B_PTR:%.*]] = llvm.ptrtoint %arg1 : !llvm.ptr<1> to i64 + // CHECK: [[B0:%.*]] = llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v32i32([[B_PTR]], {{.*}} -> vector<32xi32> // CHECK-NEXT: [[castB:%.*]] = llvm.bitcast [[B0]] : vector<32xi32> to vector<64xf16> // CHECK: [[B1:%.*]] = llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v32i32({{.*}} -> vector<32xi32> - // CHECK: [[subA:%.*]] = llvm.shufflevector [[castA]], [[castA]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<64xf16> - // CHECK: [[subB:%.*]] = llvm.shufflevector [[castB]], [[castB]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<64xf16> - // CHECK-NEXT: [[castDotA:%.*]] = llvm.bitcast [[subA]] : vector<8xf16> to vector<8xi16> - // CHECK-NEXT: [[castDotB:%.*]] = llvm.bitcast [[subB]] : vector<16xf16> to vector<8xi32> - // CHECK: llvm.call @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f([[castDotA]], [[castDotB]], {{.*}} -> vector<8xf32> + // CHECK: [[subA1:%.*]] = llvm.shufflevector [[castA]], [[castA]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<64xf16> + // CHECK: [[subB1:%.*]] = llvm.shufflevector [[castB]], [[castB]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<64xf16> + // CHECK-NEXT: [[castDotA1:%.*]] = llvm.bitcast [[subA1]] : vector<8xf16> to vector<8xi16> + // CHECK-NEXT: [[castDotB1:%.*]] = llvm.bitcast [[subB1]] : vector<16xf16> to vector<8xi32> + // CHECK: llvm.call @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f([[castDotA1]], [[castDotB1]], {{.*}} -> vector<8xf32> + // CHECK: [[subA2:%.*]] = llvm.shufflevector [[castA]], [[castA]] [32, 33, 34, 35, 36, 37, 38, 39] : vector<64xf16> + // CHECK: [[subB2:%.*]] = llvm.shufflevector [[castB]], [[castB]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<64xf16> + // CHECK-NEXT: [[castDotA2:%.*]] = llvm.bitcast [[subA2]] : vector<8xf16> to vector<8xi16> + // CHECK-NEXT: [[castDotB2:%.*]] = llvm.bitcast [[subB2]] : vector<16xf16> to vector<8xi32> + // CHECK: llvm.call @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f([[castDotA2]], [[castDotB2]], {{.*}} -> vector<8xf32> %63 = tt.load %57 {DotIdx = 0 : i32, boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> %64 = tt.load %58 {DotIdx = 1 : i32, boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> %65 = tt.load %59 {DotIdx = 1 : i32, boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> @@ -87,7 +102,8 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c cf.br ^bb1(%119, %71, %115, %117, %118 : i32, tensor<8x16xf32>, !tt.ptr, 1>, !tt.ptr, 1>, !tt.ptr, 1>) ^bb3: %120 = tt.make_tensor_ptr %arg2, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%21, %36] {order = array} : , 1> - // CHECK: llvm.call @llvm.genx.GenISA.LSC2DBlockWrite.v8i32 + // CHECK: [[RES_PTR:%.*]] = llvm.ptrtoint %arg2 : !llvm.ptr<1> to i64 + // CHECK: llvm.call @llvm.genx.GenISA.LSC2DBlockWrite.v8i32([[RES_PTR]], {{.*}} tt.store %120, %41 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr, 1> tt.return } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp index 7012047a62..79a362d5f3 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp @@ -10,7 +10,6 @@ #include "intel/include/GPUToTritonGEN/GPUToTritonGENPass.h" #include "intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h" -#include "intel/include/TritonIntelGPUToLLVM/Passes.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" @@ -79,6 +78,7 @@ struct ConvertTritonGPUToLLVM void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); + intel::TritonGPUToLLVMPipelineManager pipelineManager(mod, context); mlir::LowerToLLVMOptions option(context); option.overrideIndexBitwidth(32); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index c652e448d7..3c53bfbcdc 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -119,16 +119,6 @@ class LoadStorePrefetchOpConversion assert(tensorType.getRank() <= 2 && "only support 1d/2d load/store/prefetch for now"); - Location loc = op.getLoc(); - constexpr bool isLoad = std::is_same_v; - constexpr bool isPrefetch = std::is_same_v; - - bool vnni = false, transpose = false; - if constexpr (isLoad) { - auto idxAttr = op->template getAttrOfType("DotIdx"); - vnni = idxAttr.getInt() == 1 ? true : false; - } - unsigned dataSize = tensorType.getElementType().getIntOrFloatBitWidth(); unsigned blockWidth = tensorType.getShape()[1]; assert(blockWidth == 16 || blockWidth == 32 && "only support 16/32 block"); @@ -148,6 +138,7 @@ class LoadStorePrefetchOpConversion OpBuilder::InsertPoint insertPoint = rewriter.saveInsertionPoint(); rewriter.setInsertionPointAfter(ptrOp); + Location loc = op.getLoc(); Value bytes = i32_val(tensorType.getElementType().getIntOrFloatBitWidth() / 8); Value one = i32_val(1); @@ -160,32 +151,32 @@ class LoadStorePrefetchOpConversion Value offsetX = extract_element(tensorPtr, i32_val(0)); Value offsetY = extract_element(tensorPtr, i32_val(1)); - if constexpr (isLoad) { - Type resType = - this->getTypeConverter()->convertType(op->getResult(0).getType()); + if constexpr (std::is_same_v) { auto idxAttr = op->template getAttrOfType("DotIdx"); unsigned idx = idxAttr.getInt(); + Type resType = + this->getTypeConverter()->convertType(op->getResult(0).getType()); Type vectorType = getVectorType(cast(op.getResult().getType()), idx == 0 ? i16_ty : i32_ty); + bool vnni = (idx == 1) && dataSize <= 32; auto load = rewriter.create( loc, vectorType, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, - dataSize, blockWidth, blockHeight, vBlks, transpose, vnni); - auto cast = bitcast(load, resType); - rewriter.replaceOp(op, cast); - } else if constexpr (isPrefetch) { + dataSize, blockWidth, blockHeight, vBlks, false /* transpose*/, vnni); + rewriter.replaceOp(op, bitcast(load, resType)); + } else if constexpr (std::is_same_v) { rewriter.create( loc, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, dataSize, - blockWidth, blockHeight, vBlks, transpose, vnni, + blockWidth, blockHeight, vBlks, false /*transpose*/, false /*vnni*/, TritonGEN::PrefetchCacheControl::L1C_L3C); rewriter.eraseOp(op); } else { VectorType vectorType = getVectorType( cast(op.getValue().getType()), i32_ty); - Value cast = bitcast(adaptor.getValue(), vectorType); rewriter.create( loc, base, surfaceW, surfaceH, surfaceP, offsetX, offsetY, dataSize, - blockWidth, blockHeight, vBlks, transpose, vnni, cast); + blockWidth, blockHeight, vBlks, false /*transpose*/, false /*vnni*/, + bitcast(adaptor.getValue(), vectorType)); rewriter.eraseOp(op); } From be9c8980030fe596e6e6b02e6c5bbe26a568487b Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 24 Apr 2024 21:06:27 +0000 Subject: [PATCH 14/17] Address code review comments Signed-off-by: Tiotto, Ettore --- .../TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 33 ------------------- 1 file changed, 33 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index 3c53bfbcdc..10bb805a0f 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -307,38 +307,6 @@ class ExtractOpConversion : public ConvertTritonGPUOpToLLVMPattern { } }; -// FIXME: support it in upstream constantOpLowering -class ArithConstantOpLowering - : public ConvertTritonGPUOpToLLVMPattern { - using ConvertTritonGPUOpToLLVMPattern< - mlir::arith::ConstantOp>::ConvertTritonGPUOpToLLVMPattern; - LogicalResult - matchAndRewrite(mlir::arith::ConstantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - auto srcType = dyn_cast(op.getType()); - if (!srcType || srcType.getNumElements() == 1) - return failure(); - - assert((isa(srcType)) && - "arith.constant should only have vector or tensor type"); - - if (Type dstType = getTypeConverter()->convertType(srcType)) { - if (auto dstElementsAttr = dyn_cast(op.getValue())) { - auto vecType = cast(dstType); - VectorType dstAttrType = - vec_ty(vecType.getElementType(), vecType.getNumElements()); - dstElementsAttr = dstElementsAttr.resizeSplat(dstAttrType); - rewriter.replaceOpWithNewOp(op, vecType, - dstElementsAttr); - return success(); - } - } - - return failure(); - } -}; - } // namespace void mlir::triton::intel::populateTritonOpsToLLVMPatterns( @@ -353,5 +321,4 @@ void mlir::triton::intel::populateTritonOpsToLLVMPatterns( patterns.add>(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); } From cc10968915d297efacd7a0295f523660f94438ab Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 24 Apr 2024 21:41:31 +0000 Subject: [PATCH 15/17] Fix precommit Signed-off-by: Tiotto, Ettore --- test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir index cc315e6195..03c82d71c8 100644 --- a/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir +++ b/test/Conversion/tritongpu_to_llvm_intel_block_ptr.mlir @@ -66,7 +66,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 %62 = arith.cmpi slt, %40, %c4096_i32 : i32 cf.cond_br %62, ^bb2, ^bb3 ^bb2: - // CHECK: [[A_PTR:%.*]] = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64 + // CHECK: [[A_PTR:%.*]] = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64 // CHECK: [[A:%.*]] = llvm.call @llvm.genx.GenISA.LSC2DBlockRead.v64i16([[A_PTR]], {{.*}} -> vector<64xi16> // CHECK-NEXT: [[castA:%.*]] = llvm.bitcast [[A]] : vector<64xi16> to vector<64xf16> // CHECK: [[B_PTR:%.*]] = llvm.ptrtoint %arg1 : !llvm.ptr<1> to i64 From 1bd65870b4f2709f935ea71ad91e805c1da8c3a8 Mon Sep 17 00:00:00 2001 From: David <110815347+Dewei-Wang-sh@users.noreply.github.com> Date: Thu, 25 Apr 2024 09:22:43 +0800 Subject: [PATCH 16/17] Update third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp --- third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index 10bb805a0f..ee3d19b62f 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -207,7 +207,7 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { return TritonGEN::PrecisionType::FP16; else if (type == rewriter.getTF32Type()) return TritonGEN::PrecisionType::TF32; - assert(false && "add more support"); + llvm_unreachable("add more support for PrecisionType"); return TritonGEN::PrecisionType::UNUSED; }; From e3f2e4986df2f695f246cb23c59bf46aa239ae86 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 25 Apr 2024 17:32:31 +0000 Subject: [PATCH 17/17] Cleanup Signed-off-by: Tiotto, Ettore --- .../include/TritonIntelGPUToLLVM/AsmFormat.h | 30 -- .../TritonIntelGPUToLLVM/PTXAsmFormat.h | 341 ------------------ .../TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp | 16 +- 3 files changed, 12 insertions(+), 375 deletions(-) delete mode 100644 third_party/intel/include/TritonIntelGPUToLLVM/AsmFormat.h delete mode 100644 third_party/intel/include/TritonIntelGPUToLLVM/PTXAsmFormat.h diff --git a/third_party/intel/include/TritonIntelGPUToLLVM/AsmFormat.h b/third_party/intel/include/TritonIntelGPUToLLVM/AsmFormat.h deleted file mode 100644 index acbc59a5cd..0000000000 --- a/third_party/intel/include/TritonIntelGPUToLLVM/AsmFormat.h +++ /dev/null @@ -1,30 +0,0 @@ -#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ -#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ - -#include "mlir/IR/Value.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringRef.h" -#include -#include - -namespace mlir { -class ConversionPatternRewriter; -class Location; - -namespace triton { -using llvm::StringRef; - -namespace intel { - -inline std::string strJoin(llvm::ArrayRef strs, - llvm::StringRef delimiter) { - return llvm::join(strs.begin(), strs.end(), delimiter); -} - -} // namespace intel -} // namespace triton -} // namespace mlir - -#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ diff --git a/third_party/intel/include/TritonIntelGPUToLLVM/PTXAsmFormat.h b/third_party/intel/include/TritonIntelGPUToLLVM/PTXAsmFormat.h deleted file mode 100644 index 8e6cd9488f..0000000000 --- a/third_party/intel/include/TritonIntelGPUToLLVM/PTXAsmFormat.h +++ /dev/null @@ -1,341 +0,0 @@ -#ifndef TRITON_CONVERSION_TRITON_INTEL_GPU_TO_LLVM_PTX_ASM_FORMAT_H_ -#define TRITON_CONVERSION_TRITON_INTEL_GPU_TO_LLVM_PTX_ASM_FORMAT_H_ - -#include "mlir/IR/Value.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include -#include - -namespace mlir { -class ConversionPatternRewriter; -class Location; - -namespace triton { -using llvm::StringRef; - -namespace intel { - -struct PTXInstr; -struct PTXInstrCommon; -struct PTXInstrExecution; - -// PTXBuilder helps to manage a PTX asm program consists of one or multiple -// instructions. -// -// A helper for building an ASM program, the objective of PTXBuilder is to give -// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear. -// Currently, several factors are introduced to reduce the need for mixing -// string and C++ if-else code. -// -// Usage: -// To build: @$3 asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k), -// "b"(p)); -// -// PTXBuilder builder; -// auto& add = builder.create<>(); -// add.predicate(pVal).o("lo").o("u32"); // add any suffix -// // predicate here binds %0 to pVal, pVal is a mlir::Value -// -// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal -// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal -// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal -// add(iOpr, jOpr, kOpr).predicate(predVal); // set operands and predicate -// -// To get the asm code: -// builder.dump() -// -// To get all the mlir::Value used in the PTX code, -// -// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal} -// -// To get the string containing all the constraints with "," separated, -// builder.getConstraints() // get "=r,r,k" -// -// PTXBuilder can build a PTX asm with multiple instructions, sample code: -// -// PTXBuilder builder; -// auto& mov = builder.create("mov"); -// auto& cp = builder.create("cp"); -// mov(...); -// cp(...); -// This will get a PTX code with two instructions. -// -// Similar to a C function, a declared PTXInstr instance can be launched -// multiple times with different operands, e.g. -// -// auto& mov = builder.create("mov"); -// mov(... some operands ...); -// mov(... some different operands ...); -// -// Finally, we will get a PTX code with two mov instructions. -// -// There are several derived instruction type for typical instructions, for -// example, the PtxIOInstr for ld and st instructions. -struct PTXBuilder { - struct Operand { - std::string constraint; - Value value; - int idx{-1}; - llvm::SmallVector list; - std::function repr; - - // for list - Operand() = default; - Operand(const Operation &) = delete; - Operand(Value value, StringRef constraint) - : constraint(constraint), value(value) {} - - bool isList() const { return !value && constraint.empty(); } - - Operand *listAppend(Operand *arg) { - list.push_back(arg); - return this; - } - - Operand *listGet(size_t nth) const { - assert(nth < list.size()); - return list[nth]; - } - - std::string dump() const; - }; - - template - INSTR *create(Args &&...args) { - instrs.emplace_back(std::make_unique(this, args...)); - return static_cast(instrs.back().get()); - } - - // Create a list of operands. - Operand *newListOperand() { return newOperand(); } - - Operand *newListOperand(ArrayRef> items) { - auto *list = newOperand(); - for (auto &item : items) { - list->listAppend(newOperand(item.first, item.second)); - } - return list; - } - - Operand *newListOperand(unsigned count, mlir::Value val, - const std::string &constraint) { - auto *list = newOperand(); - for (unsigned i = 0; i < count; ++i) { - list->listAppend(newOperand(val, constraint)); - } - return list; - } - - Operand *newListOperand(unsigned count, const std::string &constraint) { - auto *list = newOperand(); - for (unsigned i = 0; i < count; ++i) { - list->listAppend(newOperand(constraint)); - } - return list; - } - - // Create a new operand. It will not add to operand list. - // @value: the MLIR value bind to this operand. - // @constraint: ASM operand constraint, .e.g. "=r" - // @formatter: extra format to represent this operand in ASM code, default is - // "%{0}".format(operand.idx). - Operand *newOperand(mlir::Value value, StringRef constraint, - std::function formatter = nullptr); - - // Create a new operand which is written to, that is, the constraint starts - // with "=", e.g. "=r". - // If the operand will be used in predicated execution, - // users may want to initialize it before use. - // Otherwise if the register is only used in the true branch or the false - // branch but not both, the register is undefined and ptxas can perform - // aggressive optimizations that may lead to incorrect results. - Operand *newOperand(StringRef constraint, bool init = false); - - // Create a new operand that is tied to a previous operand. In this case the - // asm would be permitted to write to an input register. Instead of providing - // constraint code for this operand, the constraint code of the tied operand - // is used. - Operand *newOperand(unsigned operandIndex); - - // Create a constant integer operand. - Operand *newConstantOperand(int64_t v); - // Create a constant operand with explicit code specified. - Operand *newConstantOperand(const std::string &v); - - Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0); - - llvm::SmallVector getAllArgs() const; - - llvm::SmallVector getAllMLIRArgs() const; - - std::string getConstraints() const; - - std::string dump() const; - - mlir::Value launch(OpBuilder &rewriter, Location loc, Type resTy, - bool hasSideEffect = true, bool isAlignStack = false, - ArrayRef attrs = {}) const; - -private: - Operand *newOperand() { - argArchive.emplace_back(std::make_unique()); - return argArchive.back().get(); - } - - void initOperand(Operand *opr); - - // Make the operands in argArchive follow the provided \param order. - void reorderArgArchive(ArrayRef order) { - assert(order.size() == argArchive.size()); - // The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but - // it does necessary when onlyAttachMLIRArgs is true for the $0, $1... are - // determined by PTX code snippet passed from external. - sort(argArchive.begin(), argArchive.end(), - [&](std::unique_ptr &a, std::unique_ptr &b) { - auto ida = std::find(order.begin(), order.end(), a.get()); - auto idb = std::find(order.begin(), order.end(), b.get()); - assert(ida != order.end()); - assert(idb != order.end()); - return ida < idb; - }); - } - - friend struct PTXInstr; - friend struct PTXInstrCommon; - -protected: - llvm::SmallVector, 6> argArchive; - llvm::SmallVector, 2> instrs; - llvm::SmallVector, 4> executions; - int oprCounter{}; -}; - -// PTX instruction common interface. -// Put the generic logic for all the instructions here. -struct PTXInstrCommon { - explicit PTXInstrCommon(PTXBuilder *builder) : builder(builder) {} - - using Operand = PTXBuilder::Operand; - - // clang-format off - PTXInstrExecution& operator()() { return call({}); } - PTXInstrExecution& operator()(Operand* a) { return call({a}); } - PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); } - PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); } - PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d) { return call({a, b, c, d}); } - PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { return call({a, b, c, d, e}); } - PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { return call({a, b, c, d, e, f}); } - PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { return call({a, b, c, d, e, f, g}); } - // clang-format on - - // Set operands of this instruction. - PTXInstrExecution &operator()(llvm::ArrayRef oprs, - bool onlyAttachMLIRArgs = false); - -protected: - // "Call" the instruction with operands. - // \param oprs The operands of this instruction. - // \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments - // to the inline Asm without generating the operand ids(such as $0, $1) in PTX - // code. - PTXInstrExecution &call(llvm::ArrayRef oprs, - bool onlyAttachMLIRArgs = false); - - PTXBuilder *builder{}; - llvm::SmallVector instrParts; - - friend struct PTXInstrExecution; -}; - -template struct PTXInstrBase : public PTXInstrCommon { - using Operand = PTXBuilder::Operand; - - explicit PTXInstrBase(PTXBuilder *builder, const std::string &name) - : PTXInstrCommon(builder) { - o(name); - } - - // Append a suffix to the instruction. - // e.g. PTXInstr("add").o("s32") get a add.s32. - // A predicate is used to tell whether to apply the suffix, so that no if-else - // code needed. e.g. `PTXInstr("add").o("s32", isS32).o("u32", !isS32);` will - // get a `add.s32` if isS32 is true. - ConcreteT &o(const std::string &suffix, bool predicate = true) { - if (predicate) - instrParts.push_back(suffix); - return *static_cast(this); - } -}; - -struct PTXInstr : public PTXInstrBase { - using PTXInstrBase::PTXInstrBase; - - // Append a ".global" to the instruction. - PTXInstr &global(); - - // Append a ".shared" to the instruction. - PTXInstr &shared(); - - // Append a ".v[0-9]+" to the instruction - PTXInstr &v(int vecWidth, bool predicate = true); - - // Append a".b[0-9]+" to the instruction - PTXInstr &b(int width); -}; - -// Record the operands and context for "launching" a PtxInstr. -struct PTXInstrExecution { - using Operand = PTXBuilder::Operand; - - llvm::SmallVector argsInOrder; - - PTXInstrExecution() = default; - explicit PTXInstrExecution(PTXInstrCommon *instr, - llvm::ArrayRef oprs, - bool onlyAttachMLIRArgs) - : argsInOrder(oprs.begin(), oprs.end()), instr(instr), - onlyAttachMLIRArgs(onlyAttachMLIRArgs) {} - - // Prefix a predicate to the instruction. - PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") { - pred = instr->builder->newOperand(value, constraint); - return *this; - } - - // Prefix a !predicate to the instruction. - PTXInstrExecution &predicateNot(mlir::Value value, StringRef constraint) { - pred = instr->builder->newOperand(value, constraint); - pred->repr = [](int idx) { return "@!$" + std::to_string(idx); }; - return *this; - } - - std::string dump() const; - - SmallVector getArgList() const; - - PTXInstrCommon *instr{}; - Operand *pred{}; - bool onlyAttachMLIRArgs{}; -}; - -/// ====== Some instruction wrappers ====== -// We add the wrappers to make the usage more intuitive by avoiding mixing the -// PTX code with some trivial C++ code. - -struct PTXCpAsyncLoadInstr : PTXInstrBase { - explicit PTXCpAsyncLoadInstr(PTXBuilder *builder, - triton::CacheModifier modifier) - : PTXInstrBase(builder, "cp.async") { - o(triton::stringifyCacheModifier(modifier).str()); - o("shared"); - o("global"); - } -}; - -} // namespace intel -} // namespace triton -} // namespace mlir - -#endif diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index ee3d19b62f..fa485ca650 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -141,10 +141,18 @@ class LoadStorePrefetchOpConversion Location loc = op.getLoc(); Value bytes = i32_val(tensorType.getElementType().getIntOrFloatBitWidth() / 8); - Value one = i32_val(1); - Value surfaceW = sub(mul(trunc(i32_ty, ptrOp.getShape()[1]), bytes), one); - Value surfaceH = sub(trunc(i32_ty, ptrOp.getShape()[0]), one); - Value surfaceP = sub(mul(trunc(i32_ty, ptrOp.getStrides()[0]), bytes), one); + + auto calculateSurface = [&](Value shape, bool multiplyBytes) { + Value truncatedShape = trunc(i32_ty, shape); + if (multiplyBytes) + truncatedShape = mul(truncatedShape, bytes); + return sub(truncatedShape, i32_val(1)); + }; + + Value surfaceW = calculateSurface(ptrOp.getShape()[1], true); + Value surfaceH = calculateSurface(ptrOp.getShape()[0], false); + Value surfaceP = calculateSurface(ptrOp.getStrides()[0], true); + rewriter.restoreInsertionPoint(insertPoint); Value tensorPtr = adaptor.getPtr();