Skip to content

Commit

Permalink
[AMD]Enable EmitIndicesTest (#3167)
Browse files Browse the repository at this point in the history
Allows to test emmitting indices from different layouts. Valid also for
AMD target device.

Signed-off-by: joviliast <iveselov.nn@gmail.com>
  • Loading branch information
joviliast authored Mar 11, 2024
1 parent eb613cd commit d04f288
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 107 deletions.
1 change: 0 additions & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,3 @@ add_triton_library(TritonAMDGPUToLLVM
)

target_compile_definitions(TritonAMDGPUToLLVM PUBLIC USE_ROCM)

4 changes: 3 additions & 1 deletion unittest/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
function(add_triton_ut)
set(options)
set(oneValueArgs NAME)
set(multiValueArgs SRCS LIBS)
set(multiValueArgs SRCS LIBS DEFS)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${__NAME}
COMMAND ${__NAME})
Expand All @@ -29,6 +29,8 @@ function(add_triton_ut)

target_compile_options(${__NAME} PRIVATE -fno-rtti)

target_compile_definitions(${__NAME} PRIVATE ${__DEFS})

# Without the TEST_DISCOVERY_TIMEOUT, the tests randomly time out on my mac
# laptop. I think the issue may be that the very first time you run a program
# it's a bit slow.
Expand Down
10 changes: 9 additions & 1 deletion unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@ add_triton_ut(
)

add_triton_ut(
NAME TestEmitIndices
NAME TestEmitIndicesNvidia
SRCS EmitIndicesTest.cpp DumpLayout.cpp
LIBS TritonGPUIR TritonNvidiaGPUIR TritonNVIDIAGPUToLLVM
DEFS NVIDIA_TARGET=1
)

add_triton_ut(
NAME TestEmitIndicesAMD
SRCS EmitIndicesTest.cpp DumpLayout.cpp
LIBS TritonGPUIR TritonAMDGPUToLLVM
DEFS AMD_TARGET=1
)
56 changes: 36 additions & 20 deletions unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,51 @@
*/

#include "DumpLayout.h"

#ifdef AMD_TARGET
#include "amd/lib/TritonAMDGPUToLLVM/Utility.h"
#else
#include "nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"

#endif
namespace mlir {
namespace triton {
namespace gpu {

namespace {

#ifdef AMD_TARGET
Value getMockSmemBaseImpl([[maybe_unused]] IRRewriter &rewriter,
[[maybe_unused]] Location loc) {
return i32_val(0);
}
#else
Value getMockSmemBaseImpl(IRRewriter &rewriter, Location loc) {
Value mockSmemBase =
LLVM::NVIDIA::getSRegValue(rewriter, loc, "%mock_smem_base");
auto llPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
auto cast = rewriter.create<UnrealizedConversionCastOp>(
loc, TypeRange{llPtrTy}, ValueRange{mockSmemBase});
return cast.getResult(0);
}
#endif

//===----------------------------------------------------------------------===//
// IndexEmitter
//===----------------------------------------------------------------------===//

class IndexEmitter {
public:
IndexEmitter(MLIRContext *context_)
: context(context_), option(context), typeConverter(context, option),
rewriter(context), loc(UnknownLoc::get(context)) {
rewriter.setInsertionPointToStart(&block);
: context(context_), option(context), rewriter(context),
loc(UnknownLoc::get(context)) {
mlir::OpBuilder builder(context);
std::vector<mlir::Type> inTypes{};
std::vector<mlir::Type> outTypes{};
auto funcTy = builder.getFunctionType(inTypes, outTypes);
auto func = builder.create<mlir::triton::FuncOp>(loc, "test_func", funcTy);
auto mlirModule = mlir::ModuleOp::create(loc);
mlirModule.push_back(func);
auto *block = func.addEntryBlock();
rewriter.setInsertionPointToStart(block);
}

llvm::SmallVector<llvm::SmallVector<Value>>
Expand All @@ -56,28 +81,17 @@ class IndexEmitter {
Type elemTy, llvm::ArrayRef<int64_t> shape,
bool withCTAOffset) {
auto srcTy = RankedTensorType::get(shape, elemTy, srcLayout);
SharedMemoryObject smemObj(getMockSmemBase(), elemTy, shape,
sharedLayout.getOrder(), loc, rewriter);
SharedMemoryObject smemObj(getMockSmemBaseImpl(rewriter, loc), elemTy,
shape, sharedLayout.getOrder(), loc, rewriter);
return getSwizzledSharedPtrs(loc, /*inVec=*/1, srcTy, sharedLayout, elemTy,
smemObj, rewriter, smemObj.offsets,
smemObj.strides);
}

private:
Value getMockSmemBase() {
Value mockSmemBase =
mlir::LLVM::NVIDIA::getSRegValue(rewriter, loc, "%mock_smem_base");
auto llPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
auto cast = rewriter.create<UnrealizedConversionCastOp>(
loc, TypeRange{llPtrTy}, ValueRange{mockSmemBase});
return cast.getResult(0);
}

// Non-static members are initialized in declaration order
MLIRContext *context;
LowerToLLVMOptions option;
TritonGPUToLLVMTypeConverter typeConverter;
Block block;
IRRewriter rewriter;
Location loc;
};
Expand Down Expand Up @@ -151,6 +165,8 @@ int eval(Value value, int ctaid, int tid) {
return eval(xorOp.getLhs(), ctaid, tid) ^ eval(xorOp.getRhs(), ctaid, tid);
} else if (auto trunciOp = llvm::dyn_cast<arith::TruncIOp>(op)) {
return eval(trunciOp.getIn(), ctaid, tid);
} else if (auto idxCastOp = llvm::dyn_cast<arith::IndexCastOp>(op)) {
return eval(idxCastOp.getIn(), ctaid, tid);
} else if (auto castOp = llvm::dyn_cast<UnrealizedConversionCastOp>(op)) {
return eval(castOp.getOperand(0), ctaid, tid);
} else if (auto threadOp = llvm::dyn_cast<mlir::gpu::ThreadIdOp>(op)) {
Expand Down Expand Up @@ -181,7 +197,7 @@ std::string dumpDistributedLayout(Attribute layout,
assert(shape.size() <= 2 &&
"High order tensor is not supported in dumpLayout");

int numThreads = 32 * getNumWarpsPerCTA(layout);
int numThreads = getWarpSize(layout) * getNumWarpsPerCTA(layout);
int numCTAs = getNumCTAs(layout);
auto f16Ty = FloatType::getF16(layout.getContext());
int numElems = getTotalElemsPerThread(layout, shape, f16Ty);
Expand Down
170 changes: 86 additions & 84 deletions unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ TEST_F(EmitIndicesTest, BlockedLayout_SingleCTA_Vectorize) {
/*order=*/{1, 0}, /*refStr=*/refStr);
}

// FIXME: These tests are temporarily disabled due to ctaid.x|y|z are swapped
#ifdef TEST_FAILED
TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAOrder_1_0) {
// clang-format off
std::string refStr =
Expand Down Expand Up @@ -476,6 +478,89 @@ TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAWrapBeforeBroadcast_Dim0) {
/*refStr=*/refStr);
}

TEST_F(EmitIndicesTest, SliceLayout_MultiCTA) {
// clang-format off
std::string refStr =
"CTA0: T0:0|CTA0: T1:0|CTA0: T2:0|CTA0: T3:0 | CTA1: T0:0|CTA1: T1:0|CTA1: T2:0|CTA1: T3:0,"
"CTA0: T4:0|CTA0: T5:0|CTA0: T6:0|CTA0: T7:0 | CTA1: T4:0|CTA1: T5:0|CTA1: T6:0|CTA1: T7:0,"
"CTA0: T8:0|CTA0: T9:0|CTA0:T10:0|CTA0:T11:0 | CTA1: T8:0|CTA1: T9:0|CTA1:T10:0|CTA1:T11:0,"
"CTA0:T12:0|CTA0:T13:0|CTA0:T14:0|CTA0:T15:0 | CTA1:T12:0|CTA1:T13:0|CTA1:T14:0|CTA1:T15:0,"
"CTA0:T16:0|CTA0:T17:0|CTA0:T18:0|CTA0:T19:0 | CTA1:T16:0|CTA1:T17:0|CTA1:T18:0|CTA1:T19:0,"
"CTA0:T20:0|CTA0:T21:0|CTA0:T22:0|CTA0:T23:0 | CTA1:T20:0|CTA1:T21:0|CTA1:T22:0|CTA1:T23:0,"
"CTA0:T24:0|CTA0:T25:0|CTA0:T26:0|CTA0:T27:0 | CTA1:T24:0|CTA1:T25:0|CTA1:T26:0|CTA1:T27:0,"
"CTA0:T28:0|CTA0:T29:0|CTA0:T30:0|CTA0:T31:0 | CTA1:T28:0|CTA1:T29:0|CTA1:T30:0|CTA1:T31:0,"

"CTA2: T0:0|CTA2: T1:0|CTA2: T2:0|CTA2: T3:0 | CTA3: T0:0|CTA3: T1:0|CTA3: T2:0|CTA3: T3:0,"
"CTA2: T4:0|CTA2: T5:0|CTA2: T6:0|CTA2: T7:0 | CTA3: T4:0|CTA3: T5:0|CTA3: T6:0|CTA3: T7:0,"
"CTA2: T8:0|CTA2: T9:0|CTA2:T10:0|CTA2:T11:0 | CTA3: T8:0|CTA3: T9:0|CTA3:T10:0|CTA3:T11:0,"
"CTA2:T12:0|CTA2:T13:0|CTA2:T14:0|CTA2:T15:0 | CTA3:T12:0|CTA3:T13:0|CTA3:T14:0|CTA3:T15:0,"
"CTA2:T16:0|CTA2:T17:0|CTA2:T18:0|CTA2:T19:0 | CTA3:T16:0|CTA3:T17:0|CTA3:T18:0|CTA3:T19:0,"
"CTA2:T20:0|CTA2:T21:0|CTA2:T22:0|CTA2:T23:0 | CTA3:T20:0|CTA3:T21:0|CTA3:T22:0|CTA3:T23:0,"
"CTA2:T24:0|CTA2:T25:0|CTA2:T26:0|CTA2:T27:0 | CTA3:T24:0|CTA3:T25:0|CTA3:T26:0|CTA3:T27:0,"
"CTA2:T28:0|CTA2:T29:0|CTA2:T30:0|CTA2:T31:0 | CTA3:T28:0|CTA3:T29:0|CTA3:T30:0|CTA3:T31:0\n";
// clang-format on

runSliceBlockedMultiCTA(/*size=*/16, /*sizePerThread=*/{1, 1},
/*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{1, 1},
/*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2},
/*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0},
/*sliceDim=*/1, /*refStr=*/refStr);
}

//===----------------------------------------------------------------------===//
// Tests for SharedEncodingAttr
//===----------------------------------------------------------------------===//

TEST_F(EmitIndicesTest, SharedLayout) {
// clang-format off
std::string refStr =
"(0: 0),(0: 1),(0: 2),(0: 3),(0: 4),(0: 5),(0: 6),(0: 7),(0: 8),(0: 9),(0:10),(0:11),(0:12),(0:13),(0:14),(0:15),(0:16),(0:17),(0:18),(0:19),(0:20),(0:21),(0:22),(0:23),(0:24),(0:25),(0:26),(0:27),(0:28),(0:29),(0:30),(0:31)\n"
"(1: 0),(1: 1),(1: 2),(1: 3),(1: 4),(1: 5),(1: 6),(1: 7),(1: 8),(1: 9),(1:10),(1:11),(1:12),(1:13),(1:14),(1:15),(1:16),(1:17),(1:18),(1:19),(1:20),(1:21),(1:22),(1:23),(1:24),(1:25),(1:26),(1:27),(1:28),(1:29),(1:30),(1:31)\n"
"(2: 8),(2: 9),(2:10),(2:11),(2:12),(2:13),(2:14),(2:15),(2: 0),(2: 1),(2: 2),(2: 3),(2: 4),(2: 5),(2: 6),(2: 7),(2:24),(2:25),(2:26),(2:27),(2:28),(2:29),(2:30),(2:31),(2:16),(2:17),(2:18),(2:19),(2:20),(2:21),(2:22),(2:23)\n"
"(3: 8),(3: 9),(3:10),(3:11),(3:12),(3:13),(3:14),(3:15),(3: 0),(3: 1),(3: 2),(3: 3),(3: 4),(3: 5),(3: 6),(3: 7),(3:24),(3:25),(3:26),(3:27),(3:28),(3:29),(3:30),(3:31),(3:16),(3:17),(3:18),(3:19),(3:20),(3:21),(3:22),(3:23)\n"
"(4:16),(4:17),(4:18),(4:19),(4:20),(4:21),(4:22),(4:23),(4:24),(4:25),(4:26),(4:27),(4:28),(4:29),(4:30),(4:31),(4: 0),(4: 1),(4: 2),(4: 3),(4: 4),(4: 5),(4: 6),(4: 7),(4: 8),(4: 9),(4:10),(4:11),(4:12),(4:13),(4:14),(4:15)\n"
"(5:16),(5:17),(5:18),(5:19),(5:20),(5:21),(5:22),(5:23),(5:24),(5:25),(5:26),(5:27),(5:28),(5:29),(5:30),(5:31),(5: 0),(5: 1),(5: 2),(5: 3),(5: 4),(5: 5),(5: 6),(5: 7),(5: 8),(5: 9),(5:10),(5:11),(5:12),(5:13),(5:14),(5:15)\n"
"(6:24),(6:25),(6:26),(6:27),(6:28),(6:29),(6:30),(6:31),(6:16),(6:17),(6:18),(6:19),(6:20),(6:21),(6:22),(6:23),(6: 8),(6: 9),(6:10),(6:11),(6:12),(6:13),(6:14),(6:15),(6: 0),(6: 1),(6: 2),(6: 3),(6: 4),(6: 5),(6: 6),(6: 7)\n"
"(7:24),(7:25),(7:26),(7:27),(7:28),(7:29),(7:30),(7:31),(7:16),(7:17),(7:18),(7:19),(7:20),(7:21),(7:22),(7:23),(7: 8),(7: 9),(7:10),(7:11),(7:12),(7:13),(7:14),(7:15),(7: 0),(7: 1),(7: 2),(7: 3),(7: 4),(7: 5),(7: 6),(7: 7)\n";
// clang-format on

runSharedSingleCTA(/*row=*/8, /*col=*/32, /*rowMajor=*/true,
/*elemTyStr=*/"F16", /*refStr=*/refStr);
}

TEST_F(EmitIndicesTest, LayoutVisualizer_Blocked) {
CTALayoutAttr CTALayout =
CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{2, 2},
/*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0});

Attribute blockedLayout = BlockedEncodingAttr::get(
/*context=*/&context, /*sizePerThread=*/{1, 4},
/*threadsPerWarp=*/{2, 16},
/*warpsPerCTA=*/{4, 1}, /*order=*/{1, 0}, /*CTALayout=*/CTALayout);

llvm::SmallVector<int64_t> shape = {/*row=*/128, /*col=*/128};

std::ofstream ofs("blockedLayout.csv");
ofs << dumpDistributedLayout(blockedLayout, shape, /*multiCTA=*/true);
}

TEST_F(EmitIndicesTest, LayoutVisualizer_Shared) {
CTALayoutAttr CTALayout =
CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1, 1},
/*CTASplitNum=*/{1, 1}, /*CTAOrder=*/{1, 0});

Attribute sharedLayout = SharedEncodingAttr::get(
/*context=*/&context, /*vec=*/1, /*perPhase=*/2, /*maxPhase=*/8,
/*order=*/{0, 1}, /*CTALayout=*/CTALayout);

llvm::SmallVector<int64_t> shape = {/*row=*/16, /*col=*/16};
Type elemTy = FloatType::getF16(&context);

std::ofstream ofs("sharedLayout.csv");
ofs << dumpSharedLayout(sharedLayout, shape, elemTy, /*multiCTA=*/false);
}
#endif

//===----------------------------------------------------------------------===//
// Tests for SliceEncodingAttr
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -512,35 +597,6 @@ TEST_F(EmitIndicesTest, SliceLayout_SingleCTA_SliceDim0) {
/*order=*/{1, 0}, /*sliceDim=*/0, /*refStr=*/refStr);
}

TEST_F(EmitIndicesTest, SliceLayout_MultiCTA) {
// clang-format off
std::string refStr =
"CTA0: T0:0|CTA0: T1:0|CTA0: T2:0|CTA0: T3:0 | CTA1: T0:0|CTA1: T1:0|CTA1: T2:0|CTA1: T3:0,"
"CTA0: T4:0|CTA0: T5:0|CTA0: T6:0|CTA0: T7:0 | CTA1: T4:0|CTA1: T5:0|CTA1: T6:0|CTA1: T7:0,"
"CTA0: T8:0|CTA0: T9:0|CTA0:T10:0|CTA0:T11:0 | CTA1: T8:0|CTA1: T9:0|CTA1:T10:0|CTA1:T11:0,"
"CTA0:T12:0|CTA0:T13:0|CTA0:T14:0|CTA0:T15:0 | CTA1:T12:0|CTA1:T13:0|CTA1:T14:0|CTA1:T15:0,"
"CTA0:T16:0|CTA0:T17:0|CTA0:T18:0|CTA0:T19:0 | CTA1:T16:0|CTA1:T17:0|CTA1:T18:0|CTA1:T19:0,"
"CTA0:T20:0|CTA0:T21:0|CTA0:T22:0|CTA0:T23:0 | CTA1:T20:0|CTA1:T21:0|CTA1:T22:0|CTA1:T23:0,"
"CTA0:T24:0|CTA0:T25:0|CTA0:T26:0|CTA0:T27:0 | CTA1:T24:0|CTA1:T25:0|CTA1:T26:0|CTA1:T27:0,"
"CTA0:T28:0|CTA0:T29:0|CTA0:T30:0|CTA0:T31:0 | CTA1:T28:0|CTA1:T29:0|CTA1:T30:0|CTA1:T31:0,"

"CTA2: T0:0|CTA2: T1:0|CTA2: T2:0|CTA2: T3:0 | CTA3: T0:0|CTA3: T1:0|CTA3: T2:0|CTA3: T3:0,"
"CTA2: T4:0|CTA2: T5:0|CTA2: T6:0|CTA2: T7:0 | CTA3: T4:0|CTA3: T5:0|CTA3: T6:0|CTA3: T7:0,"
"CTA2: T8:0|CTA2: T9:0|CTA2:T10:0|CTA2:T11:0 | CTA3: T8:0|CTA3: T9:0|CTA3:T10:0|CTA3:T11:0,"
"CTA2:T12:0|CTA2:T13:0|CTA2:T14:0|CTA2:T15:0 | CTA3:T12:0|CTA3:T13:0|CTA3:T14:0|CTA3:T15:0,"
"CTA2:T16:0|CTA2:T17:0|CTA2:T18:0|CTA2:T19:0 | CTA3:T16:0|CTA3:T17:0|CTA3:T18:0|CTA3:T19:0,"
"CTA2:T20:0|CTA2:T21:0|CTA2:T22:0|CTA2:T23:0 | CTA3:T20:0|CTA3:T21:0|CTA3:T22:0|CTA3:T23:0,"
"CTA2:T24:0|CTA2:T25:0|CTA2:T26:0|CTA2:T27:0 | CTA3:T24:0|CTA3:T25:0|CTA3:T26:0|CTA3:T27:0,"
"CTA2:T28:0|CTA2:T29:0|CTA2:T30:0|CTA2:T31:0 | CTA3:T28:0|CTA3:T29:0|CTA3:T30:0|CTA3:T31:0\n";
// clang-format on

runSliceBlockedMultiCTA(/*size=*/16, /*sizePerThread=*/{1, 1},
/*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{1, 1},
/*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2},
/*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0},
/*sliceDim=*/1, /*refStr=*/refStr);
}

//===----------------------------------------------------------------------===//
// Tests for NvidiaMmaEncodingAttr
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -571,50 +627,13 @@ TEST_F(EmitIndicesTest, MmaLayout) {
/*refStr=*/refStr);
}

//===----------------------------------------------------------------------===//
// Tests for SharedEncodingAttr
//===----------------------------------------------------------------------===//

TEST_F(EmitIndicesTest, SharedLayout) {
// clang-format off
std::string refStr =
"(0: 0),(0: 1),(0: 2),(0: 3),(0: 4),(0: 5),(0: 6),(0: 7),(0: 8),(0: 9),(0:10),(0:11),(0:12),(0:13),(0:14),(0:15),(0:16),(0:17),(0:18),(0:19),(0:20),(0:21),(0:22),(0:23),(0:24),(0:25),(0:26),(0:27),(0:28),(0:29),(0:30),(0:31)\n"
"(1: 0),(1: 1),(1: 2),(1: 3),(1: 4),(1: 5),(1: 6),(1: 7),(1: 8),(1: 9),(1:10),(1:11),(1:12),(1:13),(1:14),(1:15),(1:16),(1:17),(1:18),(1:19),(1:20),(1:21),(1:22),(1:23),(1:24),(1:25),(1:26),(1:27),(1:28),(1:29),(1:30),(1:31)\n"
"(2: 8),(2: 9),(2:10),(2:11),(2:12),(2:13),(2:14),(2:15),(2: 0),(2: 1),(2: 2),(2: 3),(2: 4),(2: 5),(2: 6),(2: 7),(2:24),(2:25),(2:26),(2:27),(2:28),(2:29),(2:30),(2:31),(2:16),(2:17),(2:18),(2:19),(2:20),(2:21),(2:22),(2:23)\n"
"(3: 8),(3: 9),(3:10),(3:11),(3:12),(3:13),(3:14),(3:15),(3: 0),(3: 1),(3: 2),(3: 3),(3: 4),(3: 5),(3: 6),(3: 7),(3:24),(3:25),(3:26),(3:27),(3:28),(3:29),(3:30),(3:31),(3:16),(3:17),(3:18),(3:19),(3:20),(3:21),(3:22),(3:23)\n"
"(4:16),(4:17),(4:18),(4:19),(4:20),(4:21),(4:22),(4:23),(4:24),(4:25),(4:26),(4:27),(4:28),(4:29),(4:30),(4:31),(4: 0),(4: 1),(4: 2),(4: 3),(4: 4),(4: 5),(4: 6),(4: 7),(4: 8),(4: 9),(4:10),(4:11),(4:12),(4:13),(4:14),(4:15)\n"
"(5:16),(5:17),(5:18),(5:19),(5:20),(5:21),(5:22),(5:23),(5:24),(5:25),(5:26),(5:27),(5:28),(5:29),(5:30),(5:31),(5: 0),(5: 1),(5: 2),(5: 3),(5: 4),(5: 5),(5: 6),(5: 7),(5: 8),(5: 9),(5:10),(5:11),(5:12),(5:13),(5:14),(5:15)\n"
"(6:24),(6:25),(6:26),(6:27),(6:28),(6:29),(6:30),(6:31),(6:16),(6:17),(6:18),(6:19),(6:20),(6:21),(6:22),(6:23),(6: 8),(6: 9),(6:10),(6:11),(6:12),(6:13),(6:14),(6:15),(6: 0),(6: 1),(6: 2),(6: 3),(6: 4),(6: 5),(6: 6),(6: 7)\n"
"(7:24),(7:25),(7:26),(7:27),(7:28),(7:29),(7:30),(7:31),(7:16),(7:17),(7:18),(7:19),(7:20),(7:21),(7:22),(7:23),(7: 8),(7: 9),(7:10),(7:11),(7:12),(7:13),(7:14),(7:15),(7: 0),(7: 1),(7: 2),(7: 3),(7: 4),(7: 5),(7: 6),(7: 7)\n";
// clang-format on

runSharedSingleCTA(/*row=*/8, /*col=*/32, /*rowMajor=*/true,
/*elemTyStr=*/"F16", /*refStr=*/refStr);
}

//===----------------------------------------------------------------------===//
// The following unittests are tools for Triton developers to visualize layouts.
// You can modify parameters and shapes here to create your own layout and
// tensor. The output will be saved into a csv file which can be opened with
// Microsoft Excel.
//===----------------------------------------------------------------------===//

TEST_F(EmitIndicesTest, LayoutVisualizer_Blocked) {
CTALayoutAttr CTALayout =
CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{2, 2},
/*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0});

Attribute blockedLayout = BlockedEncodingAttr::get(
/*context=*/&context, /*sizePerThread=*/{1, 4},
/*threadsPerWarp=*/{2, 16},
/*warpsPerCTA=*/{4, 1}, /*order=*/{1, 0}, /*CTALayout=*/CTALayout);

llvm::SmallVector<int64_t> shape = {/*row=*/128, /*col=*/128};

std::ofstream ofs("blockedLayout.csv");
ofs << dumpDistributedLayout(blockedLayout, shape, /*multiCTA=*/true);
}

TEST_F(EmitIndicesTest, LayoutVisualizer_Slice) {
CTALayoutAttr CTALayout =
CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1, 1},
Expand Down Expand Up @@ -648,22 +667,6 @@ TEST_F(EmitIndicesTest, LayoutVisualizer_Mma) {
ofs << dumpDistributedLayout(mmaLayout, shape, /*multiCTA=*/false);
}

TEST_F(EmitIndicesTest, LayoutVisualizer_Shared) {
CTALayoutAttr CTALayout =
CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1, 1},
/*CTASplitNum=*/{1, 1}, /*CTAOrder=*/{1, 0});

Attribute sharedLayout = SharedEncodingAttr::get(
/*context=*/&context, /*vec=*/1, /*perPhase=*/2, /*maxPhase=*/8,
/*order=*/{0, 1}, /*CTALayout=*/CTALayout);

llvm::SmallVector<int64_t> shape = {/*row=*/16, /*col=*/16};
Type elemTy = FloatType::getF16(&context);

std::ofstream ofs("sharedLayout.csv");
ofs << dumpSharedLayout(sharedLayout, shape, elemTy, /*multiCTA=*/false);
}

} // namespace gpu
} // namespace triton
} // namespace mlir
Expand All @@ -674,6 +677,5 @@ TEST_F(EmitIndicesTest, LayoutVisualizer_Shared) {

int main(int argc, char *argv[]) {
testing::InitGoogleTest(&argc, argv);
// FIXME: These tests are temporarily disabled due to ctaid.x|y|z are swapped
// return RUN_ALL_TESTS();
return RUN_ALL_TESTS();
}

0 comments on commit d04f288

Please sign in to comment.