diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index d6530b0933..c5c78e6d5b 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -25,7 +25,6 @@ namespace triton { constexpr int patternBenefitDefault = 1; constexpr int patternBenefitPrioritizeOverLLVMConversions = 10; constexpr int patternBenefitClampOptimizedPattern = 20; -constexpr int patternBenefitConvertLayoutOptimizedPattern = 20; struct BackendCallbacks { /** @@ -50,7 +49,7 @@ void populateElementwiseOpToLLVMPatterns( // callback receives 1) the current source op, 2) the number of issued LLVM // instructions and 3) their input types. Each MLIR backend can provide a // callback and, thus, handle backend-specific behaviors. -void populateMemoryOpToLLVMPattern( +void populateMemoryOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit, std::optional backendCallbacks = std::nullopt); @@ -102,10 +101,6 @@ void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); -void populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern( - LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, - RewritePatternSet &patterns, PatternBenefit benefit); - void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, const TargetInfoBase &targetInfo, diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index bd2aca5fb3..0ee9236071 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -720,8 +720,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, auto ans = mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 && mmaLayout.getWarpsPerCTA()[1] == 1 && - !cvtNeedsSharedMemory(parentTy, srcTy) && - (elementTypeSize == 16 || elementTypeSize == 8) && + !cvtNeedsSharedMemory(parentTy, srcTy) && elementTypeSize == 8 && dotOperandLayout.getKWidth() == 32 / elementTypeSize; return ans; } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index fab60442d6..378d871201 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -16,11 +16,7 @@ namespace { -using ::mlir::LLVM::getMultiDimOffset; -using ::mlir::LLVM::getSharedMemoryObjectFromStruct; -using ::mlir::LLVM::getWrappedMultiDimOffset; -using ::mlir::LLVM::linearize; - +using namespace mlir; using namespace mlir::triton::gpu; // XXX(Keren): A temporary knob to control the use of legacy MMA conversion @@ -105,13 +101,14 @@ struct ConvertLayoutOpConversion // of performance issue observed. for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { SmallVector multiDimOffset = - getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, type, - multiDimCTAInRepId, shapePerCTATile); - SmallVector multiDimOffsetWrapped = getWrappedMultiDimOffset( - rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile, - shapePerCTA); - Value offset = linearize(rewriter, loc, multiDimOffsetWrapped, - paddedRepShape, outOrd); + LLVM::getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, + type, multiDimCTAInRepId, shapePerCTATile); + SmallVector multiDimOffsetWrapped = + LLVM::getWrappedMultiDimOffset(rewriter, loc, multiDimOffset, + origRepShape, shapePerCTATile, + shapePerCTA); + Value offset = LLVM::linearize(rewriter, loc, multiDimOffsetWrapped, + paddedRepShape, outOrd); auto elemPtrTy = smemBase.getType(); Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset); auto vecTy = vec_ty(llvmElemTy, vec); @@ -267,7 +264,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // conversions. TODO(jlebar): Eventually we want this to be the only pattern. explicit ConvertLayoutOpUsingLinearLayoutsConversion( LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, - PatternBenefit benefit = 2) + PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { } @@ -395,16 +392,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion if (!layoutIsOK(srcTy.getEncoding()) || !layoutIsOK(dstTy.getEncoding())) { return failure(); } - // FIXME [Dot LL] Remove this once we implement this trick in LLs - if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) { - return failure(); - } - - // The following check can be removed when generalized warp shuffle - // conversions are ready: - if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) { - return failure(); - } assert(cvtNeedsSharedMemory(srcTy, dstTy)); @@ -666,22 +653,17 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } // namespace -void mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern( +void mlir::triton::populateConvertLayoutOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit) { + if (useLegacyMMAConversion) { + // Prioritize the legacy MMA conversion over the LinearLayout conversion. + // Only for debugging purposes. + patterns.add(typeConverter, targetInfo, + benefit.getBenefit() + 1); + } patterns.add( typeConverter, targetInfo, benefit); -} - -void mlir::triton::populateConvertLayoutOpToLLVMPatterns( - LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, - RewritePatternSet &patterns, PatternBenefit benefit) { - // We prefer using the linear layout conversion, so it gets a higher benefit. - // Eventually the LL conversion will subsume all of the others and be the only - // one left. - mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern( - typeConverter, targetInfo, patterns, benefit.getBenefit() + 1); patterns.add( typeConverter, targetInfo, benefit); - patterns.add(typeConverter, targetInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 8911b60748..2b7026eaee 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -121,33 +121,12 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { // FIXME [Dot LL] // Do for all DotOperandEncodingAttr once we have LLs for all of them - static bool isSupportedDotOpLayout(MemDescType srcTy, - RankedTensorType dstTy) { - auto srcLayout = cast(srcTy.getEncoding()); - auto dstLayout = dstTy.getEncoding(); - auto bitwidth = dstTy.getElementTypeBitWidth(); - auto rank = dstTy.getRank(); + static bool isSupportedLayout(Attribute dstLayout) { + if (isa(dstLayout)) + return true; if (auto dot = dyn_cast(dstLayout)) { - auto vecWidth = 32 / bitwidth; - auto kWidth = dot.getKWidth(); - auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2; - if (auto mma = dyn_cast(dot.getParent())) { - auto needTrans = kOrder != srcLayout.getOrder()[0]; - auto canUseLdmatrix = - (bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth); - if (mma.isHopper()) { - // I think we should be able to remove this condition, but it's here - // as the legacy ldmatrix path does not support it - canUseLdmatrix &= srcTy.getElementTypeBitWidth() * kWidth == 32; - } - // If we remove this one, ldmatrix will IMA. It can probably be relaxed - // though - canUseLdmatrix &= - srcTy.getShape()[0] >= 8 && - srcTy.getShape()[1] >= 4 * kWidth & dstTy.getRank() <= 2; - return !canUseLdmatrix; - } - if (isa(dot.getParent())) + if (isa(dot.getParent())) return true; } return false; @@ -156,12 +135,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - MemDescType srcTy = op.getSrc().getType(); RankedTensorType dstTy = op.getType(); Attribute dstLayout = dstTy.getEncoding(); - if (isa(dstLayout) || - isSupportedDotOpLayout(srcTy, dstTy)) { + if (isSupportedLayout(dstLayout)) { return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter); } @@ -198,11 +174,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto loc = op.getLoc(); auto srcTy = op.getSrc().getType(); auto dstTy = op.getResult().getType(); - auto dstShape = dstTy.getShape(); - auto srcSharedLayout = cast(srcTy.getEncoding()); - assert((!isa(dstTy.getEncoding()) || - isSupportedDotOpLayout(srcTy, dstTy)) && - "Unexpected rank of ConvertLayout(shared->distributed)"); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( loc, adaptor.getSrc(), @@ -265,7 +236,7 @@ struct LocalStoreOpConversion } // namespace -void mlir::triton::populateMemoryOpToLLVMPattern( +void mlir::triton::populateMemoryOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit, std::optional backendCallbacks) { diff --git a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp index d9fb1d7e17..b6728d22b4 100644 --- a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp @@ -17,9 +17,11 @@ namespace { // dot(a, b, inputPrecision="tf32x3") -> // let aBig = f32ToTF32(a), aSmall = a - aBig; // let bBig = f32ToTF32(b), bSmall = b - bBig; -// dot(aSmall, bBig, inputPrecision="tf32") + -// dot(aBig, bSmall, inputPrecision="tf32") + -// dot(aBig, bBig, inputPrecision="tf32") +// let small = dot(aSmall, bBig, inputPrecision="tf32") + +// dot(aBig, bSmall, inputPrecision="tf32") +// let masked_nans = replaceNansWithZeros(small) +// let big = dot(aBig, bBig, inputPrecision="tf32") +// return big + masked_nans; class TF32x3 : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -62,6 +64,13 @@ class TF32x3 : public OpRewritePattern { InputPrecision::TF32, dotOp.getMaxNumImpreciseAcc()); }; + auto replaceNansWithZeros = [&](Value value) -> Value { + auto nans = rewriter.create( + dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value); + auto zero = zeroLike(value); + return rewriter.create(dotOp->getLoc(), nans, zero, + value); + }; auto aBig = f32ToTF32(dotOp.getA()); auto aSmall = sub(dotOp.getA(), aBig); @@ -73,7 +82,16 @@ class TF32x3 : public OpRewritePattern { auto dot1 = dot(aSmall, bBig, zero); auto dot2 = dot(aBig, bSmall, dot1); - auto dot3 = dot(aBig, bBig, dot2); + + // If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0. + // If rhs is +infinity, we will have: + // +infinity * 1.0 = +infinity + // +infinity * 0.0 = NaN + // We would get the wrong result if we sum these partial products. Instead, + // we must override any accumulated result if the last partial product is + // non-finite. + auto dot2withZeroedNans = replaceNansWithZeros(dot2); + auto dot3 = dot(aBig, bBig, dot2withZeroedNans); auto sum = add(dot3, dotOp.getC()); diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 01e8acf258..401cf193ef 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -399,7 +399,10 @@ struct MMAV3UseRegOperand dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/kWidth); auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), dotOperandEnc); - if (!matchMmaV3AndDotOperandLayout(srcTy, newTy)) + // TODO(Keren): relax the condition once + // https://github.com/triton-lang/triton/pull/5419 is merged + if (!cvtReordersRegisters(srcTy, newTy) && + !matchMmaV3AndDotOperandLayout(srcTy, newTy)) return failure(); Value newOperand = diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index 3ab65c6105..a122415ae1 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -103,7 +103,8 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, return op; } - assert("don't know how to predicate this op" && false); + op->emitError("pipeliner doesn't know how to predicate this op."); + llvm::report_fatal_error("Fatal pipeliner error"); return op; } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a12d0641e4..345178ccf6 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1690,6 +1690,48 @@ def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr): assert (torch.equal(X, Y)) +@pytest.mark.interpreter +@pytest.mark.skipif((is_cuda() and torch.cuda.get_device_capability()[0] < 9) or is_hip(), + reason="Requires compute capability >= 9 for NV") +def test_load_scope_sem_coop_grid_cta_not_one(device): + + @triton.jit + def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr): + numel = 512 + offset = tl.program_id(0) * BLOCK_SIZE + index = offset + mask = index < numel + a = tl.load(ptrs, mask=mask) + tl.store(ptrs, a) + + block_size = 128 + data = torch.zeros((128, ), device=device, dtype=torch.float32) + + out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=4, launch_cooperative_grid=True) + out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=4, launch_cooperative_grid=False) + + +@pytest.mark.interpreter +@pytest.mark.skipif(is_hip(), reason="Not implemented for AMD At this moment") +def test_load_scope_sem_coop_grid_cta_one(device): + + @triton.jit + def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr): + numel = 512 + offset = tl.program_id(0) * BLOCK_SIZE + index = offset + mask = index < numel + a = tl.load(ptrs, mask=mask) + tl.store(ptrs, a) + + block_size = 128 + data = torch.zeros((128, ), device=device, dtype=torch.float32) + + # Should do nothing different for num_ctas=1 (with coop launch grid) + out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=1, launch_cooperative_grid=True) + out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=1, launch_cooperative_grid=False) + + # --------------- # test cast # --------------- diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 573d9d4191..cf32451c5e 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -4,7 +4,7 @@ import os import time import inspect -from typing import Dict +from typing import Dict, Tuple, List, Optional from .jit import KernelInterface from .errors import OutOfResources, PTXASError @@ -23,7 +23,7 @@ def __init__( restore_value, pre_hook=None, post_hook=None, - prune_configs_by: Dict = None, + prune_configs_by: Optional[Dict] = None, warmup=None, rep=None, use_cuda_graph=False, @@ -40,7 +40,7 @@ def __init__( else: self.configs = configs self.keys = key - self.cache = {} + self.cache: Dict[Tuple, Config] = {} self.arg_names = arg_names # Reset to zero or restore values @@ -211,7 +211,7 @@ def run(self, *args, **kwargs): self.nargs = None return ret - def prune_configs(self, kwargs): + def prune_configs(self, kwargs: Dict) -> List[Config]: pruned_configs = self.configs if self.early_config_prune: pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs) @@ -219,6 +219,10 @@ def prune_configs(self, kwargs): top_k = self.configs_top_k if isinstance(top_k, float) and top_k <= 1.0: top_k = int(len(self.configs) * top_k) + elif not isinstance(top_k, int): + # Slice index must be an integer + raise TypeError(f"Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int") + if len(pruned_configs) > top_k: est_timing = { config: self.perf_model( diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 70cdb612ae..07b82df414 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -501,7 +501,7 @@ def _call_hook( name = self.fn.__name__ module = self.fn.__module__ arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])]) - repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})" + repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}, launch_cooperative_grid={options.launch_cooperative_grid}]({arg_reprs})" class JitFunctionInfo: @@ -521,6 +521,7 @@ def __init__(self, module, name, jit_function): 'num_ctas': options.num_ctas, 'num_stages': options.num_stages, 'enable_fp_fusion': options.enable_fp_fusion, + 'launch_cooperative_grid': options.launch_cooperative_grid, 'extern_libs': options.extern_libs, 'configs': configs, 'specialization_data': specialization_data, diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 4b4a08857d..eec0c6248c 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -50,8 +50,6 @@ def _matmul_launch_metadata(grid, kernel, args): ret = {} M, N, K = args["M"], args["N"], args["K"] ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" - if "tiles_per_update" in args: - ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}, tiles_per_update={args['tiles_per_update']:02}]" if "c_ptr" in args: bytes_per_elem = args["c_ptr"].element_size() else: @@ -376,8 +374,7 @@ def matmul_tma_persistent(a, b): @triton.jit(launch_metadata=_matmul_launch_metadata) -def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, # - a_ptr, b_ptr, c_ptr, # +def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, # M, N, K, # BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # @@ -417,7 +414,6 @@ def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, # tile_id = start_pid - NUM_SMS ki = -1 - ni = -1 pid_m = 0 pid_n = 0 @@ -427,36 +423,10 @@ def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, # num_pid_in_group = GROUP_SIZE_M * num_pid_n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # Create an opaque value to prevent the descriptor creation from being - # hoisted out of the loop - zero = tl.inline_asm_elementwise("mov.b32 $0, 0;", "=r", [], dtype=tl.int32, is_pure=True, pack=1) for _ in range(0, k_tiles * tiles_per_SM): ki = tl.where(ki == k_tiles - 1, 0, ki + 1) if ki == 0: - ni += 1 - - # Simulate a grouped gemm - if ni == tiles_per_update: - a_desc = tl._experimental_make_tensor_descriptor( - a_ptr + zero, - shape=[M, K], - strides=[K, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], - ) - b_desc = tl._experimental_make_tensor_descriptor( - b_ptr + zero, - shape=[N, K], - strides=[K, 1], - block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], - ) - c_desc = tl._experimental_make_tensor_descriptor( - c_ptr + zero, - shape=[M, N], - strides=[N, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], - ) - ni = 0 tile_id += NUM_SMS group_id = tile_id // num_pid_in_group @@ -482,8 +452,7 @@ def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, # accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) -def matmul_descriptor_persistent(a, b, tiles_per_update): - # Autotuner does not work with TMA. Use manual config. +def matmul_descriptor_persistent(a, b): configs = { torch.float8_e4m3fn: { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4, @@ -513,7 +482,6 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) matmul_kernel_descriptor_persistent[grid]( - tiles_per_update, # a, b, c, # M, N, K, # BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # @@ -570,7 +538,7 @@ def bench_fn(reps, warmup_reps, fn, *args): fn(*args) -def bench(K, dtype, tiles_per_update, reps=1000, warmup_reps=10000): +def bench(K, dtype, reps=1000, warmup_reps=10000): M = 8192 N = 8192 a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) @@ -586,10 +554,10 @@ def bench(K, dtype, tiles_per_update, reps=1000, warmup_reps=10000): bench_fn(reps, warmup_reps, matmul_persistent, a, b.T) if supports_tma(): bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b) - bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b, tiles_per_update) + bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b) -def validate(M, N, K, dtype, tiles_per_update): +def validate(M, N, K, dtype): a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype) b = b.T.contiguous() @@ -599,7 +567,7 @@ def validate(M, N, K, dtype, tiles_per_update): naive_result = matmul(a, b.T) persistent_result = matmul_persistent(a, b.T) tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None - descriptor_persistent_result = matmul_descriptor_persistent(a, b, tiles_per_update) if supports_tma() else None + descriptor_persistent_result = matmul_descriptor_persistent(a, b) if supports_tma() else None if torch_result is not None: naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16), @@ -624,7 +592,7 @@ def validate(M, N, K, dtype, tiles_per_update): if tma_persistent_result is not None: print(f"TMA persistent: {naive_vs_tma_persistent} ", end="") if descriptor_persistent_result is not None: - print(f"Device TMA persistent: {naive_vs_descriptor_persistent} ", end="") + print(f"Tensor descriptor persistent: {naive_vs_descriptor_persistent} ", end="") print() @@ -644,13 +612,6 @@ def show_profile(precision, profile_name): parser.add_argument("-K", type=int, required=False, default=512) parser.add_argument("--K_range", type=int, nargs=2) parser.add_argument("--K_step", type=int, default=512) - parser.add_argument( - "--tiles_per_update", - type=int, - default=1, - help= - "Number of output tiles calculated for each update of the tma descriptor in matmul_descriptor_persistent_kernel", - ) parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16") args = parser.parse_args() @@ -666,11 +627,11 @@ def show_profile(precision, profile_name): torch.manual_seed(0) - validate(32, 32, 32, dtype, args.tiles_per_update) - validate(8192, 8192, 512, dtype, args.tiles_per_update) + validate(32, 32, 32, dtype) + validate(8192, 8192, 512, dtype) proton.start("matmul", hook="triton") for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): - bench(K, dtype, args.tiles_per_update) + bench(K, dtype) proton.finalize() show_profile(args.prec, "matmul") diff --git a/test/Conversion/nvgpu_to_llvm.mlir b/test/Conversion/nvgpu_to_llvm.mlir index beaa4c952d..c8917c2b0f 100644 --- a/test/Conversion/nvgpu_to_llvm.mlir +++ b/test/Conversion/nvgpu_to_llvm.mlir @@ -37,15 +37,30 @@ llvm.func @cluster_id() -> i32 { // ----- -// CHECK-LABEL: @st_matrix -llvm.func @st_matrix(%i: i32, %ptr: !llvm.ptr<3>) { +// CHECK-LABEL: @stmatrix +llvm.func @stmatrix(%i: i32, %ptr: !llvm.ptr<3>) { // CHECK: stmatrix.sync.aligned.m8n8.x4.shared.b16 [$0], {$1, $2, $3, $4}; nvgpu.stmatrix %ptr, %i, %i, %i, %i : !llvm.ptr<3>, i32, i32, i32, i32 + // CHECK: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [$0], {$1, $2, $3, $4}; + nvgpu.stmatrix %ptr, %i, %i, %i, %i {trans} : !llvm.ptr<3>, i32, i32, i32, i32 llvm.return } // ----- +// CHECK-LABEL: @ldmatrix +llvm.func @ldmatrix(%ptr: !llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> { + // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4]; + %0 = nvgpu.ldmatrix %ptr : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {$0, $1, $2, $3}, [$4]; + %1 = nvgpu.ldmatrix %ptr {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %2 = llvm.extractvalue %1[0] : !llvm.struct<(i32, i32, i32, i32)> + %3 = llvm.insertvalue %2, %0[0] : !llvm.struct<(i32, i32, i32, i32)> + llvm.return %3 : !llvm.struct<(i32, i32, i32, i32)> +} + +// ----- + !struct_128xf32 = !llvm.struct<( f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index ae20408c66..bc3f946387 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -845,10 +845,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { %AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> %BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> - // CHECK: llvm.inline_asm - // CHECK: ldmatrix.sync.aligned.m8n8.x4 - // CHECK: llvm.inline_asm - // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 + // CHECK: nvgpu.ldmatrix + // CHECK: nvgpu.ldmatrix %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> @@ -876,8 +874,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @convert_dot_fp8(%A: tensor<16x16xf8E5M2, #blocked0>, %B: tensor<16x16xf8E5M2, #blocked0>) { %AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> %BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> - // CHECK: llvm.inline_asm - // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 + // CHECK: nvgpu.ldmatrix %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_a> %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> @@ -1177,7 +1174,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a:!ttg.memdesc<128x32xf16, #shared, #smem>, %b:!ttg.memdesc<32x256xf16, #shared, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> - // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 + // CHECK: nvgpu.ldmatrix %a_mat = ttg.local_load %a : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #dot_operand_a> %b_mat = ttg.local_load %b : !ttg.memdesc<32x256xf16, #shared, #smem> -> tensor<32x256xf16, #dot_operand_b> @@ -1227,11 +1224,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - // CHECK: llvm.inline_asm - // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 + // CHECK: nvgpu.ldmatrix // CHECK-SAME: (i32, i32, i32, i32) - // CHECK: llvm.inline_asm - // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 + // CHECK: nvgpu.ldmatrix // CHECK-SAME: (i32, i32, i32, i32) %a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a> %b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b> @@ -1720,10 +1715,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr %f16_shared = ttg.local_alloc %f16_inp : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> %i16_shared = ttg.local_alloc %i16_inp : (tensor<16x16xi16, #blocked0>) -> !ttg.memdesc<16x16xi16, #shared0, #smem> - // CHECK: llvm.inline_asm - // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 - // CHECK: llvm.inline_asm - // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 + // CHECK: nvgpu.ldmatrix + // CHECK: nvgpu.ldmatrix %f16_dot = ttg.local_load %f16_shared : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> %i16_dot = ttg.local_load %i16_shared : !ttg.memdesc<16x16xi16, #shared0, #smem> -> tensor<16x16xi16, #dot_operand_b> diff --git a/test/TritonGPU/tf32x3-matmul.mlir b/test/TritonGPU/tf32x3-matmul.mlir new file mode 100644 index 0000000000..180a5c6331 --- /dev/null +++ b/test/TritonGPU/tf32x3-matmul.mlir @@ -0,0 +1,14 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-F32DotTC -canonicalize | FileCheck %s --check-prefixes=CHECK + +// CHECK: %[[DOT1:.*]] = tt.dot %[[LHS_LOW:.*]], %[[RHS_HIGH:.*]], %cst, inputPrecision = tf32 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> +// CHECK: %[[DOT2:.*]] = tt.dot %[[LHS_HIGH:.*]], %[[RHS_LOW:.*]], %[[DOT1]], inputPrecision = tf32 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> +// CHECK: %[[CMP:.*]] = arith.cmpf uno, %[[DOT2]], %[[DOT2]] : tensor<16x16xf32> +// CHECK: %[[MASKED:.*]] = arith.select %[[CMP]], %cst, %[[DOT2]] : tensor<16x16xi1>, tensor<16x16xf32> +// CHECK: %[[RESULT:.*]] = tt.dot %[[LHS_HIGH]], %[[RHS_HIGH]], %[[MASKED]], inputPrecision = tf32 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> + +module { + tt.func @dot_test(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { + %4 = tt.dot %arg0, %arg1, %arg2, inputPrecision = tf32x3 : tensor<16x16xf32> * tensor<16x16xf32> -> tensor<16x16xf32> + tt.return %4 : tensor<16x16xf32> + } +} diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index e10c5b1bc1..77a1233dbb 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -42,6 +42,9 @@ class HIPOptions: default_dot_input_precision: str = "ieee" allowed_dot_input_precisions: Tuple[str] = ("ieee", ) enable_fp_fusion: bool = True + # TODO: Implement cooperative grid launch for AMD: + # See: https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html + launch_cooperative_grid: bool = False matrix_instr_nonkdim: int = 0 kpack: int = 1 allow_flush_denorm: bool = False diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index abd86dc033..e2465f17b6 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -4,6 +4,7 @@ add_triton_library(TritonAMDGPUToLLVM ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp ConvertLayoutOpToLLVM.cpp + MemoryOpToLLVM.cpp DotOpToLLVM/MFMA.cpp DotOpToLLVM/WMMA.cpp DotOpToLLVM.cpp diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 3b61fb8cc4..3d88ea981b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -4,11 +4,9 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::triton::gpu::AMDMfmaEncodingAttr; using ::mlir::triton::gpu::AMDWmmaEncodingAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; -using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::MemDescType; using ::mlir::triton::gpu::SharedEncodingAttr; @@ -29,92 +27,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, } // namespace SharedToDotOperandWMMA namespace { -struct LocalLoadOpConversion - : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern< - triton::gpu::LocalLoadOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MemDescType srcTy = op.getSrc().getType(); - RankedTensorType dstTy = op.getType(); - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - if (isa(dstLayout) && - isa( - cast(dstLayout).getParent())) { - return lowerSharedToDotOperand(op, adaptor, getTypeConverter(), rewriter); - } - return failure(); - } - -private: - /// Lower ttg.local_load in dot operand layout if the operand parent layout is - /// MFMA or WMMA. - /// - /// \returns value with packed loaded values or empty value if this local_load - /// is not supproted. - Value lowerSharedToDotOperandMMA( - triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, - const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const { - auto loc = op.getLoc(); - Value src = op.getSrc(); - Value dst = op.getResult(); - auto llvmElemTy = typeConverter->convertType( - cast(src.getType()).getElementType()); - - auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), - llvmElemTy, rewriter); - Value res; - auto dopOpParent = dotOperandLayout.getParent(); - if (!isOuter && - isa(dopOpParent)) { - auto sharedToDotConvert = isa(dopOpParent) - ? SharedToDotOperandMFMA::convertLayout - : SharedToDotOperandWMMA::convertLayout; - res = sharedToDotConvert(dotOperandLayout.getOpIdx(), rewriter, loc, src, - dotOperandLayout, smemObj, typeConverter, - tid_val()); - } else { - assert(false && "unsupported layout found"); - } - return res; - } - - // shared -> matrix_core_dot_operand - LogicalResult - lowerSharedToDotOperand(triton::gpu::LocalLoadOp op, - triton::gpu::LocalLoadOpAdaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - Value src = op.getSrc(); - Value dst = op.getResult(); - auto dstTensorTy = cast(dst.getType()); - auto srcTensorTy = cast(src.getType()); - auto dotOperandLayout = - cast(dstTensorTy.getEncoding()); - auto sharedLayout = cast(srcTensorTy.getEncoding()); - - bool isOuter{}; - int K{}; - if (dotOperandLayout.getOpIdx() == 0) // $a - K = dstTensorTy.getShape()[sharedLayout.getOrder()[0]]; - else // $b - K = dstTensorTy.getShape()[sharedLayout.getOrder()[1]]; - isOuter = K == 1; - Value res = lowerSharedToDotOperandMMA(op, adaptor, typeConverter, rewriter, - dotOperandLayout, isOuter); - if (!res) - return failure(); - rewriter.replaceOp(op, res); - return success(); - } -}; struct ConvertLayoutOpMFMAToDotOpConversion : public ConvertOpToLLVMPattern { @@ -270,13 +182,9 @@ struct ConvertLayoutOpMFMAToDotOpConversion } // namespace -namespace mlir::triton::AMD { -void populateConvertLayoutOpToLLVMPatterns( +void mlir::triton::AMD::populateConvertLayoutOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, - RewritePatternSet &patterns, int numWarps, - ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { - patterns.add(typeConverter, benefit); + RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, targetInfo, benefit); } -} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp new file mode 100644 index 0000000000..57c97f7563 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp @@ -0,0 +1,123 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using ::mlir::triton::gpu::AMDMfmaEncodingAttr; +using ::mlir::triton::gpu::AMDWmmaEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::MemDescType; +using ::mlir::triton::gpu::SharedEncodingAttr; + +namespace SharedToDotOperandMFMA { +Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, + Location loc, Value tensor, + DotOperandEncodingAttr bEncoding, + const SharedMemoryObject &smemObj, + const LLVMTypeConverter *typeConverter, Value thread); +} // namespace SharedToDotOperandMFMA + +namespace SharedToDotOperandWMMA { +Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, + Location loc, Value tensor, + DotOperandEncodingAttr bEncoding, + const SharedMemoryObject &smemObj, + const LLVMTypeConverter *typeConverter, Value thread); +} // namespace SharedToDotOperandWMMA + +namespace { +struct LocalLoadOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + triton::gpu::LocalLoadOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemDescType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + if (isa(dstLayout) && + isa( + cast(dstLayout).getParent())) { + return lowerSharedToDotOperand(op, adaptor, getTypeConverter(), rewriter); + } + return failure(); + } + +private: + /// Lower ttg.local_load in dot operand layout if the operand parent layout is + /// MFMA or WMMA. + /// + /// \returns value with packed loaded values or empty value if this local_load + /// is not supproted. + Value lowerSharedToDotOperandMMA( + triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const { + auto loc = op.getLoc(); + Value src = op.getSrc(); + Value dst = op.getResult(); + auto llvmElemTy = typeConverter->convertType( + cast(src.getType()).getElementType()); + + auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getSrc(), llvmElemTy, rewriter); + Value res; + auto dopOpParent = dotOperandLayout.getParent(); + if (!isOuter && + isa(dopOpParent)) { + auto sharedToDotConvert = isa(dopOpParent) + ? SharedToDotOperandMFMA::convertLayout + : SharedToDotOperandWMMA::convertLayout; + res = sharedToDotConvert(dotOperandLayout.getOpIdx(), rewriter, loc, src, + dotOperandLayout, smemObj, typeConverter, + tid_val()); + } else { + assert(false && "unsupported layout found"); + } + return res; + } + + // shared -> matrix_core_dot_operand + LogicalResult + lowerSharedToDotOperand(triton::gpu::LocalLoadOp op, + triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + Value src = op.getSrc(); + Value dst = op.getResult(); + auto dstTensorTy = cast(dst.getType()); + auto srcTensorTy = cast(src.getType()); + auto dotOperandLayout = + cast(dstTensorTy.getEncoding()); + auto sharedLayout = cast(srcTensorTy.getEncoding()); + + bool isOuter{}; + int K{}; + if (dotOperandLayout.getOpIdx() == 0) // $a + K = dstTensorTy.getShape()[sharedLayout.getOrder()[0]]; + else // $b + K = dstTensorTy.getShape()[sharedLayout.getOrder()[1]]; + isOuter = K == 1; + Value res = lowerSharedToDotOperandMMA(op, adaptor, typeConverter, rewriter, + dotOperandLayout, isOuter); + if (!res) + return failure(); + rewriter.replaceOp(op, res); + return success(); + } +}; + +} // namespace + +void mlir::triton::AMD::populateMemoryOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h index b217fc4956..5701fa990e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -7,10 +7,14 @@ #include "triton/Analysis/AxisInfo.h" namespace mlir::triton::AMD { -void populateConvertLayoutOpToLLVMPatterns( - LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, - RewritePatternSet &patterns, int numWarps, - ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit); +void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateMemoryOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 31df3a8a60..0e29b0c00d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -171,8 +171,7 @@ struct ConvertTritonAMDGPUToLLVM }; AMD::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo, - patterns, numWarps, - axisInfoAnalysis, AMDBenefit); + patterns, AMDBenefit); mlir::triton::populateConvertLayoutOpToLLVMPatterns( typeConverter, targetInfo, patterns, commonBenefit); AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps, @@ -197,7 +196,8 @@ struct ConvertTritonAMDGPUToLLVM mlir::triton::BackendCallbacks callbacks; callbacks.localStoreOpConversion = storeOpConversionCallback; - mlir::triton::populateMemoryOpToLLVMPattern( + AMD::populateMemoryOpToLLVMPatterns(typeConverter, patterns, AMDBenefit); + mlir::triton::populateMemoryOpToLLVMPatterns( typeConverter, targetInfo, patterns, commonBenefit, callbacks); mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, patterns, commonBenefit); diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 5b8821228a..e44a684077 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -44,6 +44,7 @@ class XPUOptions: threads_per_warp: int = 32 optimize_epilogue: bool = False enable_fp_fusion: bool = True + launch_cooperative_grid: bool = False supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4nv", "fp8e4b15") deprecated_fp8_dtypes: Tuple[str] = () default_dot_input_precision: str = "tf32" diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 137fef4bd0..ccc52097d8 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -112,6 +112,7 @@ class CUDAOptions: cluster_dims: tuple = (1, 1, 1) ptx_version: int = None enable_fp_fusion: bool = True + launch_cooperative_grid: bool = False supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15") deprecated_fp8_dtypes: Tuple[str] = () default_dot_input_precision: str = "tf32" diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index e41b4a1386..8bfd010773 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -147,7 +147,7 @@ def format_of(ty): }[ty] args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) - format = "iiiKKOOOOO" + args_format + format = "iiiKKpOOOOO" + args_format args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' internal_args_list = [] @@ -234,19 +234,50 @@ def format_of(ty): }} #endif -static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ void *params[] = {{ {', '.join(params)} }}; if (gridX*gridY*gridZ > 0) {{ - if (num_ctas == 1) {{ + if ((num_ctas == 1) && (0 == launch_cooperative_grid)) {{ CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); + }} else if ((num_ctas == 1) && (0 != launch_cooperative_grid)) {{ + CUlaunchAttribute launchAttr[1]; + CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}}; + launchAttr[0] = coopAttr; + + CUlaunchConfig config; + config.gridDimX = gridX; + config.gridDimY = gridY; + config.gridDimZ = gridZ; + config.blockDimX = 32 * num_warps; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = shared_memory; + config.hStream = stream; + config.attrs = launchAttr; + config.numAttrs = 1; + + static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; + if (cuLaunchKernelExHandle == NULL) {{ + cuLaunchKernelExHandle = getLaunchKernelExHandle(); + }} + CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0)); + }} else {{ - CUlaunchAttribute launchAttr[2]; + CUlaunchAttribute launchAttr[3]; launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; launchAttr[0].value.clusterDim.x = clusterDimX; launchAttr[0].value.clusterDim.y = clusterDimY; launchAttr[0].value.clusterDim.z = clusterDimZ; launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; + + unsigned numAttrs = 2; + if (0 != launch_cooperative_grid) {{ + CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}}; + launchAttr[2] = coopAttr; + numAttrs = 3; + }} + CUlaunchConfig config; config.gridDimX = gridX * clusterDimX; config.gridDimY = gridY * clusterDimY; @@ -257,7 +288,7 @@ def format_of(ty): config.sharedMemBytes = shared_memory; config.hStream = stream; config.attrs = launchAttr; - config.numAttrs = 2; + config.numAttrs = numAttrs; static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; if (cuLaunchKernelExHandle == NULL) {{ cuLaunchKernelExHandle = getLaunchKernelExHandle(); @@ -382,6 +413,7 @@ def format_of(ty): int gridX, gridY, gridZ; uint64_t _stream; uint64_t _function; + int launch_cooperative_grid; PyObject *launch_enter_hook = NULL; PyObject *launch_exit_hook = NULL; PyObject *kernel_metadata = NULL; @@ -389,7 +421,7 @@ def format_of(ty): PyObject *global_scratch_obj = NULL; {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, - &_stream, &_function, &global_scratch_obj, + &_stream, &_function, &launch_cooperative_grid, &global_scratch_obj, &kernel_metadata, &launch_metadata, &launch_enter_hook, &launch_exit_hook{args_list})) {{ return NULL; @@ -423,7 +455,7 @@ def format_of(ty): {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; {"".join([f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" if ty == "nvTmaDesc" else "" for i, ty in signature.items()])}; Py_BEGIN_ALLOW_THREADS; - _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); + _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); Py_END_ALLOW_THREADS; if (PyErr_Occurred()) {{ return NULL; @@ -479,6 +511,7 @@ def __init__(self, src, metadata): self.launch = mod.launch self.global_scratch_size = metadata.global_scratch_size self.global_scratch_align = metadata.global_scratch_align + self.launch_cooperative_grid = metadata.launch_cooperative_grid def __call__(self, gridX, gridY, gridZ, stream, function, *args): if self.global_scratch_size > 0: @@ -487,7 +520,7 @@ def __call__(self, gridX, gridY, gridZ, stream, function, *args): global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream) else: global_scratch = None - self.launch(gridX, gridY, gridZ, stream, function, global_scratch, *args) + self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args) class CudaDriver(GPUDriver): diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td index 31b2646db8..ed2a2ec391 100644 --- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -43,12 +43,10 @@ def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> { let assemblyFormat = "attr-dict"; } -def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", -[DeclareOpInterfaceMethods, - AllTypesMatch<["input", "output"]>]> { +def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", [DeclareOpInterfaceMethods, + AllTypesMatch<["input", "output"]>]> { let arguments = (ins LLVM_AnyStruct:$input, I32Attr:$pendings); let results = (outs LLVM_AnyStruct:$output); - let assemblyFormat = "attr-dict"; let assemblyFormat = "$input attr-dict `:` type($input)"; } @@ -103,10 +101,23 @@ def NVGPU_ClusterWaitOp : NVGPU_Op<"cluster_wait", []> { } def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> { - let arguments = (ins LLVM_PointerShared:$addr, Variadic:$datas); + let arguments = ( + ins LLVM_PointerShared:$addr, + Variadic:$vals, + UnitAttr:$trans + ); let assemblyFormat = "operands attr-dict `:` type(operands)"; } +def NVGPU_LoadMatrixOp : NVGPU_Op<"ldmatrix", [MemoryEffects<[MemRead]>]> { + let arguments = ( + ins LLVM_PointerShared:$addr, + UnitAttr:$trans + ); + let results = (outs LLVM_AnyStruct:$result); + let assemblyFormat = "$addr attr-dict `:` functional-type($addr, $result)"; +} + def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> { let results = (outs I32:$result); let assemblyFormat = "attr-dict"; diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 8de0efefca..8906eb154a 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -23,21 +23,20 @@ using ttn::OperandsAndConstraints; namespace { -const std::string Wgmma_Fence_Op = "wgmma.fence.sync.aligned;"; -const std::string Wgmma_Commit_Group_Op = "wgmma.commit_group.sync.aligned;"; -const std::string Cluster_Wait_Op = "barrier.cluster.wait.aligned;"; -const std::string Fence_Mbarrier_Init_Op = - "fence.mbarrier_init.release.cluster;"; -const std::string Cluster_Cta_Id_Op = "{\n" - ".reg .u32 a<5>; \n" - "mov.u32 a0, %cluster_ctaid.x;\n" // x - "mov.u32 a1, %cluster_ctaid.y;\n" // y - "mov.u32 a2, %cluster_ctaid.z;\n" // z - "mov.u32 a3, %cluster_nctaid.x;\n" // nx - "mov.u32 a4, %cluster_nctaid.y;\n" // ny - "mad.lo.u32 a1, a2, a4, a1; \n" - "mad.lo.u32 $0, a1, a3, a0; \n" - "}"; +const std::string kWgmmaFenceOp = "wgmma.fence.sync.aligned;"; +const std::string kWgmmaCommitGroupOp = "wgmma.commit_group.sync.aligned;"; +const std::string kClusterWaitOp = "barrier.cluster.wait.aligned;"; +const std::string kFenceMbarrierInitOp = "fence.mbarrier_init.release.cluster;"; +const std::string kClusterCtaIdOp = "{\n" + ".reg .u32 a<5>; \n" + "mov.u32 a0, %cluster_ctaid.x;\n" // x + "mov.u32 a1, %cluster_ctaid.y;\n" // y + "mov.u32 a2, %cluster_ctaid.z;\n" // z + "mov.u32 a3, %cluster_nctaid.x;\n" // nx + "mov.u32 a4, %cluster_nctaid.y;\n" // ny + "mad.lo.u32 a1, a2, a4, a1; \n" + "mad.lo.u32 $0, a1, a3, a0; \n" + "}"; bool isNumber(const std::string &s) { return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { @@ -235,46 +234,138 @@ class ClusterArriveOpPattern : public OpRewritePattern { } }; -class StoreMatrixOpPattern : public OpRewritePattern { +// Base class for Matrix Operation Patterns +template +class MatrixOpPattern : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ttn::StoreMatrixOp op, + LogicalResult matchAndRewrite(MatrixOpType op, PatternRewriter &rewriter) const override { - return rewriteAsPtxAsm(op, rewriter, getPtxAsm(op), - getOperandsAndConstraints(op)); - } - - OperandsAndConstraints - getOperandsAndConstraints(ttn::StoreMatrixOp op) const { - OperandsAndConstraints operandsAndTypes; - auto addr = op.getAddr(); - auto datas = op.getDatas(); - operandsAndTypes.push_back({addr, "r"}); - for (unsigned i = 0; i < datas.size(); i++) { - operandsAndTypes.push_back({datas[i], "r"}); - } - return operandsAndTypes; + unsigned vecSize = getVectorSize(op); + bool trans = op.getTrans(); + // Template method for PTX assembly generation + std::string ptxAsm = + (llvm::Twine(ConcreteMatrixOpPattern::kOpCode) + + getPtxModifiers(vecSize, trans) + " " + getOperands(op, vecSize) + ";") + .str(); + + OperandsAndConstraints operandAndConstraints = + getOperandsAndConstraints(op, vecSize); + Constraints outputConstraints = getOutputConstraints(op, vecSize); + + return rewriteAsPtxAsm(op, rewriter, ptxAsm, operandAndConstraints, + outputConstraints); } - std::string getPtxAsm(ttn::StoreMatrixOp op) const { - auto datas = op.getDatas(); - std::string ptxAsm; - switch (datas.size()) { +protected: + // Shared helper methods + std::string getPtxModifiers(unsigned vecSize, bool trans) const { + auto ptxAsmBase = llvm::Twine(".sync.aligned.m8n8"); + const std::string suffix = trans ? ".trans.shared.b16" : ".shared.b16"; + switch (vecSize) { case 1: - ptxAsm = "stmatrix.sync.aligned.m8n8.x1.shared.b16 [$0], {$1};"; - break; + return (ptxAsmBase + ".x1" + suffix).str(); case 2: - ptxAsm = "stmatrix.sync.aligned.m8n8.x2.shared.b16 [$0], {$1, $2};"; - break; + return (ptxAsmBase + ".x2" + suffix).str(); case 4: - ptxAsm = - "stmatrix.sync.aligned.m8n8.x4.shared.b16 [$0], {$1, $2, $3, $4};"; - break; + return (ptxAsmBase + ".x4" + suffix).str(); default: - assert(false && "Invalid size"); + assert(false && "Invalid vector size"); } - return ptxAsm; + } + + std::string getPtxRegOperands(unsigned startIdx, unsigned count) const { + llvm::SmallString<20> regOperands; + llvm::raw_svector_ostream stream(regOperands); + stream << "{"; + for (unsigned i = 0; i < count; i++) { + stream << "$" + llvm::utostr(startIdx + i); + if (i != count - 1) + stream << ", "; + } + stream << "}"; + return std::string(regOperands.str()); + } + + std::string getPtxAddrOperand(unsigned idx) const { + return (llvm::Twine("[$") + llvm::utostr(idx) + "]").str(); + } + + virtual std::string getOperands(MatrixOpType op, unsigned vecSize) const = 0; + virtual OperandsAndConstraints + getOperandsAndConstraints(MatrixOpType op, unsigned vecSize) const = 0; + virtual Constraints getOutputConstraints(MatrixOpType op, + unsigned vecSize) const = 0; + virtual unsigned getVectorSize(MatrixOpType op) const = 0; +}; + +// StoreMatrixOp Pattern +class StoreMatrixOpPattern + : public MatrixOpPattern { +public: + using MatrixOpPattern::MatrixOpPattern; + static constexpr const char *kOpCode = "stmatrix"; + +protected: + unsigned getVectorSize(ttn::StoreMatrixOp op) const override { + return op.getVals().size(); + } + + std::string getOperands(ttn::StoreMatrixOp op, + unsigned vecSize) const override { + return (llvm::Twine(getPtxAddrOperand(0)) + ", " + + getPtxRegOperands(1, vecSize)) + .str(); + } + + OperandsAndConstraints + getOperandsAndConstraints(ttn::StoreMatrixOp op, + unsigned vecSize) const override { + OperandsAndConstraints constraints = {{op.getAddr(), "r"}}; + for (unsigned i = 0; i < vecSize; i++) { + constraints.push_back({op.getVals()[i], "r"}); + } + return constraints; + } + + Constraints getOutputConstraints(ttn::StoreMatrixOp op, + unsigned vecSize) const override { + return {}; // No output constraints for StoreMatrixOp + } +}; + +// LoadMatrixOp Pattern +class LoadMatrixOpPattern + : public MatrixOpPattern { +public: + using MatrixOpPattern::MatrixOpPattern; + static constexpr const char *kOpCode = "ldmatrix"; + +protected: + unsigned getVectorSize(ttn::LoadMatrixOp op) const override { + auto resultType = cast(op.getType()); + return resultType.getBody().size(); + } + + std::string getOperands(ttn::LoadMatrixOp op, + unsigned vecSize) const override { + return (llvm::Twine(getPtxRegOperands(0, vecSize)) + ", " + + getPtxAddrOperand(vecSize)) + .str(); + } + + OperandsAndConstraints + getOperandsAndConstraints(ttn::LoadMatrixOp op, + unsigned vecSize) const override { + return {{op.getAddr(), "r"}}; + } + + Constraints getOutputConstraints(ttn::LoadMatrixOp op, + unsigned vecSize) const override { + return Constraints(vecSize, "=r"); } }; @@ -507,17 +598,16 @@ class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase { #define POPULATE_NVGPU_OP(SRC_OP, ASM) \ patterns.add>(context, ASM, Constraints(), \ Constraints()); - POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, Wgmma_Fence_Op) - POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, Wgmma_Commit_Group_Op) - POPULATE_NVGPU_OP(ttn::ClusterWaitOp, Cluster_Wait_Op) + POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, kWgmmaFenceOp) + POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, kWgmmaCommitGroupOp) + POPULATE_NVGPU_OP(ttn::ClusterWaitOp, kClusterWaitOp) #undef POPULATE_NVGPU_OP patterns.add>( - context, Cluster_Cta_Id_Op, Constraints({"=r"}), Constraints()); + context, kClusterCtaIdOp, Constraints({"=r"}), Constraints()); - patterns - .add( - context); + patterns.add(context); if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) signalPassFailure(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt index 91ddfc2700..96727b3571 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(TritonNVIDIAGPUToLLVM ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp ConvertLayoutOpToLLVM.cpp + MemoryOpToLLVM.cpp DotOpToLLVM/MMAv2.cpp DotOpToLLVM/WGMMA.cpp DotOpToLLVM.cpp diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 95b8fb6461..30a740f8bc 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -4,125 +4,17 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/PatternMatch.h" -#include "triton/Analysis/Allocation.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -using ::mlir::LLVM::getMultiDimOffset; -using ::mlir::LLVM::getSharedMemoryObjectFromStruct; -using ::mlir::LLVM::getWrappedMultiDimOffset; -using ::mlir::LLVM::linearize; -using ::mlir::triton::gpu::DotOperandEncodingAttr; -using ::mlir::triton::gpu::getOrder; -using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getSizePerThread; -using ::mlir::triton::gpu::getTotalElemsPerThread; -using ::mlir::triton::gpu::SharedEncodingAttr; - -// Forward declarations - -namespace SharedToDotOperandMMAv2OrV3 { -Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, - Location loc, Value tensor, - DotOperandEncodingAttr bEncoding, - const SharedMemoryObject &smemObj, - const LLVMTypeConverter *typeConverter, Value thread); -} // namespace SharedToDotOperandMMAv2OrV3 - namespace { using namespace mlir; using namespace mlir::triton; using namespace mlir::triton::gpu; -struct LocalLoadOpConversion - : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern< - triton::gpu::LocalLoadOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MemDescType srcTy = op.getSrc().getType(); - RankedTensorType dstTy = op.getType(); - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - if (isa(dstLayout) && - isa( - cast(dstLayout).getParent())) { - return lowerSharedToDotOperand(op, adaptor, getTypeConverter(), rewriter); - } - return failure(); - } - -private: - // shared -> dot_operand if the result layout is mma - Value lowerSharedToDotOperandMMA( - triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, - const NvidiaMmaEncodingAttr &mmaLayout, - const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const { - auto loc = op.getLoc(); - auto src = op.getSrc(); - auto dst = op.getResult(); - bool isMMA = supportMMA(dst, mmaLayout.getVersionMajor()); - - auto llvmElemTy = - typeConverter->convertType(src.getType().getElementType()); - - auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), - llvmElemTy, rewriter); - Value res; - - if (isOuter) { - assert(false && "MMA Layout does not support outer product"); - return res; - } - - if (mmaLayout.isHopper() || mmaLayout.isAmpere()) { // tensor core v2 or v3 - if (mmaLayout.isHopper()) - assert(dotOperandLayout.getOpIdx() == 0 && - "MMAv3 can only have operand $b on shared memory"); - - res = SharedToDotOperandMMAv2OrV3::convertLayout( - dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout, - smemObj, typeConverter, getThreadId(rewriter, loc)); - } else { - assert(false && "Unsupported mma layout found"); - } - return res; - }; - - // shared -> mma_operand - LogicalResult - lowerSharedToDotOperand(triton::gpu::LocalLoadOp op, - triton::gpu::LocalLoadOpAdaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - auto dstEnc = cast(op.getType().getEncoding()); - auto sharedLayout = - cast(op.getSrc().getType().getEncoding()); - - int K; - if (dstEnc.getOpIdx() == 0) // $a - K = op.getType().getShape()[sharedLayout.getOrder()[0]]; - else // $b - K = op.getType().getShape()[sharedLayout.getOrder()[1]]; - bool isOuter = K == 1; - auto mmaLayout = cast(dstEnc.getParent()); - Value res = lowerSharedToDotOperandMMA(op, adaptor, typeConverter, rewriter, - mmaLayout, dstEnc, isOuter); - - rewriter.replaceOp(op, res); - return success(); - } -}; - struct ConvertLayoutOpConversion : public ConvertOpToLLVMPattern { public: @@ -190,7 +82,7 @@ struct ConvertLayoutOpConversion "Unexpected number of indices emitted"); for (unsigned i = 0; i < inIndices.size(); ++i) { - Value offset = linearize(rewriter, loc, inIndices[i], smemShape); + Value offset = LLVM::linearize(rewriter, loc, inIndices[i], smemShape); Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset); store(inVals[i], ptr); } @@ -220,9 +112,10 @@ struct ConvertLayoutOpConversion localCoord.push_back(urem(coord[d], srcShapePerCTACache[d])); } - Value remoteCTAId = - linearize(rewriter, loc, multiDimCTAId, srcCTAsPerCGA, srcCTAOrder); - Value localOffset = linearize(rewriter, loc, localCoord, smemShape); + Value remoteCTAId = LLVM::linearize(rewriter, loc, multiDimCTAId, + srcCTAsPerCGA, srcCTAOrder); + Value localOffset = + LLVM::linearize(rewriter, loc, localCoord, smemShape); Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, localOffset); outVals.push_back(targetInfo.loadDShared( @@ -326,10 +219,6 @@ struct ConvertLayoutOpConversion auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) { - if (srcTy.getElementType().getIntOrFloatBitWidth() == 16) { - rewriter.replaceOp(op, adaptor.getSrc()); - return success(); - } assert(srcTy.getElementType().getIntOrFloatBitWidth() == 8 && "Unsupported type size."); convertMMAV3To8BitsDotOperand(op, adaptor, rewriter); @@ -342,115 +231,15 @@ struct ConvertLayoutOpConversion const NVIDIA::TargetInfo &targetInfo; }; -struct LocalAllocOpConversion - : public ConvertOpToLLVMPattern { - LocalAllocOpConversion(const LLVMTypeConverter &converter, - const NVIDIA::TargetInfo &targetInfo, - PatternBenefit benefit = 1) - : ConvertOpToLLVMPattern(converter, benefit), - targetInfo(targetInfo) {} - - LogicalResult - matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getSrc()) - return failure(); - auto mmaEncoding = dyn_cast( - op.getSrc().getType().getEncoding()); - if (!mmaEncoding) - return failure(); - auto sharedLayout = - cast(op.getType().getEncoding()); - if (!sharedLayout.getHasLeadingOffset()) - return failure(); - int swizzleByteSize = 0; - if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2) - swizzleByteSize = 32; - else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4) - swizzleByteSize = 64; - else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8) - swizzleByteSize = 128; - else - return failure(); - - auto *ctx = rewriter.getContext(); - Location loc = op->getLoc(); - - RankedTensorType srcTy = op.getSrc().getType(); - SmallVector shape = - convertType(srcTy.getShape()); - auto order = sharedLayout.getOrder(); - if (!targetInfo.canUseStMatrix(srcTy, shape, shape, order, - swizzleByteSize)) { - return failure(); - } - auto layout = chooseStMatrixLayout(rewriter.getContext(), srcTy, shape, - shape, order, swizzleByteSize); - Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); - auto smemPtrTy = ptr_ty(ctx, 3); - - auto kRegister = str_attr("register"); - auto kLane = str_attr("lane"); - auto kWarp = str_attr("warp"); - auto kBlock = str_attr("block"); - - Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(layout.getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); - - auto regBase = applyLinearLayout(loc, rewriter, layout, - {{kRegister, i32_val(0)}, - {kLane, laneId}, - {kWarp, warpId}, - {kBlock, i32_val(0)}})[0] - .second; - auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - auto srcVec = layout.getNumConsecutiveInOut(); - Type llvmElemTy = typeConverter->convertType(srcTy.getElementType()); - for (int i = 0; i < srcVals.size(); i += srcVec) { - auto regIdx = - layout.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0] - .second; - Value offset = xor_(regBase, i32_val(regIdx)); - auto vecAddr = gep(smemPtrTy, llvmElemTy, smemBase, offset); - vecAddr.setInbounds(true); - SmallVector inValsVec; - for (int j = 0; j < srcVec; j++) - inValsVec.push_back(srcVals[i + j]); - Value valsVec = packLLVector(loc, inValsVec, rewriter); - targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec); - } - - auto resultTy = cast(op.getType()); - auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); - auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA, - sharedLayout, loc, rewriter); - auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); - rewriter.replaceOp(op, retVal); - return success(); - } - -private: - const NVIDIA::TargetInfo &targetInfo; -}; - } // namespace -void mlir::triton::NVIDIA::populateConvertLayoutOpToLLVMOptimizedPatterns( - LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(typeConverter, targetInfo, benefit); -} - void mlir::triton::NVIDIA::populateConvertLayoutOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit) { - // For now give ConvertLayoutOpConversion higher benefit, I can split before - // merging - patterns.add(typeConverter, targetInfo, benefit); - // Same default benefit - patterns.add(typeConverter, benefit); + // Give this convertLayoutOpConversion a higher benefit as it only matches + // optimized or cross CTA cases + patterns.add(typeConverter, targetInfo, + benefit.getBenefit() + 1); mlir::triton::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo, patterns, benefit); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index b4d84c0aeb..4b82755877 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -1,6 +1,7 @@ #include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" #include "Utility.h" #include "mlir/Support/LLVM.h" +#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" using namespace mlir; @@ -339,23 +340,10 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, if (batch != 0) stridedOffset = add( stridedOffset, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); - Value readPtr = gep(ptr_ty(ctx, 3), shemTy, ptr, stridedOffset); - - PTXBuilder builder; - // ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a - // thread. - auto resArgs = builder.newListOperand(4, "=r"); - auto addrArg = builder.newAddrOperand(readPtr, "r"); - - auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4") - ->o("trans", needTrans /*predicate*/) - .o("shared.b16"); - ldmatrix(resArgs, addrArg); - - // The result type is 4xi32, each i32 is composed of 2xf16 - // elements (adjacent two columns in a row) or a single f32 element. - Value resV4 = builder.launch(rewriter, loc, resTy); + auto ldMatrixOp = + rewriter.create(loc, resTy, readPtr, needTrans); + auto resV4 = ldMatrixOp.getResult(); return {extract_val(elemTy, resV4, 0), extract_val(elemTy, resV4, 1), extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)}; } else { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp new file mode 100644 index 0000000000..45881491c2 --- /dev/null +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -0,0 +1,228 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "TargetInfo.h" +#include "Utility.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace SharedToDotOperandMMAv2OrV3 { +Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, + Location loc, Value tensor, + DotOperandEncodingAttr bEncoding, + const SharedMemoryObject &smemObj, + const LLVMTypeConverter *typeConverter, Value thread); +} // namespace SharedToDotOperandMMAv2OrV3 + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +struct LocalLoadOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + triton::gpu::LocalLoadOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemDescType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + if (isa(dstLayout) && + isa( + cast(dstLayout).getParent())) { + auto dot = cast(dstLayout); + auto mma = cast(dot.getParent()); + auto shared = cast(srcLayout); + auto bitwidth = dstTy.getElementTypeBitWidth(); + auto vecWidth = 32 / bitwidth; + auto kWidth = dot.getKWidth(); + auto rank = dstTy.getRank(); + auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2; + auto needTrans = kOrder != shared.getOrder()[0]; + auto canUseLdmatrix = + (bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth); + if (mma.isHopper()) { + // I think we should be able to remove this condition, but it's here + // as the legacy ldmatrix path does not support it + canUseLdmatrix &= srcTy.getElementTypeBitWidth() * kWidth == 32; + } + // If we remove this one, ldmatrix will IMA. It can probably be relaxed + // though + canUseLdmatrix &= + srcTy.getShape()[0] >= 8 && + srcTy.getShape()[1] >= 4 * kWidth & dstTy.getRank() <= 2; + if (canUseLdmatrix) { + return lowerSharedToDotOperand(op, adaptor, getTypeConverter(), + rewriter); + } + } + return failure(); + } + +private: + // shared -> dot_operand if the result layout is mma + Value lowerSharedToDotOperandMMA( + triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + const NvidiaMmaEncodingAttr &mmaLayout, + const DotOperandEncodingAttr &dotOperandLayout) const { + auto loc = op.getLoc(); + auto src = op.getSrc(); + auto dst = op.getResult(); + bool isMMA = supportMMA(dst, mmaLayout.getVersionMajor()); + + auto llvmElemTy = + typeConverter->convertType(src.getType().getElementType()); + + auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + Value res; + + if (mmaLayout.isHopper() || mmaLayout.isAmpere()) { // tensor core v2 or v3 + if (mmaLayout.isHopper()) + assert(dotOperandLayout.getOpIdx() == 0 && + "Operand $b in MMAv3 can only be in shared memory"); + + res = SharedToDotOperandMMAv2OrV3::convertLayout( + dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout, + smemObj, typeConverter, getThreadId(rewriter, loc)); + } else { + assert(false && "Unsupported mma layout found"); + } + return res; + }; + + // shared -> mma_operand + LogicalResult + lowerSharedToDotOperand(triton::gpu::LocalLoadOp op, + triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto dstEnc = cast(op.getType().getEncoding()); + auto sharedLayout = + cast(op.getSrc().getType().getEncoding()); + + auto mmaLayout = cast(dstEnc.getParent()); + Value res = lowerSharedToDotOperandMMA(op, adaptor, typeConverter, rewriter, + mmaLayout, dstEnc); + + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct LocalAllocOpConversion + : public ConvertOpToLLVMPattern { + LocalAllocOpConversion(const LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getSrc()) + return failure(); + auto mmaEncoding = dyn_cast( + op.getSrc().getType().getEncoding()); + if (!mmaEncoding) + return failure(); + auto sharedLayout = + cast(op.getType().getEncoding()); + if (!sharedLayout.getHasLeadingOffset()) + return failure(); + int swizzleByteSize = 0; + if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2) + swizzleByteSize = 32; + else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4) + swizzleByteSize = 64; + else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8) + swizzleByteSize = 128; + else + return failure(); + + auto *ctx = rewriter.getContext(); + Location loc = op->getLoc(); + + RankedTensorType srcTy = op.getSrc().getType(); + SmallVector shape = + convertType(srcTy.getShape()); + auto order = sharedLayout.getOrder(); + if (!targetInfo.canUseStMatrix(srcTy, shape, shape, order, + swizzleByteSize)) { + return failure(); + } + auto layout = chooseStMatrixLayout(rewriter.getContext(), srcTy, shape, + shape, order, swizzleByteSize); + Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); + auto smemPtrTy = ptr_ty(ctx, 3); + + auto kRegister = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kBlock = str_attr("block"); + + Value threadId = getThreadId(rewriter, loc); + Value threadsPerWarp = i32_val(layout.getInDimSize(kLane)); + Value laneId = urem(threadId, threadsPerWarp); + Value warpId = udiv(threadId, threadsPerWarp); + + auto regBase = applyLinearLayout(loc, rewriter, layout, + {{kRegister, i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, i32_val(0)}})[0] + .second; + auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto srcVec = layout.getNumConsecutiveInOut(); + Type llvmElemTy = typeConverter->convertType(srcTy.getElementType()); + for (int i = 0; i < srcVals.size(); i += srcVec) { + auto regIdx = + layout.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0] + .second; + Value offset = xor_(regBase, i32_val(regIdx)); + auto vecAddr = gep(smemPtrTy, llvmElemTy, smemBase, offset); + vecAddr.setInbounds(true); + SmallVector inValsVec; + for (int j = 0; j < srcVec; j++) + inValsVec.push_back(srcVals[i + j]); + Value valsVec = packLLVector(loc, inValsVec, rewriter); + targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec); + } + + auto resultTy = cast(op.getType()); + auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); + auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA, + sharedLayout, loc, rewriter); + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } + +private: + const NVIDIA::TargetInfo &targetInfo; +}; +} // namespace + +void mlir::triton::NVIDIA::populateMemoryOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + // Backend optimized memory ops get higher benefit + patterns.add(typeConverter, targetInfo, + benefit.getBenefit() + 1); + patterns.add(typeConverter, benefit.getBenefit() + 1); + mlir::triton::populateMemoryOpToLLVMPatterns(typeConverter, targetInfo, + patterns, benefit); +} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h index 4060378fa4..a5bdacff9f 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -23,6 +23,11 @@ void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); +void populateMemoryOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + void populateConvertLayoutOpToLLVMOptimizedPatterns( LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index d749d44bc4..089e4aaebb 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -123,9 +123,6 @@ struct ConvertTritonGPUToLLVM RewritePatternSet patterns(context); int benefit = patternBenefitPrioritizeOverLLVMConversions; - mlir::triton::NVIDIA::populateConvertLayoutOpToLLVMOptimizedPatterns( - typeConverter, targetInfo, patterns, - patternBenefitConvertLayoutOptimizedPattern); mlir::triton::NVIDIA::populateConvertLayoutOpToLLVMPatterns( typeConverter, targetInfo, patterns, benefit); mlir::triton::NVIDIA::populateTMAToLLVMPatterns(typeConverter, targetInfo, @@ -171,8 +168,8 @@ struct ConvertTritonGPUToLLVM benefit); mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); - mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, targetInfo, - patterns, benefit); + mlir::triton::NVIDIA::populateMemoryOpToLLVMPatterns( + typeConverter, targetInfo, patterns, benefit); mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, patterns, benefit); mlir::triton::NVIDIA::populateUpcastMXFPToLLVMPatterns( diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp index 897172fd6d..129836d541 100644 --- a/unittest/Tools/LinearLayoutTest.cpp +++ b/unittest/Tools/LinearLayoutTest.cpp @@ -530,6 +530,59 @@ TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastedDims2) { EXPECT_EQ(c.compose(b), a.transposeOuts(llvm::to_vector(b.getOutDimNames()))); } +TEST_F(LinearLayoutTest, InvertAndCompose_IdentityInDim) { + SmallVector outDims = {S("dim0"), S("dim1"), S("dim2"), + S("dim3"), S("dim4"), S("dim5"), + S("dim6"), S("dim7"), S("dim8")}; + + LinearLayout src({{S("register"), + { + {0, 0, 0, 0, 0, 0, 0, 0, 1}, + {0, 0, 0, 0, 0, 0, 0, 1, 0}, + }}, + {S("lane"), + { + {0, 0, 0, 0, 0, 0, 1, 0, 0}, + {0, 0, 0, 0, 0, 1, 0, 0, 0}, + {0, 0, 0, 0, 1, 0, 0, 0, 0}, + {0, 0, 0, 1, 0, 0, 0, 0, 0}, + {0, 0, 1, 0, 0, 0, 0, 0, 0}, + }}, + {S("warp"), + { + {0, 1, 0, 0, 0, 0, 0, 0, 0}, + {1, 0, 0, 0, 0, 0, 0, 0, 0}, + }}, + {S("block"), {}}}, + outDims); + LinearLayout dst({{S("register"), + { + {0, 0, 0, 0, 0, 0, 0, 0, 1}, + {0, 0, 0, 0, 0, 0, 0, 1, 0}, + }}, + {S("lane"), + { + {1, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 1, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 1, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 1, 0, 0, 0, 0}, + }}, + {S("warp"), + { + {0, 0, 0, 0, 0, 1, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 1, 0, 0}, + }}, + {S("block"), {}}}, + outDims); + + LinearLayout cvt = dst.invertAndCompose(src); + SmallVector> k = { + {S("register"), 3}, {S("lane"), 0}, {S("warp"), 2}, {S("block"), 0}}; + + EXPECT_EQ(dst.apply(k), src.apply(cvt.apply(k))); +} + TEST_F(LinearLayoutTest, NumConsecutiveInOut) { EXPECT_EQ( 1,