Skip to content

Commit

Permalink
Merge commit '20dc69842d0141b12b87b9def19f2269d3ec9cdc'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Aug 20, 2024
2 parents 59539b0 + 20dc698 commit 57a0e6c
Show file tree
Hide file tree
Showing 22 changed files with 281 additions and 130 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ jobs:
run: |
if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then
echo '::set-output name=matrix-CUDA::[["a100-runner-set"], ["h100-runner-set"]]'
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"]]'
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx942"]]'
echo '::set-output name=matrix-MACOS::[["macos-latest"]]'
else
echo '::set-output name=matrix-CUDA::["ubuntu-latest"]'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ jobs:
run: |
if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then
echo '::set-output name=matrix-CUDA::[["a100-runner-set"], ["h100-runner-set"]]'
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"]]'
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx942"]]'
echo '::set-output name=matrix-MACOS::[["macos-latest"]]'
else
echo '::set-output name=matrix-CUDA::["ubuntu-latest"]'
Expand Down
10 changes: 4 additions & 6 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1299,9 +1299,8 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
idxCol = urem(idxCol, numElemsPerSwizzlingRowVal);
strideRow = numElemsPerSwizzlingRowVal;
}
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxCol.getDefiningOp())) {
if (auto _cst = dyn_cast_or_null<LLVM::ConstantOp>(
add.getRhs().getDefiningOp())) {
if (auto add = idxCol.getDefiningOp<LLVM::AddOp>()) {
if (auto _cst = add.getRhs().getDefiningOp<LLVM::ConstantOp>()) {
unsigned cst =
cast<IntegerAttr>(_cst.getValue()).getValue().getSExtValue();
unsigned key = cst % (outVec * maxPhase);
Expand All @@ -1310,9 +1309,8 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
immedateOffCol = cst / (outVec * maxPhase) * (outVec * maxPhase);
}
}
if (auto add = dyn_cast_or_null<LLVM::AddOp>(idxRow.getDefiningOp())) {
if (auto _cst = dyn_cast_or_null<LLVM::ConstantOp>(
add.getRhs().getDefiningOp())) {
if (auto add = idxRow.getDefiningOp<LLVM::AddOp>()) {
if (auto _cst = add.getRhs().getDefiningOp<LLVM::ConstantOp>()) {
unsigned cst =
mlir::cast<IntegerAttr>(_cst.getValue()).getValue().getSExtValue();
unsigned key = cst % (perPhase * maxPhase);
Expand Down
40 changes: 40 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,42 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
return amendedFuncOp;
}

// Map the MLIR attribute `tt.nv_tma_desc` to the appropriate LLVM and NVVM
// attributes.
static void handleByvalTmaDescArgs(LLVM::LLVMFuncOp &llvmFuncOp) {
const bool isKernel = LLVM::isKernel(llvmFuncOp);
for (unsigned i = 0; i < llvmFuncOp.getNumArguments(); ++i) {
const auto attrs = llvmFuncOp.getArgAttrDict(i);
if (!attrs) {
continue;
}

for (const auto &attr : attrs) {
if (attr.getName() == "tt.nv_tma_desc") {
const auto i32_type =
mlir::IntegerType::get(llvmFuncOp.getContext(), 32);
assert(attr.getValue() == mlir::IntegerAttr::get(i32_type, 1));
assert(isKernel &&
"tt.nv_tma_desc is not supported for device functions");

// See
// https://github.com/google/jax/blob/main/jaxlib/mosaic/gpu/passes.cc
mlir::BlockArgument arg = llvmFuncOp.getArgument(i);
const auto byteType =
mlir::IntegerType::get(llvmFuncOp.getContext(), 8);
const auto arrayType = mlir::LLVM::LLVMArrayType::get(
llvmFuncOp.getContext(), byteType, 128);
llvmFuncOp.setArgAttr(i, "llvm.byval",
mlir::TypeAttr::get(arrayType));
llvmFuncOp.setArgAttr(i, "nvvm.grid_constant",
mlir::UnitAttr::get(llvmFuncOp.getContext()));
llvmFuncOp.setArgAttr(i, "llvm.align",
mlir::IntegerAttr::get(i32_type, 64));
}
}
}
}

LogicalResult
matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -105,6 +141,10 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
newFuncOp->setAttr("nvvm.reqntid",
rewriter.getDenseI32ArrayAttr(32 * numWarps));
rewriter.eraseOp(funcOp);

// Add attributes for by-value TMA descriptor args (nvidia)
handleByvalTmaDescArgs(newFuncOp);

return success();
}

Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
// Manually constant-fold the layout where possible.
SmallVector<std::pair<StringAttr, int32_t>> constantIns;
for (auto [inDimName, idx] : indices) {
if (auto constant = dyn_cast<LLVM::ConstantOp>(idx.getDefiningOp())) {
if (auto constant = idx.getDefiningOp<LLVM::ConstantOp>()) {
constantIns.push_back(
{inDimName, cast<IntegerAttr>(constant.getValue()).getInt()});
} else {
Expand All @@ -184,7 +184,7 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
}

for (auto [inDimName, idx] : indices) {
if (isa<LLVM::ConstantOp>(idx.getDefiningOp())) {
if (idx.getDefiningOp<LLVM::ConstantOp>()) {
continue;
}

Expand Down
6 changes: 2 additions & 4 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ struct CanonicalizeMaskedLoadPattern : public OpRewritePattern<LoadOp> {
if (!mask)
return failure();

auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
auto constantMask = mask.getDefiningOp<arith::ConstantOp>();
if (!constantMask)
return failure();

Expand Down Expand Up @@ -159,8 +158,7 @@ struct CanonicalizeMaskedStorePattern : public OpRewritePattern<StoreOp> {
if (!mask)
return failure();

auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
auto constantMask = mask.getDefiningOp<arith::ConstantOp>();
if (!constantMask)
return failure();

Expand Down
21 changes: 7 additions & 14 deletions lib/Dialect/Triton/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,15 @@ class CombineSelectMaskedLoadPattern : public RewritePattern {
Value falseValue = selectOp.getFalseValue();
Value condSelect = selectOp.getCondition();

auto *loadOpCandidate = trueValue.getDefiningOp();
auto loadOp = llvm::dyn_cast_or_null<LoadOp>(loadOpCandidate);
auto loadOp = trueValue.getDefiningOp<LoadOp>();
if (!loadOp)
return failure();

Value mask = loadOp.getMask();
if (!mask)
return failure();

auto *splatOpCandidate = mask.getDefiningOp();
auto splatOp = llvm::dyn_cast_or_null<SplatOp>(splatOpCandidate);
auto splatOp = mask.getDefiningOp<SplatOp>();
if (!splatOp)
return failure();

Expand Down Expand Up @@ -175,26 +173,21 @@ class CombineBroadcastMulReducePattern : public RewritePattern {
if (!isReduceAdd)
return failure();
// operand of reduce has to be mul
auto mulOp = llvm::dyn_cast_or_null<arith::MulFOp>(
reduceOp.getOperand(0).getDefiningOp());
auto mulOp = reduceOp.getOperand(0).getDefiningOp<arith::MulFOp>();
if (!mulOp)
return failure();
// mul operand has to be broadcast
auto broadcastLhsOp = llvm::dyn_cast_or_null<BroadcastOp>(
mulOp.getOperand(0).getDefiningOp());
auto broadcastLhsOp = mulOp.getOperand(0).getDefiningOp<BroadcastOp>();
if (!broadcastLhsOp)
return failure();
auto broadcastRhsOp = llvm::dyn_cast_or_null<BroadcastOp>(
mulOp.getOperand(1).getDefiningOp());
auto broadcastRhsOp = mulOp.getOperand(1).getDefiningOp<BroadcastOp>();
if (!broadcastRhsOp)
return failure();
// broadcast operand is expand dims
auto expandLhsOp = llvm::dyn_cast_or_null<ExpandDimsOp>(
broadcastLhsOp.getSrc().getDefiningOp());
auto expandLhsOp = broadcastLhsOp.getSrc().getDefiningOp<ExpandDimsOp>();
if (!expandLhsOp)
return failure();
auto expandRhsOp = llvm::dyn_cast_or_null<ExpandDimsOp>(
broadcastRhsOp.getSrc().getDefiningOp());
auto expandRhsOp = broadcastRhsOp.getSrc().getDefiningOp<ExpandDimsOp>();
if (!expandRhsOp)
return failure();
// get not-broadcast dimensions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class TritonGPUOptimizeThreadLocalityPass
return;
auto argNum = yieldOpOperand.getOperandNumber();
auto oldAccum = forOp.getInitArgs()[argNum];
auto cstOp = dyn_cast<arith::ConstantOp>(oldAccum.getDefiningOp());
auto cstOp = oldAccum.getDefiningOp<arith::ConstantOp>();
if (!cstOp)
return;
reduceOps.insert(reduce);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1397,16 +1397,18 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
while (isa_and_nonnull<ttg::ConvertLayoutOp, tt::TransOp>(
transitiveOperand.getDefiningOp()) ||
isa<BlockArgument>(transitiveOperand)) {
if (auto blockArg = dyn_cast<BlockArgument>(transitiveOperand)) {
assert(blockArg.getOwner() == forOp.getBody());
auto blockArg = dyn_cast<BlockArgument>(transitiveOperand);
if (blockArg && blockArg.getOwner() == forOp.getBody()) {
transitiveOperand =
cast<scf::YieldOp>(blockArg.getOwner()->getTerminator())
.getOperand(blockArg.getArgNumber() - 1);
}
transitiveOperand = transitiveOperand.getDefiningOp()->getOperand(0);
if (Operation *def = transitiveOperand.getDefiningOp()) {
transitiveOperand = def->getOperand(0);
}
}
return forOp.isDefinedOutsideOfLoop(transitiveOperand) ||
isa<ttg::MemDescSubviewOp>(transitiveOperand.getDefiningOp());
transitiveOperand.getDefiningOp<ttg::MemDescSubviewOp>();
};

// We don't have to call checkOperand on getC() because it's always in
Expand Down
100 changes: 42 additions & 58 deletions python/test/unit/hopper/test_experimental_tma.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,62 @@
import pytest
import torch
import tempfile

import triton
import triton.language as tl
from triton.tools.experimental_descriptor import create_1d_tma_descriptor, create_2d_tma_descriptor


def test_descriptor_load_ttgir():
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9:
pytest.skip("Test requires Hopper target.")
return
device = "cuda"
SIZE = 128
def create_tma_desc_gmem_ptr(ptr, dims, block_dims, element_size):
cpu_desc = torch.empty(128, device="cpu")
if len(dims) == 1:
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size,
cpu_desc.data_ptr())
else:
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], block_dims[1],
element_size, cpu_desc.data_ptr())
return cpu_desc.cuda()

x = torch.randn(SIZE, dtype=torch.float32, device=device)
desc = create_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size())
size_in_bytes = SIZE * x.element_size()

ir = f"""
#blocked = #triton_gpu.blocked<{{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}}>
#shared = #triton_gpu.shared<{{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}}>
module attributes {{"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
tt.func public @kernel(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i8> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{
%c0_i32 = arith.constant 0 : i32
%0 = tt.make_range {{end = {SIZE} : i32, start = 0 : i32}} : tensor<{SIZE}xi32, #blocked>
%1 = triton_gpu.local_alloc : () -> !tt.memdesc<{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable>
%2 = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.init_barrier %2, 1 : <1xi64, #shared, #triton_gpu.shared_memory, mutable>
%true = arith.constant 1 : i1
triton_nvidia_gpu.barrier_expect %2, {size_in_bytes}, %true : <1xi64, #shared, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0_i32] %1, %2, %true : <i8>, <1xi64, #shared, #triton_gpu.shared_memory, mutable> -> <{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable>
triton_nvidia_gpu.wait_barrier %2, %c0_i32 : <1xi64, #shared, #triton_gpu.shared_memory, mutable>
%3 = triton_gpu.local_load %1 : !tt.memdesc<{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable> -> tensor<{SIZE}xf32, #blocked>
%4 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<{SIZE}x!tt.ptr<f32>, #blocked>
%5 = tt.addptr %4, %0 : tensor<{SIZE}x!tt.ptr<f32>, #blocked>, tensor<{SIZE}xi32, #blocked>
tt.store %5, %3 : tensor<{SIZE}x!tt.ptr<f32>, #blocked>
tt.return
}}
}}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)

z_tri = torch.empty_like(x)
kernel[(1, 1, 1)](z_tri, desc)
assert torch.equal(x, z_tri)
TMA_FENCE_ASM: tl.constexpr = "fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg"


def test_experimetal_descriptor_load():
@pytest.mark.parametrize("byval_tma", [True, False])
def test_experimetal_descriptor_load(byval_tma):
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9:
pytest.skip("Test requires Hopper target.")
return
device = "cuda"
SIZE = 128

@triton.jit
def kernel(Z, desc, SIZE: tl.constexpr):
def kernel(Z, desc, SIZE: tl.constexpr, BYVAL_TMA: tl.constexpr):
if not BYVAL_TMA:
tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [desc], dtype=tl.int32, is_pure=False, pack=1)
off_desc = 0
off = tl.arange(0, SIZE)
x = tl._experimental_descriptor_load(desc, [off_desc], [SIZE], Z.dtype.element_ty)
tl.store(Z + off, x)

x = torch.randn(SIZE, dtype=torch.float32, device=device)
desc = create_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size())
if byval_tma:
desc = create_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size())
else:
desc = create_tma_desc_gmem_ptr(x.data_ptr(), [SIZE], [SIZE], x.element_size())
z_tri = torch.empty_like(x)
kernel[(1, )](z_tri, desc, SIZE=SIZE, num_warps=4)
compiled_kernel = kernel[(1, )](z_tri, desc, SIZE=SIZE, BYVAL_TMA=byval_tma, num_warps=4)
assert torch.equal(x, z_tri)
if byval_tma:
assert ".param .align 64 .b8" in compiled_kernel.asm["ptx"]


@triton.jit
def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):
# TODO(embg) remove TMA fence after __grid_constant__ lands
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[a_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[b_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[c_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
BYVAL_TMA: tl.constexpr):
if not BYVAL_TMA:
tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [a_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [b_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise(TMA_FENCE_ASM, "=r, l", [c_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)

pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
Expand All @@ -101,7 +77,8 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #

@pytest.mark.parametrize("num_stages", [1, 4])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)])
def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K):
@pytest.mark.parametrize("byval_tma", [True, False])
def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma):
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9:
pytest.skip("Test requires Hopper target.")
return
Expand All @@ -111,13 +88,20 @@ def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K):
A = torch.randn((M, K), dtype=torch.float16, device=device)
B = torch.randn((K, N), dtype=torch.float16, device=device)
C = torch.empty((M, N), dtype=torch.float16, device=device)
desc_a = create_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size())
desc_b = create_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size())
desc_c = create_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size())
if byval_tma:
desc_a = create_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size())
desc_b = create_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size())
desc_c = create_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size())
else:
desc_a = create_tma_desc_gmem_ptr(A.data_ptr(), [M, K], [BLOCK_M, BLOCK_K], A.element_size())
desc_b = create_tma_desc_gmem_ptr(B.data_ptr(), [K, N], [BLOCK_K, BLOCK_N], B.element_size())
desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size())
kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1,
1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps=8,
num_stages=num_stages)
1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, BYVAL_TMA=byval_tma,
num_warps=8, num_stages=num_stages)
ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16)
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
if BLOCK_M >= 64 and BLOCK_N >= 64:
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"]
if byval_tma:
assert ".param .align 64 .b8" in kernel.asm["ptx"]
4 changes: 4 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,6 +1578,10 @@ def serialized_add(data, Lock, SEM: tl.constexpr):

tl.store(ptrs, tl.load(ptrs) + 1.0)

# insert barrier to set a fence between tl.store and
# tl.atomic_xchg in a block.
tl.debug_barrier()

# release lock
tl.atomic_xchg(Lock, 0)

Expand Down
Loading

0 comments on commit 57a0e6c

Please sign in to comment.