diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index dfd1a25a13..6c93538a24 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -106,7 +106,7 @@ jobs: run: | if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then echo '::set-output name=matrix-CUDA::[["a100-runner-set"], ["h100-runner-set"]]' - echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"]]' + echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx942"]]' echo '::set-output name=matrix-MACOS::[["macos-latest"]]' else echo '::set-output name=matrix-CUDA::["ubuntu-latest"]' diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index bfed9fcd47..1b4c46a26c 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -115,7 +115,7 @@ jobs: run: | if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then echo '::set-output name=matrix-CUDA::[["a100-runner-set"], ["h100-runner-set"]]' - echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"]]' + echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx942"]]' echo '::set-output name=matrix-MACOS::[["macos-latest"]]' else echo '::set-output name=matrix-CUDA::["ubuntu-latest"]' diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 8ea01d2f1f..89e0b23e4c 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1299,9 +1299,8 @@ inline DenseMap getSwizzledSharedPtrs( idxCol = urem(idxCol, numElemsPerSwizzlingRowVal); strideRow = numElemsPerSwizzlingRowVal; } - if (auto add = dyn_cast_or_null(idxCol.getDefiningOp())) { - if (auto _cst = dyn_cast_or_null( - add.getRhs().getDefiningOp())) { + if (auto add = idxCol.getDefiningOp()) { + if (auto _cst = add.getRhs().getDefiningOp()) { unsigned cst = cast(_cst.getValue()).getValue().getSExtValue(); unsigned key = cst % (outVec * maxPhase); @@ -1310,9 +1309,8 @@ inline DenseMap getSwizzledSharedPtrs( immedateOffCol = cst / (outVec * maxPhase) * (outVec * maxPhase); } } - if (auto add = dyn_cast_or_null(idxRow.getDefiningOp())) { - if (auto _cst = dyn_cast_or_null( - add.getRhs().getDefiningOp())) { + if (auto add = idxRow.getDefiningOp()) { + if (auto _cst = add.getRhs().getDefiningOp()) { unsigned cst = mlir::cast(_cst.getValue()).getValue().getSExtValue(); unsigned key = cst % (perPhase * maxPhase); diff --git a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp index 6172c614aa..ae3aa63b85 100644 --- a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -67,6 +67,42 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { return amendedFuncOp; } + // Map the MLIR attribute `tt.nv_tma_desc` to the appropriate LLVM and NVVM + // attributes. + static void handleByvalTmaDescArgs(LLVM::LLVMFuncOp &llvmFuncOp) { + const bool isKernel = LLVM::isKernel(llvmFuncOp); + for (unsigned i = 0; i < llvmFuncOp.getNumArguments(); ++i) { + const auto attrs = llvmFuncOp.getArgAttrDict(i); + if (!attrs) { + continue; + } + + for (const auto &attr : attrs) { + if (attr.getName() == "tt.nv_tma_desc") { + const auto i32_type = + mlir::IntegerType::get(llvmFuncOp.getContext(), 32); + assert(attr.getValue() == mlir::IntegerAttr::get(i32_type, 1)); + assert(isKernel && + "tt.nv_tma_desc is not supported for device functions"); + + // See + // https://github.com/google/jax/blob/main/jaxlib/mosaic/gpu/passes.cc + mlir::BlockArgument arg = llvmFuncOp.getArgument(i); + const auto byteType = + mlir::IntegerType::get(llvmFuncOp.getContext(), 8); + const auto arrayType = mlir::LLVM::LLVMArrayType::get( + llvmFuncOp.getContext(), byteType, 128); + llvmFuncOp.setArgAttr(i, "llvm.byval", + mlir::TypeAttr::get(arrayType)); + llvmFuncOp.setArgAttr(i, "nvvm.grid_constant", + mlir::UnitAttr::get(llvmFuncOp.getContext())); + llvmFuncOp.setArgAttr(i, "llvm.align", + mlir::IntegerAttr::get(i32_type, 64)); + } + } + } + } + LogicalResult matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -105,6 +141,10 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { newFuncOp->setAttr("nvvm.reqntid", rewriter.getDenseI32ArrayAttr(32 * numWarps)); rewriter.eraseOp(funcOp); + + // Add attributes for by-value TMA descriptor args (nvidia) + handleByvalTmaDescArgs(newFuncOp); + return success(); } diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 8058f9d708..458a1ed9d9 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -164,7 +164,7 @@ applyLinearLayout(Location loc, RewriterBase &rewriter, // Manually constant-fold the layout where possible. SmallVector> constantIns; for (auto [inDimName, idx] : indices) { - if (auto constant = dyn_cast(idx.getDefiningOp())) { + if (auto constant = idx.getDefiningOp()) { constantIns.push_back( {inDimName, cast(constant.getValue()).getInt()}); } else { @@ -184,7 +184,7 @@ applyLinearLayout(Location loc, RewriterBase &rewriter, } for (auto [inDimName, idx] : indices) { - if (isa(idx.getDefiningOp())) { + if (idx.getDefiningOp()) { continue; } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index e96cd89f65..c72b92171f 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -91,8 +91,7 @@ struct CanonicalizeMaskedLoadPattern : public OpRewritePattern { if (!mask) return failure(); - auto constantMask = - llvm::dyn_cast_or_null(mask.getDefiningOp()); + auto constantMask = mask.getDefiningOp(); if (!constantMask) return failure(); @@ -159,8 +158,7 @@ struct CanonicalizeMaskedStorePattern : public OpRewritePattern { if (!mask) return failure(); - auto constantMask = - llvm::dyn_cast_or_null(mask.getDefiningOp()); + auto constantMask = mask.getDefiningOp(); if (!constantMask) return failure(); diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index c5d638754e..33c4516b47 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -113,8 +113,7 @@ class CombineSelectMaskedLoadPattern : public RewritePattern { Value falseValue = selectOp.getFalseValue(); Value condSelect = selectOp.getCondition(); - auto *loadOpCandidate = trueValue.getDefiningOp(); - auto loadOp = llvm::dyn_cast_or_null(loadOpCandidate); + auto loadOp = trueValue.getDefiningOp(); if (!loadOp) return failure(); @@ -122,8 +121,7 @@ class CombineSelectMaskedLoadPattern : public RewritePattern { if (!mask) return failure(); - auto *splatOpCandidate = mask.getDefiningOp(); - auto splatOp = llvm::dyn_cast_or_null(splatOpCandidate); + auto splatOp = mask.getDefiningOp(); if (!splatOp) return failure(); @@ -175,26 +173,21 @@ class CombineBroadcastMulReducePattern : public RewritePattern { if (!isReduceAdd) return failure(); // operand of reduce has to be mul - auto mulOp = llvm::dyn_cast_or_null( - reduceOp.getOperand(0).getDefiningOp()); + auto mulOp = reduceOp.getOperand(0).getDefiningOp(); if (!mulOp) return failure(); // mul operand has to be broadcast - auto broadcastLhsOp = llvm::dyn_cast_or_null( - mulOp.getOperand(0).getDefiningOp()); + auto broadcastLhsOp = mulOp.getOperand(0).getDefiningOp(); if (!broadcastLhsOp) return failure(); - auto broadcastRhsOp = llvm::dyn_cast_or_null( - mulOp.getOperand(1).getDefiningOp()); + auto broadcastRhsOp = mulOp.getOperand(1).getDefiningOp(); if (!broadcastRhsOp) return failure(); // broadcast operand is expand dims - auto expandLhsOp = llvm::dyn_cast_or_null( - broadcastLhsOp.getSrc().getDefiningOp()); + auto expandLhsOp = broadcastLhsOp.getSrc().getDefiningOp(); if (!expandLhsOp) return failure(); - auto expandRhsOp = llvm::dyn_cast_or_null( - broadcastRhsOp.getSrc().getDefiningOp()); + auto expandRhsOp = broadcastRhsOp.getSrc().getDefiningOp(); if (!expandRhsOp) return failure(); // get not-broadcast dimensions diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp index 3775b4f7d8..48e5480916 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp @@ -147,7 +147,7 @@ class TritonGPUOptimizeThreadLocalityPass return; auto argNum = yieldOpOperand.getOperandNumber(); auto oldAccum = forOp.getInitArgs()[argNum]; - auto cstOp = dyn_cast(oldAccum.getDefiningOp()); + auto cstOp = oldAccum.getDefiningOp(); if (!cstOp) return; reduceOps.insert(reduce); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index f835247ab1..07ef6f3f40 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -1397,16 +1397,18 @@ static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, while (isa_and_nonnull( transitiveOperand.getDefiningOp()) || isa(transitiveOperand)) { - if (auto blockArg = dyn_cast(transitiveOperand)) { - assert(blockArg.getOwner() == forOp.getBody()); + auto blockArg = dyn_cast(transitiveOperand); + if (blockArg && blockArg.getOwner() == forOp.getBody()) { transitiveOperand = cast(blockArg.getOwner()->getTerminator()) .getOperand(blockArg.getArgNumber() - 1); } - transitiveOperand = transitiveOperand.getDefiningOp()->getOperand(0); + if (Operation *def = transitiveOperand.getDefiningOp()) { + transitiveOperand = def->getOperand(0); + } } return forOp.isDefinedOutsideOfLoop(transitiveOperand) || - isa(transitiveOperand.getDefiningOp()); + transitiveOperand.getDefiningOp(); }; // We don't have to call checkOperand on getC() because it's always in diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index c0228fb54a..6c687d0891 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -1,56 +1,27 @@ import pytest import torch -import tempfile import triton import triton.language as tl from triton.tools.experimental_descriptor import create_1d_tma_descriptor, create_2d_tma_descriptor -def test_descriptor_load_ttgir(): - if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: - pytest.skip("Test requires Hopper target.") - return - device = "cuda" - SIZE = 128 +def create_tma_desc_gmem_ptr(ptr, dims, block_dims, element_size): + cpu_desc = torch.empty(128, device="cpu") + if len(dims) == 1: + triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size, + cpu_desc.data_ptr()) + else: + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], block_dims[1], + element_size, cpu_desc.data_ptr()) + return cpu_desc.cuda() - x = torch.randn(SIZE, dtype=torch.float32, device=device) - desc = create_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size()) - size_in_bytes = SIZE * x.element_size() - - ir = f""" - #blocked = #triton_gpu.blocked<{{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}}> - #shared = #triton_gpu.shared<{{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}}> - module attributes {{"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ - tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ - %c0_i32 = arith.constant 0 : i32 - %0 = tt.make_range {{end = {SIZE} : i32, start = 0 : i32}} : tensor<{SIZE}xi32, #blocked> - %1 = triton_gpu.local_alloc : () -> !tt.memdesc<{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable> - %2 = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.init_barrier %2, 1 : <1xi64, #shared, #triton_gpu.shared_memory, mutable> - %true = arith.constant 1 : i1 - triton_nvidia_gpu.barrier_expect %2, {size_in_bytes}, %true : <1xi64, #shared, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0_i32] %1, %2, %true : , <1xi64, #shared, #triton_gpu.shared_memory, mutable> -> <{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable> - triton_nvidia_gpu.wait_barrier %2, %c0_i32 : <1xi64, #shared, #triton_gpu.shared_memory, mutable> - %3 = triton_gpu.local_load %1 : !tt.memdesc<{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable> -> tensor<{SIZE}xf32, #blocked> - %4 = tt.splat %arg0 : !tt.ptr -> tensor<{SIZE}x!tt.ptr, #blocked> - %5 = tt.addptr %4, %0 : tensor<{SIZE}x!tt.ptr, #blocked>, tensor<{SIZE}xi32, #blocked> - tt.store %5, %3 : tensor<{SIZE}x!tt.ptr, #blocked> - tt.return - }} - }} - """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) - z_tri = torch.empty_like(x) - kernel[(1, 1, 1)](z_tri, desc) - assert torch.equal(x, z_tri) +TMA_FENCE_ASM: tl.constexpr = "fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg" -def test_experimetal_descriptor_load(): +@pytest.mark.parametrize("byval_tma", [True, False]) +def test_experimetal_descriptor_load(byval_tma): if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: pytest.skip("Test requires Hopper target.") return @@ -58,29 +29,34 @@ def test_experimetal_descriptor_load(): SIZE = 128 @triton.jit - def kernel(Z, desc, SIZE: tl.constexpr): + def kernel(Z, desc, SIZE: tl.constexpr, BYVAL_TMA: tl.constexpr): + if not BYVAL_TMA: + tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [desc], dtype=tl.int32, is_pure=False, pack=1) off_desc = 0 off = tl.arange(0, SIZE) x = tl._experimental_descriptor_load(desc, [off_desc], [SIZE], Z.dtype.element_ty) tl.store(Z + off, x) x = torch.randn(SIZE, dtype=torch.float32, device=device) - desc = create_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size()) + if byval_tma: + desc = create_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size()) + else: + desc = create_tma_desc_gmem_ptr(x.data_ptr(), [SIZE], [SIZE], x.element_size()) z_tri = torch.empty_like(x) - kernel[(1, )](z_tri, desc, SIZE=SIZE, num_warps=4) + compiled_kernel = kernel[(1, )](z_tri, desc, SIZE=SIZE, BYVAL_TMA=byval_tma, num_warps=4) assert torch.equal(x, z_tri) + if byval_tma: + assert ".param .align 64 .b8" in compiled_kernel.asm["ptx"] @triton.jit def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # - M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): - # TODO(embg) remove TMA fence after __grid_constant__ lands - tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", - [a_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) - tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", - [b_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) - tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", - [c_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) + M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BYVAL_TMA: tl.constexpr): + if not BYVAL_TMA: + tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [a_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) + tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [b_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) + tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [c_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -101,7 +77,8 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # @pytest.mark.parametrize("num_stages", [1, 4]) @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)]) -def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K): +@pytest.mark.parametrize("byval_tma", [True, False]) +def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma): if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: pytest.skip("Test requires Hopper target.") return @@ -111,13 +88,20 @@ def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K): A = torch.randn((M, K), dtype=torch.float16, device=device) B = torch.randn((K, N), dtype=torch.float16, device=device) C = torch.empty((M, N), dtype=torch.float16, device=device) - desc_a = create_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size()) - desc_b = create_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size()) - desc_c = create_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size()) + if byval_tma: + desc_a = create_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size()) + desc_b = create_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size()) + desc_c = create_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size()) + else: + desc_a = create_tma_desc_gmem_ptr(A.data_ptr(), [M, K], [BLOCK_M, BLOCK_K], A.element_size()) + desc_b = create_tma_desc_gmem_ptr(B.data_ptr(), [K, N], [BLOCK_K, BLOCK_N], B.element_size()) + desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size()) kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, - 1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps=8, - num_stages=num_stages) + 1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, BYVAL_TMA=byval_tma, + num_warps=8, num_stages=num_stages) ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16) torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) if BLOCK_M >= 64 and BLOCK_N >= 64: assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"] + if byval_tma: + assert ".param .align 64 .b8" in kernel.asm["ptx"] diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 419907da93..b05fdda8fa 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1578,6 +1578,10 @@ def serialized_add(data, Lock, SEM: tl.constexpr): tl.store(ptrs, tl.load(ptrs) + 1.0) + # insert barrier to set a fence between tl.store and + # tl.atomic_xchg in a block. + tl.debug_barrier() + # release lock tl.atomic_xchg(Lock, 0) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index b15f58df0c..7baf4b3c39 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -543,3 +543,38 @@ def inc_counter(*args, **kwargs): # test that we can't preload a mismatched kernel with pytest.raises(RuntimeError, match="Specialization data is for"): kernel_sub.preload(specialization_data) + + +def test_hooks(fresh_triton_cache) -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + # get the serialized specialization data + specialization_data = None + is_warmup = False + key = 0 + + def cache_hook(*args, **kwargs): + nonlocal specialization_data + specialization_data = kwargs["compile"]["specialization_data"] + nonlocal is_warmup + is_warmup = kwargs["compile"]["is_warmup"] + nonlocal key + key = kwargs["compile"]["key"] + + specialization_data_compiled = None + + def compiled_hook(*args, **kwargs): + nonlocal specialization_data_compiled + specialization_data_compiled = kwargs["compile"]["specialization_data"] + + JITFunction.cache_hook = cache_hook + JITFunction.compiled_hook = compiled_hook + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + assert specialization_data is not None and specialization_data_compiled == specialization_data + assert is_warmup is True + assert key in kernel_add.cache[torch.xpu.current_device()] diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 0ae66922f1..96b7346ac5 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -9,7 +9,7 @@ from .. import language from .._C.libtriton import ir from ..language import constexpr, tensor, str_to_ty -from ..language.core import _unwrap_if_constexpr +from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type from ..runtime.jit import _normalize_ty, get_jit_fn_file_line # ideally we wouldn't need any runtime component from ..runtime import JITFunction @@ -409,6 +409,11 @@ def visit_FunctionDef(self, node): if i in self.attributes: for name, value in self.attributes[i]: self.fn.set_arg_attr(idx, name, value) + + # Mark this argument as a pass-by-value TMA descriptor (nvidia) + if isinstance(self.prototype.param_types[idx], nv_tma_desc_type): + self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1) + arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) idx += 1 diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 6e8803638e..0a84bd86a5 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -84,6 +84,7 @@ permute, pi32_t, pointer_type, + nv_tma_desc_type, program_id, range, reduce, @@ -207,6 +208,7 @@ "philox_impl", "pi32_t", "pointer_type", + "nv_tma_desc_type", "program_id", "rand", "rand4x", @@ -259,6 +261,10 @@ def str_to_ty(name): const = True ty = str_to_ty(name) return pointer_type(element_ty=ty, const=const) + + if name == "nvTmaDesc": + return nv_tma_desc_type() + tys = { "fp8e4nv": float8e4nv, "fp8e4b8": float8e4b8, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 971d8a9f3a..cf86e9296a 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -568,7 +568,7 @@ def __init__(self, element_ty: dtype, address_space: int = 1, const: bool = Fals self.name = f'pointer<{element_ty}>' if not const else f'const_pointer<{element_ty}>' def to_ir(self, builder: ir.builder) -> ir.pointer_type: - return builder.get_ptr_ty(self.element_ty.to_ir(builder), 1) + return builder.get_ptr_ty(self.element_ty.to_ir(builder), self.address_space) def __str__(self): return self.name @@ -595,6 +595,13 @@ def scalar(self): return self +class nv_tma_desc_type(pointer_type): + + def __init__(self): + super().__init__(uint8, const=True, address_space=0) + self.name = 'nv_tma_desc_type' + + class block_type(dtype): def __init__(self, element_ty: dtype, shape: List): diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 30670f7404..ce5d91c237 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -63,11 +63,11 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): cc_cmd = [cxx] if icpx is not None: cc_cmd += ["-fsycl"] - cc_cmd += ["-O3"] else: - cc_cmd = [cc, "-O3"] + cc_cmd = [cc] - cc_cmd += [src, "-shared", "-fPIC", "-o", so] + # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 + cc_cmd += [src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] cc_cmd += [f'-l{lib}' for lib in libraries] cc_cmd += [f"-L{dir}" for dir in library_dirs] cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index afdb41d0b3..d65510624f 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -306,6 +306,8 @@ def mangle_type(arg, is_const=False): return "i64" elif isinstance(arg, float): return "fp32" + elif hasattr(arg, "tma_desc_cpu_ptr"): + return "nvTmaDesc" else: # dtypes are hashable so we can memoize this mapping: dsk = (arg.dtype, is_const) @@ -440,6 +442,9 @@ def create_function_from_signature(sig, kparams): class JITFunction(KernelInterface[T]): # Hook for inspecting compiled functions and modules cache_hook = None + # Hook to signal that a kernel is done compiling and inspect compiled function. + # cache_hook will always be called before compilation and compiled_hook after. + compiled_hook = None divisibility = 16 @staticmethod @@ -523,8 +528,11 @@ def _call_hook( constants, options, configs, + is_warmup, + before, ): - if JITFunction.cache_hook is None: + hook = JITFunction.cache_hook if before else JITFunction.compiled_hook + if hook is None: return False name = self.fn.__name__ @@ -553,14 +561,15 @@ def __init__(self, module, name, jit_function): 'extern_libs': options.extern_libs, 'configs': configs, 'specialization_data': specialization_data, + 'is_warmup': is_warmup, } - return JITFunction.cache_hook( + return hook( key=key, repr=repr, fn=JitFunctionInfo(module, name, self), compile={"key": key, **kwargs}, - is_manual_warmup=False, + is_manual_warmup=is_warmup, already_compiled=False, ) @@ -641,7 +650,7 @@ def run(self, *args, grid, warmup, **kwargs): if callable(arg): raise TypeError(f"Callable constexpr at index {i} is not supported") - if self._call_hook(key, signature, device, constants, options, configs): + if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True): return None # compile the kernel src = self.ASTSource(self, signature, constants, configs[0]) @@ -651,6 +660,7 @@ def run(self, *args, grid, warmup, **kwargs): options=options.__dict__, ) self.cache[device][key] = kernel + self._call_hook(key, signature, device, constants, options, configs, warmup, before=False) # Check that used global values have not changed. not_present = object() diff --git a/python/triton/tools/experimental_descriptor.py b/python/triton/tools/experimental_descriptor.py index c1265ba04b..fba3366c0c 100644 --- a/python/triton/tools/experimental_descriptor.py +++ b/python/triton/tools/experimental_descriptor.py @@ -3,26 +3,30 @@ import triton -# Constructs a 1D TMA descriptor in mutable GPU memory. -# -# Note: on the first use of a new descriptor, each SM must invalidate the descriptor's -# address in TMA cache via fence.proxy.tensormap::generic.acquire.gpu. +class TmaDescKernelParam: + TMA_DESC_SIZE = 128 + + def __init__(self, ptr, dims, block_dims, element_size): + self.desc = torch.empty(self.TMA_DESC_SIZE, dtype=torch.int8, device="cpu") + assert len(dims) == len(block_dims) + assert 1 <= len(dims) <= 2 + assert self.desc.data_ptr() % 64 == 0 + + if len(dims) == 1: + triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size, + self.desc.data_ptr()) + else: + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], + block_dims[1], element_size, self.desc.data_ptr()) + + # Return a CUtensorMap* pointer in host memory + def tma_desc_cpu_ptr(self): + return self.desc.data_ptr() + + def create_1d_tma_descriptor(ptr, dim, block_dim, element_size): - TMA_SIZE = 128 - desc = torch.empty(TMA_SIZE, dtype=torch.int8) - triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dim, block_dim, element_size, desc.data_ptr()) - gpu_desc = desc.cuda() - return gpu_desc + return TmaDescKernelParam(ptr, [dim], [block_dim], element_size) -# Constructs a 2D TMA descriptor in mutable GPU memory. -# -# Note: on the first use of a new descriptor, each SM must invalidate the descriptor's -# address in TMA cache via fence.proxy.tensormap::generic.acquire.gpu. def create_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size): - TMA_SIZE = 128 - desc = torch.empty(TMA_SIZE, dtype=torch.int8) - triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size, - desc.data_ptr()) - gpu_desc = desc.cuda() - return gpu_desc + return TmaDescKernelParam(ptr, [dim1, dim0], [block_dim1, block_dim0], element_size) diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 460c374d7f..fdbdbfecfb 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -259,14 +259,6 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # GROUP_SIZE_M: tl.constexpr, # FP8_OUTPUT: tl.constexpr, # NUM_SMS: tl.constexpr): # - # TODO(embg) remove TMA fence after __grid_constant__ lands - tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", - [a_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) - tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", - [b_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) - tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", - [c_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) - dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16 start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index c5dc86abd8..511f72bcd9 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -79,3 +79,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: byval_tma_desc + // CHECK: llvm.align = 64 + // CHECK: llvm.byval = !llvm.array<128 x i8> + // CHECK: nvvm.grid_constant + tt.func @byval_tma_desc(%desc: !tt.ptr {tt.nv_tma_desc = 1 : i32}) { + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp index 7d24bf7581..f2818297f5 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp @@ -112,7 +112,7 @@ class BypassEpilogueSMEM : public mlir::RewritePattern { chainedOps.push_back(chainedOp); } - auto cvtOp = dyn_cast(val.getDefiningOp()); + auto cvtOp = val.getDefiningOp(); if (!cvtOp) return mlir::failure(); diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 90f71138bc..bf1f066d55 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -110,6 +110,7 @@ def ty_to_cpp(ty): "fp32": "float", "f32": "float", "fp64": "double", + "nvTmaDesc": "CUtensorMap", }[ty] @@ -121,6 +122,9 @@ def make_launcher(constants, signature, ids): def _extracted_type(ty): if ty[0] == '*': return "PyObject*" + if ty == "nvTmaDesc": + return "PyObject*" + return ty_to_cpp(ty) def format_of(ty): @@ -143,6 +147,16 @@ def format_of(ty): format = "iiiKKOOOO" + args_format args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + internal_args_list = [] + for i, ty in signature.items(): + if ty[0] == "*": + internal_args_list.append(f"ptr_info{i}.dev_ptr") + elif ty == "nvTmaDesc": + # Note: we have to dereference the pointer + internal_args_list.append(f"*tma_ptr{i}") + else: + internal_args_list.append(f"_arg{i}") + # generate glue code params = [i for i in signature.keys() if i not in constants] src = f""" @@ -271,6 +285,52 @@ def format_of(ty): return ptr_info; }} +static inline CUtensorMap* getTmaDesc(PyObject *obj) {{ + if (sizeof(CUtensorMap*) != 8) {{ + PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation"); + return NULL; + }} + + PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_cpu_ptr"); + if (!method_handle) {{ + PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist"); + return NULL; + }} + + PyObject *empty_tuple = PyTuple_New(0); + if (!empty_tuple) {{ + Py_DECREF(method_handle); + PyErr_SetString(PyExc_SystemError, "Internal Python error!"); + return NULL; + }} + PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(method_handle); + if (!method_ret) {{ + PyErr_SetString(PyExc_SystemError, "Internal Python error!"); + return NULL; + }} + + if (!PyLong_Check(method_ret)) {{ + PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() must return 64-bit int"); + Py_DECREF(method_ret); + return NULL; + }} + + uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret); + Py_DECREF(method_ret); + if (!ptr_as_uint) {{ + PyErr_SetString(PyExc_ValueError, "received NULL ptr from tma_desc_cpu_ptr()"); + return NULL; + }} + if (ptr_as_uint % 64 != 0) {{ + PyErr_SetString(PyExc_ValueError, "tma_desc_cpu_ptr() must be 64-byte aligned"); + return NULL; + }} + + return (CUtensorMap*)(ptr_as_uint); +}} + static PyObject* launch(PyObject* self, PyObject* args) {{ int gridX, gridY, gridZ; uint64_t _stream; @@ -302,9 +362,10 @@ def format_of(ty): }} // raise exception asap - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + {"".join([f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" if ty == "nvTmaDesc" else "" for i, ty in signature.items()])}; Py_BEGIN_ALLOW_THREADS; - _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); Py_END_ALLOW_THREADS; if (PyErr_Occurred()) {{ return NULL;