Skip to content

Commit

Permalink
Merge commit '0ac0d2a43d2c420b7933eeb63de4b919a065e45c'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Jul 10, 2024
2 parents 3f071e6 + 0ac0d2a commit ac86467
Show file tree
Hide file tree
Showing 17 changed files with 185 additions and 44 deletions.
1 change: 1 addition & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def TT_BroadcastOp : TT_Op<"broadcast", [Pure,

let hasCanonicalizeMethod = 1;
let hasFolder = 1;
let hasVerifier = 1;
}

// cat is not `pure` because it may reorder elements
Expand Down
20 changes: 20 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,26 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
return {};
}

LogicalResult BroadcastOp::verify() {
auto src = getSrc();
auto srcTensorType = cast<RankedTensorType>(src.getType());
auto srcShape = srcTensorType.getShape();
auto result = getResult();
auto resultTensorType = cast<RankedTensorType>(result.getType());
auto resultShape = resultTensorType.getShape();
if (srcShape.size() != resultShape.size()) {
return emitError("rank of source must be same as rank of result");
}
for (int i = 0; i < srcShape.size(); i++) {
if (srcShape[i] != 1 && srcShape[i] != resultShape[i]) {
return emitError("Different dimensions at index ")
<< i << " between source and result. "
<< "Broadcast requires the source dimension to be 1.";
}
}
return success();
}

//-- MakeTensorPtrOp --
void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state,
Value base, ValueRange shape, ValueRange strides,
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
int n = mma.getInstrShape()[1];
int k = mma.getInstrShape()[2];
assert(m == 16);
assert(n == 16 || n == 32 || n == 64 || n == 128 || n == 256);
assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256);
assert(k == 8 || k == 16 || k == 32);

MLIRContext *ctx = mma.getContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,11 @@ scheduleRemainingToLastStage(scf::ForOp forOp, tt::CoarseSchedule &schedule,
for (auto user : op->getUsers()) {
if (opToCluster.count(user)) {
tt::CoarseSchedule::Cluster userCluster = opToCluster[user];
tt::CoarseSchedule::Cluster opCluster = schedule[op].second;
tt::CoarseSchedule::Cluster opCluster;
if (schedule.count(op))
opCluster = schedule[op].second;
else
opCluster = opToCluster[op];
if (*userCluster < *opCluster) {
opToCluster[user] = opCluster;
queue.push_back(user);
Expand Down
8 changes: 6 additions & 2 deletions python/triton/runtime/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,18 @@ def put(self, data, filename, binary=True) -> str:
rnd_id = str(uuid.uuid4())
# we use the PID in case a bunch of these around so we can see what PID made it
pid = os.getpid()
# use tempfile to be robust against program interruptions
temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}"
# use temp dir to be robust against program interruptions
temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.join(temp_dir, filename)

mode = "wb" if binary else "w"
with open(temp_path, mode) as f:
f.write(data)
# Replace is guaranteed to be atomic on POSIX systems if it succeeds
# so filepath cannot see a partial write
os.replace(temp_path, filepath)
os.removedirs(temp_dir)
return filepath


Expand Down
17 changes: 8 additions & 9 deletions python/tutorials/09-persistent-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,10 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, #
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N

offs_am = tl.arange(0, BLOCK_SIZE_M)
offs_bn = tl.arange(0, BLOCK_SIZE_N)

offs_am = tl.where(offs_am < M - start_m, offs_am, 0)
offs_bn = tl.where(offs_bn < N - start_n, offs_bn, 0)
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)

offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
Expand Down Expand Up @@ -186,10 +185,10 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, #

start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offs_am = tl.arange(0, BLOCK_SIZE_M)
offs_bn = tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < M - start_m, offs_am, 0)
offs_bn = tl.where(offs_bn < N - start_n, offs_bn, 0)
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
Expand Down
24 changes: 24 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1643,3 +1643,27 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
tt.func @int32_to_bf16(%arg0: tensor<256xi32, #blocked>) attributes {noinline = false} {
// CHECK-LABEL: @int32_to_bf16
// CHECK: llvm.sitofp %{{.*}} : i32 to bf16
%a = arith.sitofp %arg0 : tensor<256xi32, #blocked> to tensor<256xbf16, #blocked>
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
tt.func @bf16_to_int32(%arg0: tensor<256xbf16, #blocked>) attributes {noinline = false} {
// CHECK-LABEL: @bf16_to_int32
// CHECK: llvm.fptosi %{{.*}} : bf16 to i32
%a = arith.fptosi %arg0 : tensor<256xbf16, #blocked> to tensor<256xi32, #blocked>
tt.return
}
}
4 changes: 2 additions & 2 deletions test/Triton/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ tt.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
tt.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
%const = arith.constant dense<1.0> : tensor<8xf32>
%bst_out = tt.broadcast %const : tensor<8xf32> -> tensor<8x2xf32>
%const = arith.constant dense<1.0> : tensor<8x1xf32>
%bst_out = tt.broadcast %const : tensor<8x1xf32> -> tensor<8x2xf32>

// CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32>
tt.return %bst_out : tensor<8x2xf32>
Expand Down
18 changes: 18 additions & 0 deletions test/Triton/invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
// RUN: triton-opt --split-input-file %s --verify-diagnostics

tt.func @fn(%v: i32) {
%b = tt.splat %v : i32 -> tensor<128xi32>
// expected-error @+1 {{rank of source must be same as rank of result}}
%c = tt.broadcast %b : tensor<128xi32> -> tensor<128x32xi32>
tt.return
}

// -----

tt.func @fn(%v: i32) {
%b = tt.splat %v : i32 -> tensor<2x32xi32>
// expected-error @+1 {{Different dimensions at index 0 between source and result. Broadcast requires the source dimension to be 1.}}
%c = tt.broadcast %b : tensor<2x32xi32> -> tensor<128x32xi32>
tt.return
}

// -----

tt.func public @fn(%arg0: tensor<128xf32>) {
// expected-error @+1 {{packed_element}}
%a = tt.elementwise_inline_asm ""
Expand Down
12 changes: 6 additions & 6 deletions test/Triton/raise-block-pointer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ tt.func @test_addptr_broadcast(%arg0 : !tt.ptr<f32>) -> tensor<2x128xf32> {
// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]] : !tt.ptr<tensor<2x128xf32>>
// CHECK: tt.return %[[VAL_4]] : tensor<2x128xf32>
tt.func @test_addptr_broadcast_rank(%arg0 : !tt.ptr<f32>) -> tensor<2x128xf32> {
%cst = arith.constant dense<1> : tensor<128xi32>
%cst = arith.constant dense<1> : tensor<1x128xi32>
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<2x128x!tt.ptr<f32>>
%1 = tt.broadcast %cst : tensor<128xi32> -> tensor<2x128xi32>
%1 = tt.broadcast %cst : tensor<1x128xi32> -> tensor<2x128xi32>
%2 = tt.addptr %0, %1 : tensor<2x128x!tt.ptr<f32>>, tensor<2x128xi32>
%3 = tt.load %2 : tensor<2x128x!tt.ptr<f32>>
tt.return %3 : tensor<2x128xf32>
Expand All @@ -213,9 +213,9 @@ tt.func @test_addptr_broadcast_rank(%arg0 : !tt.ptr<f32>) -> tensor<2x128xf32> {
// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]] : !tt.ptr<tensor<128x2x128xf32>>
// CHECK: tt.return %[[VAL_4]] : tensor<128x2x128xf32>
tt.func @test_addptr_broadcast_rank_2(%arg0 : !tt.ptr<f32>) -> tensor<128x2x128xf32> {
%cst = arith.constant dense<1> : tensor<128x128xi32>
%cst = arith.constant dense<1> : tensor<128x1x128xi32>
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x2x128x!tt.ptr<f32>>
%1 = tt.broadcast %cst : tensor<128x128xi32> -> tensor<128x2x128xi32>
%1 = tt.broadcast %cst : tensor<128x1x128xi32> -> tensor<128x2x128xi32>
%2 = tt.addptr %0, %1 : tensor<128x2x128x!tt.ptr<f32>>, tensor<128x2x128xi32>
%3 = tt.load %2 : tensor<128x2x128x!tt.ptr<f32>>
tt.return %3 : tensor<128x2x128xf32>
Expand All @@ -229,9 +229,9 @@ tt.func @test_addptr_broadcast_rank_2(%arg0 : !tt.ptr<f32>) -> tensor<128x2x128x
// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]] : !tt.ptr<tensor<128x2x128xf32>>
// CHECK: tt.return %[[VAL_4]] : tensor<128x2x128xf32>
tt.func @test_addptr_broadcast_rank_3(%arg0 : !tt.ptr<f32>) -> tensor<128x2x128xf32> {
%cst = arith.constant dense<1> : tensor<128xi32>
%cst = arith.constant dense<1> : tensor<128x1x1xi32>
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x2x128x!tt.ptr<f32>>
%1 = tt.broadcast %cst : tensor<128xi32> -> tensor<128x2x128xi32>
%1 = tt.broadcast %cst : tensor<128x1x1xi32> -> tensor<128x2x128xi32>
%2 = tt.addptr %0, %1 : tensor<128x2x128x!tt.ptr<f32>>, tensor<128x2x128xi32>
%3 = tt.load %2 : tensor<128x2x128x!tt.ptr<f32>>
tt.return %3 : tensor<128x2x128xf32>
Expand Down
56 changes: 56 additions & 0 deletions test/TritonGPU/loop-pipeline-hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -710,3 +710,59 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}


// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}>
#shared = #triton_gpu.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: _kernel_matmul_dependency
tt.func public @_kernel_matmul_dependency(%arg0: tensor<128x128x!tt.ptr<f8E4M3FNUZ>, #blocked>, %arg1: !tt.ptr<f8E4M3FNUZ> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) attributes {noinline = false} {
%cst = arith.constant dense<0> : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%cst_0 = arith.constant 1.000000e+00 : f32
%c8_i32 = arith.constant 8 : i32
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%1 = tt.splat %arg1 : !tt.ptr<f8E4M3FNUZ> -> tensor<128x128x!tt.ptr<f8E4M3FNUZ>, #blocked1>
%2:4 = scf.for %arg6 = %c8_i32 to %arg3 step %c8_i32 iter_args(%arg7 = %c8_i32, %arg8 = %c8_i32, %arg9 = %cst_1, %arg10 = %arg5) -> (i32, i32, tensor<128x128xf32, #mma>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) : i32 {
%3 = arith.addi %arg7, %c8_i32 : i32
%4 = arith.cmpi eq, %3, %c8_i32 : i32
%5:2 = scf.if %4 -> (i32, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) {
%21 = arith.addi %arg8, %c8_i32 : i32
scf.yield %21, %arg5 : i32, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
} else {
scf.yield %arg8, %arg10 : i32, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
}
%6 = arith.cmpi eq, %3, %c8_i32 : i32
%7 = scf.if %6 -> (f32) {
scf.yield %cst_0 : f32
} else {
%21 = tt.load %arg4 : !tt.ptr<f32>
scf.yield %21 : f32
}
%8 = tt.splat %3 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%9 = arith.addi %8, %0 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1>
%11 = tt.broadcast %10 : tensor<128x1xi32, #blocked1> -> tensor<128x128xi32, #blocked1>
%12 = tt.addptr %1, %11 : tensor<128x128x!tt.ptr<f8E4M3FNUZ>, #blocked1>, tensor<128x128xi32, #blocked1>
%13 = tt.load %arg0 : tensor<128x128x!tt.ptr<f8E4M3FNUZ>, #blocked>
%14 = triton_gpu.local_alloc %13 : (tensor<128x128xf8E4M3FNUZ, #blocked>) -> !tt.memdesc<128x128xf8E4M3FNUZ, #shared>
%15 = tt.load %12 : tensor<128x128x!tt.ptr<f8E4M3FNUZ>, #blocked1>
%16 = triton_gpu.local_alloc %15 : (tensor<128x128xf8E4M3FNUZ, #blocked1>) -> !tt.memdesc<128x128xf8E4M3FNUZ, #shared1>
%17 = triton_nvidia_gpu.warp_group_dot %14, %16, %arg9 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !tt.memdesc<128x128xf8E4M3FNUZ, #shared> * !tt.memdesc<128x128xf8E4M3FNUZ, #shared1> -> tensor<128x128xf32, #mma>
%18 = tt.splat %7 : f32 -> tensor<128x128xf32, #mma>
%19 = arith.mulf %17, %18 : tensor<128x128xf32, #mma>
%20 = scf.if %6 -> (tensor<128x128xf32, #mma>) {
scf.yield %cst_1 : tensor<128x128xf32, #mma>
} else {
scf.yield %19 : tensor<128x128xf32, #mma>
}
scf.yield %3, %5#0, %20, %5#1 : i32, i32, tensor<128x128xf32, #mma>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
}
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -664,10 +664,6 @@ struct SIToFPOpConversion
auto outVals = cvtFunc(loc, rewriter, inVals);
assert(outVals.size() == 4);
return outVals;
} else if (outElemTy.isBF16()) {
auto value = rewriter.create<LLVM::SIToFPOp>(loc, f32_ty, operands[0][0]);
return {FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, value,
RoundingMode::RTNE)};
} else {
return {rewriter.create<LLVM::SIToFPOp>(loc, elemTy, operands[0][0])};
}
Expand All @@ -685,13 +681,7 @@ struct FPToSIOpConversion
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
auto inElemTy = getElementType(op.getIn());
if (inElemTy.isBF16()) {
auto value =
FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0]);
return {rewriter.create<LLVM::FPToSIOp>(loc, elemTy, value)};
} else {
return {rewriter.create<LLVM::FPToSIOp>(loc, elemTy, operands[0][0])};
}
return {rewriter.create<LLVM::FPToSIOp>(loc, elemTy, operands[0][0])};
}
};

Expand Down
4 changes: 2 additions & 2 deletions third_party/proton/csrc/lib/Profiler/CuptiProfiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ processActivityKernel(CuptiProfiler::CorrIdToExternIdMap &corrIdToExternId,
}
} else {
// Graph kernels
// A single grpah launch can trigger multiple kernels.
// A single graph launch can trigger multiple kernels.
// Our solution is to construct the following maps:
// --- Application threads ---
// 1. graphId -> numKernels
Expand Down Expand Up @@ -321,7 +321,7 @@ void CuptiProfiler::CuptiProfilerPimpl::doFlush() {
// This is a blocking call but it doesn’t issue any CUDA synchronization calls
// implicitly thus it’s not guaranteed that all activities are completed on
// the underlying devices.
// We do an "oppurtunistic" synchronization here to try to ensure that all
// We do an "opportunistic" synchronization here to try to ensure that all
// activities are completed on the current context.
// If the current context is not set, we don't do any synchronization.
CUcontext cuContext = nullptr;
Expand Down
1 change: 1 addition & 0 deletions third_party/proton/csrc/lib/Session/Session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ makeContextSource(const std::string &contextSourceName) {

void Session::activate() {
profiler->start();
profiler->flush();
profiler->registerData(data.get());
}

Expand Down
16 changes: 16 additions & 0 deletions third_party/proton/test/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,19 @@ def foo(x, size: tl.constexpr, y):
assert data[0]["children"][0]["children"][0]["frame"]["name"] == "foo_test_1ctas_1elems"
assert data[0]["children"][0]["children"][0]["metrics"]["flops32"] == 1.0
assert data[0]["children"][0]["children"][0]["metrics"]["Time (ns)"] > 0


def test_deactivate():
with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f:
session_id = proton.start(f.name.split(".")[0], hook="triton")
proton.deactivate(session_id)
torch.randn((10, 10), device="cuda")
proton.activate(session_id)
torch.zeros((10, 10), device="cuda")
proton.deactivate(session_id)
proton.finalize()
data = json.load(f)
# Root shouldn't have device id
assert "DeviceId" not in data[0]["metrics"]
assert len(data[0]["children"]) == 1
assert "DeviceId" in data[0]["children"][0]["metrics"]
2 changes: 2 additions & 0 deletions unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,8 @@ std::vector<NvidiaMmaLLTestParams> makeNvidiaMmaV3TestCases() {

// These shapes were captured from grep'ing the TTGIR generated by Triton unit
// tests.
addTests({16, 8, 8}, 4, {{16, 16}, {32, 16}, {32, 32}, {64, 64}});
addTests({16, 8, 16}, 4, {{16, 16}, {32, 16}, {32, 32}, {64, 64}});
addTests({16, 16, 8}, 4, {{16, 16}, {32, 16}, {32, 32}, {64, 64}});
addTests({16, 16, 16}, 4, {{64, 16}, {128, 16}, {128, 128}});
addTests({16, 16, 32}, 4, {{64, 16}, {128, 16}});
Expand Down
26 changes: 16 additions & 10 deletions unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "mlir/IR/MLIRContext.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/Support/Signals.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
Expand Down Expand Up @@ -364,16 +365,21 @@ TEST_F(LinearLayoutConversionsTest, MMAv2_Small3D) {
}

TEST_F(LinearLayoutConversionsTest, MMAv3_64x16) {
EXPECT_EQ(toLinearLayout({64, 16}, mma(3, 0, {16, 16, 8}, {4, 1}, {1, 1},
{1, 1}, {1, 0})),
LinearLayout(
{
{S("register"), {{0, 1}, {8, 0}, {0, 8}}},
{S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}},
{S("warp"), {{16, 0}, {32, 0}}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
SmallVector<SmallVector<unsigned>, 4> instrShapes = {
{16, 16, 8}, {16, 16, 8}, {16, 8, 8}};
for (auto instrShape : instrShapes) {
SCOPED_TRACE(triton::join(instrShape, ","));
EXPECT_EQ(toLinearLayout({64, 16}, mma(3, 0, instrShape, {4, 1}, {1, 1},
{1, 1}, {1, 0})),
LinearLayout(
{
{S("register"), {{0, 1}, {8, 0}, {0, 8}}},
{S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}},
{S("warp"), {{16, 0}, {32, 0}}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
}
}

TEST_F(LinearLayoutConversionsTest, MMAv3_128x16) {
Expand Down

0 comments on commit ac86467

Please sign in to comment.