From 9a83624e49413c67ac6584623bedaaed929ad02a Mon Sep 17 00:00:00 2001 From: Shucai Xiao Date: Sun, 10 Mar 2024 23:10:10 -0500 Subject: [PATCH 1/9] turn on reduce tests with `argmax/argmin` (#3333) The latest main branch seems working fine for all cases in test_reduce, so turn on them for amd backend. --- python/test/unit/language/test_core.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d28ede1202..9ed0a8fdc1 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1871,8 +1871,6 @@ def kernel(X, Z, BLOCK: tl.constexpr): negative_config + keep_dims_2d_configs + keep_dims_3d_configs) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): - if is_hip() and (op == 'argmin' or op == 'argmax'): - pytest.skip("TODO some tests for argmin and argmax do not work on HIP") check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested @triton.jit From 91fcb98d0640dff721351f72abd201a9b0d70250 Mon Sep 17 00:00:00 2001 From: JunMa Date: Mon, 11 Mar 2024 12:27:15 +0800 Subject: [PATCH 2/9] Recommit [OPTIMIZATION] Fix addptr combine pattern (#3148) (#3289) Add addptr(addptr(%ptr, %idx0), %idx1) pattern again with constraint which makes sure that both element type of offset should be same. --- lib/Dialect/Triton/Transforms/Combine.cpp | 43 ++++++++++- lib/Dialect/Triton/Transforms/Combine.td | 10 +-- test/Triton/combine.mlir | 93 +++++++++++++++++++++-- 3 files changed, 135 insertions(+), 11 deletions(-) diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index c145b952b9..3cfa8e65a1 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -49,6 +49,47 @@ DenseElementsAttr getConstantValue(Builder &builder, Attribute value, return res; } +bool isAddPtrOffsetCombinable(Value first, Value second) { + auto GetConstantIntValue = [](Value val) -> std::optional { + DenseElementsAttr constAttr; + auto defOp = val.getDefiningOp(); + if (defOp) { + if (auto splatOp = llvm::dyn_cast(defOp)) + val = splatOp.getSrc(); + else if (matchPattern(defOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto attr = constAttr.getSplatValue(); + // Check IntegerAttr + if (auto intAttr = dyn_cast_or_null(attr)) + return intAttr.getValue(); + } + } + + // Check constant value. + llvm::APInt intVal; + if (matchPattern(val, m_ConstantInt(&intVal))) + return intVal; + + return std::nullopt; + }; + + if (first.getType() == second.getType()) { + // Whether bitwidth of element type is equal to pointer + if (getElementTypeOrSelf(first.getType()).getIntOrFloatBitWidth() == 64) + return true; + + // first + second does not overflow + auto firstVal = GetConstantIntValue(first); + auto secondVal = GetConstantIntValue(second); + if (firstVal && secondVal) { + bool overflow = false; + auto resVal = firstVal->sadd_ov(*secondVal, overflow); + return !overflow; + } + } + return false; +} + // TODO(csigg): remove after next LLVM integrate. using FastMathFlags = arith::FastMathFlags; @@ -195,7 +236,7 @@ class CombineOpsPass : public TritonCombineOpsBase { patterns.add(context); // %} patterns.add(context); - // patterns.add(context); + patterns.add(context); patterns.add(context); patterns.add(context); diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index f46b8b2e58..49e6950aa6 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -35,14 +35,14 @@ def CombineDotAddFRevPattern : Pat< (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), (ConstrainthasOneUse()">, "dot result has a single use">)]>; -// TODO: this fails for addptr(addptr(ptr, i32), i64) -// Commented out until fixed // addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1)) // Note: leave (sub %c0, %c0) canceling to ArithDialect // (ref: ArithCanonicalization.td) -// def CombineAddPtrPattern : Pat< -// (TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1), -// (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1))>; +defvar DefOverflow = ConstantEnumCase; +def CombineAddPtrPattern : Pat< + (TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1), + (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)), + [(Constraint> $idx0, $idx1)]>; // broadcast(cst) => cst def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">; diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index 6c30ecd577..a7afba95bb 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -63,28 +63,111 @@ tt.func @test_combine_dot_add_rev_pattern() -> (tensor<128x128xf32>) { } -// COM: CHECK-LABEL: @test_combine_addptr_pattern +// CHECK-LABEL: @test_combine_addptr_pattern tt.func @test_combine_addptr_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> { %off0 = arith.constant 10 : i32 %off1 = arith.constant 15 : i32 - // 10 + 15 = 25 - // COM: CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32> + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32> %base_ = tt.splat %base : !tt.ptr -> tensor<8x!tt.ptr> - // COM: CHECK-NEXT: %[[tmp0:.*]] = tt.splat %{{.*}} : !tt.ptr -> tensor<8x!tt.ptr> + // CHECK-NEXT: %[[tmp0:.*]] = tt.splat %{{.*}} : !tt.ptr -> tensor<8x!tt.ptr> %idx0 = tt.splat %off0 : i32 -> tensor<8xi32> %idx1 = tt.splat %off1 : i32 -> tensor<8xi32> - // COM: CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr>, tensor<8xi32> + // CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr>, tensor<8xi32> %ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr>, tensor<8xi32> %ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.return %ptr1 : tensor<8x!tt.ptr> +} + +// CHECK-LABEL: @test_combine_addptr_pattern_i64 +tt.func @test_combine_addptr_pattern_i64(%base: !tt.ptr) -> tensor<8x!tt.ptr> { + %off0 = arith.constant 10 : i64 + %off1 = arith.constant dense<15> : tensor<8xi64> + + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi64> + + %base_ = tt.splat %base : !tt.ptr -> tensor<8x!tt.ptr> + + // CHECK-NEXT: %[[tmp0:.*]] = tt.splat %{{.*}} : !tt.ptr -> tensor<8x!tt.ptr> + + %idx0 = tt.splat %off0 : i64 -> tensor<8xi64> + + // CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr>, tensor<8xi64> + %ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr>, tensor<8xi64> + %ptr1 = tt.addptr %ptr0, %off1 : tensor<8x!tt.ptr>, tensor<8xi64> + + tt.return %ptr1 : tensor<8x!tt.ptr> +} + +// CHECK-LABEL: @test_combine_addptr_pattern_scalar +tt.func @test_combine_addptr_pattern_scalar(%base: !tt.ptr) -> !tt.ptr { + %off0 = arith.constant 10 : i32 + %off1 = arith.constant 15 : i32 + + // CHECK-NEXT: %[[cst:.*]] = arith.constant 25 : i32 + // CHECK-NEXT: %0 = tt.addptr %{{.*}}, %[[cst]] : !tt.ptr, i32 + %ptr0 = tt.addptr %base, %off0 : !tt.ptr, i32 + %ptr1 = tt.addptr %ptr0, %off1 : !tt.ptr, i32 + + tt.return %ptr1 : !tt.ptr +} + +// CHECK-LABEL: @test_not_combine_addptr_pattern_1 +tt.func @test_not_combine_addptr_pattern_1(%base: !tt.ptr, %idx0: tensor<8xi32>) -> tensor<8x!tt.ptr> { + %off1 = arith.constant 15 : i32 + + %base_ = tt.splat %base : !tt.ptr -> tensor<8x!tt.ptr> + %idx1 = tt.splat %off1 : i32 -> tensor<8xi32> + + // CHECK: tt.addptr + // CHECK-NEXT: tt.addptr + %ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr>, tensor<8xi32> + %ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.return %ptr1 : tensor<8x!tt.ptr> +} + +// CHECK-LABEL: @test_not_combine_addptr_pattern +tt.func @test_not_combine_addptr_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> { + %off0 = arith.constant 10 : i16 + %off1 = arith.constant 15 : i32 + + // CHECK-DAG: %[[cst:.*]] = arith.constant dense<10> : tensor<8xi16> + // CHECK-DAG: %[[cst1:.*]] = arith.constant dense<15> : tensor<8xi32> + + %base_ = tt.splat %base : !tt.ptr -> tensor<8x!tt.ptr> + + %idx0 = tt.splat %off0 : i16 -> tensor<8xi16> + %idx1 = tt.splat %off1 : i32 -> tensor<8xi32> + + %ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr>, tensor<8xi16> + %ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.return %ptr1 : tensor<8x!tt.ptr> } +// CHECK-LABEL: @test_not_combine_addptr_pattern_overflow +tt.func @test_not_combine_addptr_pattern_overflow(%base: !tt.ptr) -> tensor<8x!tt.ptr> { + %off0 = arith.constant 127 : i8 + %off1 = arith.constant 1 : i8 + + // CHECK-DAG: %[[cst:.*]] = arith.constant dense<127> : tensor<8xi8> + // CHECK-DAG: %[[cst1:.*]] = arith.constant dense<1> : tensor<8xi8> + + %base_ = tt.splat %base : !tt.ptr -> tensor<8x!tt.ptr> + + %idx0 = tt.splat %off0 : i8 -> tensor<8xi8> + %idx1 = tt.splat %off1 : i8 -> tensor<8xi8> + + %ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr>, tensor<8xi8> + %ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr>, tensor<8xi8> + + tt.return %ptr1 : tensor<8x!tt.ptr> +} // CHECK-LABEL: @test_combine_select_masked_load_pattern tt.func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) { From e26580bf028815a1d10485fb548e4e634e449a4a Mon Sep 17 00:00:00 2001 From: Shucai Xiao Date: Sun, 10 Mar 2024 23:58:26 -0500 Subject: [PATCH 3/9] Revert "turn on reduce tests with `argmax/argmin` (#3333)" (#3334) This reverts commit 9a83624e49413c67ac6584623bedaaed929ad02a. --- python/test/unit/language/test_core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 9ed0a8fdc1..d28ede1202 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1871,6 +1871,8 @@ def kernel(X, Z, BLOCK: tl.constexpr): negative_config + keep_dims_2d_configs + keep_dims_3d_configs) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): + if is_hip() and (op == 'argmin' or op == 'argmax'): + pytest.skip("TODO some tests for argmin and argmax do not work on HIP") check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested @triton.jit From 38cc733efd1262dc6c81a1862247c09e9d982350 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Mon, 11 Mar 2024 01:01:49 -0500 Subject: [PATCH 4/9] [AMD] Bump numpy to 1.22.4 on AMD backend CI (#3335) --- .github/workflows/test-backends.yml | 1 + python/test/unit/language/test_core.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test-backends.yml b/.github/workflows/test-backends.yml index 4fe8c546d4..ab8b72e913 100644 --- a/.github/workflows/test-backends.yml +++ b/.github/workflows/test-backends.yml @@ -71,6 +71,7 @@ jobs: - name: Install Triton on ROCM run: | + pip install --force-reinstall numpy==1.22.4 pip uninstall -y triton cd python pip install -e . diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d28ede1202..9ed0a8fdc1 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1871,8 +1871,6 @@ def kernel(X, Z, BLOCK: tl.constexpr): negative_config + keep_dims_2d_configs + keep_dims_3d_configs) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): - if is_hip() and (op == 'argmin' or op == 'argmax'): - pytest.skip("TODO some tests for argmin and argmax do not work on HIP") check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested @triton.jit From 6bf52aa9faa0e23fba6b8c58a660d5312c953d21 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Mon, 11 Mar 2024 09:39:50 -0700 Subject: [PATCH 5/9] Add documentation explaining how dealloc is optional. (#3337) See https://github.com/openai/triton/pull/3327#discussion_r1519755419 --- include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 6d51b280e9..382152c1a3 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -171,6 +171,8 @@ def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods:$init); @@ -184,7 +186,12 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree Date: Mon, 11 Mar 2024 10:17:37 -0700 Subject: [PATCH 6/9] [BACKEND] Transformed TritonGPUToLLVMPass to TritonGPUToLLVM in amd backend (#3328) Changes: - Extracted SPMD, Print, ControlFlow patterns to separate files - Deleted TritonGPUToLLVM - Renamed TritonGPUToLLVMPass to TritonGPUToLLVM - Extracted ControlFlow, and GetProgramId, FuncOp patterns to common lib --- .../PatternTritonGPUOpToLLVM.h | 13 + .../TritonGPUToLLVM/TargetInfoBase.h | 2 + lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 3 + .../TritonGPUToLLVM}/ControlFlowOpToLLVM.cpp | 6 +- .../TritonGPUToLLVM/FuncOpToLLVM.cpp | 116 ++++ .../TritonGPUToLLVM/SPMDOpToLLVM.cpp | 38 ++ .../amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt | 4 +- .../lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp | 2 - .../PatternTritonGPUOpToLLVM.h | 10 +- .../lib/TritonAMDGPUToLLVM/SPMDOpToLLVM.cpp | 35 + .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp | 5 + .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.h | 2 + .../TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 604 +++++++----------- .../TritonGPUToLLVMPass.cpp | 503 --------------- .../amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 12 + .../amd/lib/TritonAMDGPUToLLVM/Utility.h | 3 + .../lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt | 3 +- .../PatternTritonGPUOpToLLVM.h | 12 +- .../TritonNVIDIAGPUToLLVM/PrintOpToLLVM.cpp | 5 +- .../TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp | 16 - .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 4 + .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.h | 2 + .../TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp | 116 +--- .../lib/TritonNVIDIAGPUToLLVM/Utility.cpp | 4 +- .../lib/TritonNVIDIAGPUToLLVM/Utility.h | 4 +- 25 files changed, 506 insertions(+), 1018 deletions(-) rename {third_party/nvidia/lib/TritonNVIDIAGPUToLLVM => lib/Conversion/TritonGPUToLLVM}/ControlFlowOpToLLVM.cpp (96%) create mode 100644 lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp create mode 100644 lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp create mode 100644 third_party/amd/lib/TritonAMDGPUToLLVM/SPMDOpToLLVM.cpp delete mode 100644 third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVMPass.cpp diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index 868c4df2ea..ace77a3e92 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -76,6 +76,19 @@ void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); +void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, int numWarps, + PatternBenefit benefit); + } // namespace triton } // namespace mlir diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 9b7d6f49cd..3e12c922d2 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -22,6 +22,8 @@ class TargetInfoBase { Value val, int i) const = 0; virtual Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, Value i) const = 0; + virtual Value programId(Location loc, ConversionPatternRewriter &rewriter, + ModuleOp moduleOp, int axis) const = 0; virtual bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce) const = 0; diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 98879ef1c0..e538a5c14a 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -13,6 +13,9 @@ add_triton_library(TritonGPUToLLVM ReduceOpToLLVM.cpp ScanOpToLLVM.cpp ConvertLayoutOpToLLVM.cpp + ControlFlowOpToLLVM.cpp + FuncOpToLLVM.cpp + SPMDOpToLLVM.cpp DEPENDS TritonGPUConversionPassIncGen diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ControlFlowOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp similarity index 96% rename from third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ControlFlowOpToLLVM.cpp rename to lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp index 91378f5b16..9765d7bf00 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ControlFlowOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -1,5 +1,5 @@ -#include "PatternTritonGPUOpToLLVM.h" -#include "Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" namespace { @@ -133,7 +133,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { } // namespace -void mlir::triton::NVIDIA::populateControlFlowOpToLLVMPattern( +void mlir::triton::populateControlFlowOpToLLVMPattern( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, benefit); diff --git a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp new file mode 100644 index 0000000000..04fa645477 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -0,0 +1,116 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir { +FailureOr +convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter); +} + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +/// FuncOp legalization pattern that converts MemRef arguments to pointers to +/// MemRef descriptors (LLVM struct data types) containing all the MemRef type +/// information. +struct FuncOpConversion : public ConvertOpToLLVMPattern { + FuncOpConversion(LLVMTypeConverter &converter, int numWarps, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), numWarps(numWarps) {} + + /// Only retain those attributes that are not constructed by + /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument + /// attributes. + static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, + SmallVectorImpl &result) { + + for (const auto &attr : op->getAttrs()) { + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == op.getFunctionTypeAttrName() || + attr.getName() == "std.varargs" || + (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) + continue; + result.push_back(attr); + } + } + + triton::FuncOp amendFuncOp(triton::FuncOp funcOp, + ConversionPatternRewriter &rewriter) const { + // Push back a variable that indicates the current stack pointer of shared + // memory to the function arguments. + auto loc = funcOp.getLoc(); + auto ctx = funcOp->getContext(); + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); + // 1. Modify the function type to add the new argument. + auto funcTy = funcOp.getFunctionType(); + auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); + amendedInputTy.push_back(ptrTy); + auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, + funcTy.getResults()); + // 2. Modify the argument attributes to add the new argument. + SmallVector amendedAttrs; + filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); + auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedAttrs.push_back(rewriter.getNamedAttr( + funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); + // 3. Add a new argument to the region + auto amendedFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); + auto ®ion = funcOp.getBody(); + region.addArgument(ptrTy, loc); + rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), + amendedFuncOp.end()); + return amendedFuncOp; + } + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Prevent LLVM's inliner to inline this function + auto amendedFuncOp = funcOp; + if (!LLVM::isKernel(funcOp)) + amendedFuncOp = amendFuncOp(funcOp, rewriter); + + LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( + amendedFuncOp, rewriter, *getTypeConverter()); + if (!newFuncOp) { + return failure(); + } + + auto ctx = funcOp->getContext(); + + if (LLVM::isKernel(funcOp)) { + // Set an attribute to indicate this function is a kernel entry. + newFuncOp->setAttr("nvvm.kernel", + rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); + } else { + // The noinline attribute will be used by the LLVM codegen to prevent + // inlining. + // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267 + newFuncOp.setPassthroughAttr( + ArrayAttr::get(ctx, rewriter.getStringAttr("noinline"))); + rewriter.eraseOp(amendedFuncOp); + } + // Set an attribute for maxntidx, it could be used in latter LLVM codegen + // for `nvvm.annotation` metadata. + newFuncOp->setAttr("nvvm.maxntid", + rewriter.getDenseI32ArrayAttr(32 * numWarps)); + rewriter.eraseOp(funcOp); + return success(); + } + +private: + int numWarps{0}; +}; + +} // namespace + +void mlir::triton::populateFuncOpConversionPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, + PatternBenefit benefit) { + patterns.add(typeConverter, numWarps, benefit); +} diff --git a/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp new file mode 100644 index 0000000000..26eb528282 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp @@ -0,0 +1,38 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct GetProgramIdOpConversion + : public ConvertOpToLLVMPattern { + explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value programId = targetInfo.programId(op->getLoc(), rewriter, + op->getParentOfType(), + op.getAxisAsInt()); + rewriter.replaceOp(op, programId); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index 0506877ba1..012290e4fe 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -6,12 +6,12 @@ add_triton_library(TritonAMDGPUToLLVM DotOpToLLVM.cpp ElementwiseOpToLLVM.cpp LoadStoreOpToLLVM.cpp - TritonGPUToLLVM.cpp GCNAsmFormat.cpp - TritonGPUToLLVMPass.cpp + TritonGPUToLLVM.cpp Utility.cpp TargetInfo.cpp DecomposeUnsupportedConversions.cpp + SPMDOpToLLVM.cpp DEPENDS TritonAMDGPUConversionPassIncGen diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp index c179e714eb..871d3a89d1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp @@ -11,7 +11,6 @@ using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; namespace AMD { -#ifdef USE_ROCM LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter); @@ -19,7 +18,6 @@ LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter); -#endif } // namespace AMD namespace { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h index b8af22595e..0bfebdb374 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -28,11 +28,11 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit); -void populateTritonGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, int numWarps, - ModuleAxisInfoAnalysis &axisInfoAnalysis, - ModuleAllocation &allocation, - PatternBenefit benefit); + +void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + } // namespace AMD #endif diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SPMDOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SPMDOpToLLVM.cpp new file mode 100644 index 0000000000..71d643f288 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SPMDOpToLLVM.cpp @@ -0,0 +1,35 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct GetNumProgramsOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::GetNumProgramsOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, + mlir::gpu::Dimension::y, + mlir::gpu::Dimension::z}; + Location loc = op->getLoc(); + assert(op.getAxis() < 3); + Value blockId = + rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]); + rewriter.replaceOpWithNewOp(op, i32_ty, blockId); + return success(); + } +}; + +} // namespace + +void AMD::populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 43098ec8a2..5e650b681c 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -67,6 +67,11 @@ Value TargetInfo::shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); } +Value TargetInfo::programId(Location loc, ConversionPatternRewriter &rewriter, + ModuleOp moduleOp, int axis) const { + return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis); +} + bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce) const { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 41c6a290a7..2d0569f81d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -21,6 +21,8 @@ class TargetInfo : public mlir::triton::TargetInfoBase { int i) const override; Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, Value i) const override; + Value programId(Location loc, ConversionPatternRewriter &rewriter, + ModuleOp moduleOp, int axis) const override; bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce) const override; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 769cbb7a10..2ca3f108d4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -1,401 +1,281 @@ -#include "PatternTritonGPUOpToLLVM.h" +#include "TritonAMDGPUToLLVM/Passes.h" +#include "TargetInfo.h" #include "Utility.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetPlatform.hpp" + +#include "PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" + #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" -namespace { -using namespace mlir; -using namespace mlir::triton; +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -using ::mlir::LLVM::getSharedMemoryObjectFromStruct; -using ::mlir::triton::gpu::getTotalElemsPerThread; -using ::mlir::triton::gpu::SharedEncodingAttr; - -Value llGetPid(int axis, Location loc, ModuleOp moduleOp, - ConversionPatternRewriter &rewriter) { - assert(axis >= 0); - assert(axis < 3); - assert(moduleOp); - static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, - mlir::gpu::Dimension::y, - mlir::gpu::Dimension::z}; - Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[axis]); - return rewriter.create(loc, i32_ty, blockId); -} +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTTRITONAMDGPUTOLLVM +#include "TritonAMDGPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir -struct ReturnOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +using namespace mlir; +using namespace mlir::triton; - LogicalResult - matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - unsigned numArguments = op.getNumOperands(); +namespace { - // Currently, Triton kernel function always return nothing. - // TODO(Superjomn) add support for non-inline device function - if (numArguments > 0) { - return rewriter.notifyMatchFailure( - op, "Only kernel function with nothing returned is supported."); - } +// pass ws related named attrs. +static void addWSNamedAttrs(Operation *op, + ArrayRef attrs) { + for (const NamedAttribute attr : attrs) + if (attr.getName() == "async_agent" || attr.getName() == "agent.mutex_role") + op->setAttr(attr.getName(), attr.getValue()); +} - rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), - op->getAttrs()); - return success(); +#ifdef USE_ROCM +constexpr int LDSSize = 65536; +constexpr int kPtrBitWidth = 64; +#endif +class TritonLLVMFunctionConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); } }; -// The input print op contains: -// - a "prefix" (string) specified by the user, and -// - one or more "operands" (tensors). -// -// For each operand, we print all of the values contained in this GPU thread, -// one per line, along with the index of the value in its tensor. -struct PrintOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - Value prefixStr = - LLVM::addStringToModule(loc, rewriter, "printfPrefix_", op.getPrefix()); - - auto getPid = [&](int axis) { - return llGetPid(axis, loc, op->getParentOfType(), rewriter); - }; - std::array pid = {getPid(0), getPid(1), getPid(2)}; - - // Simple printf of a string without any tensors. - if (op.getNumOperands() == 0) { - std::string formatStr; - llvm::raw_string_ostream os(formatStr); - os << "pid (" << getFormatSubstr(pid[0]) << ", " - << getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")%s"; - llPrintf(formatStr, {pid[0], pid[1], pid[2], prefixStr}, rewriter); - } else { - for (size_t i = 0; i < op.getNumOperands(); i++) { - // Elements of the tensor that are resident in this GPU thread. - auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); - - // Get the indices of `elems` within the tensor. Note that if `elems` - // has an "interesting" layout, then these will not be in any - // particularly nice order. - - // Extract the shape of the tensor being printed and use it to figure - // out how many digits we need for each of the dimensions. - SmallVector dimWidths; - SmallVector> indices; - if (auto rankedTy = - op.getOperand(i).getType().dyn_cast()) { - indices = emitIndices(loc, rewriter, rankedTy.getEncoding(), rankedTy, - true); - for (int64_t dim : rankedTy.getShape()) { - if (dim > 0) { - dimWidths.push_back(static_cast(std::ceil(std::log10(dim)))); - } else { - dimWidths.push_back(0); - } - } - } else { - // We're printing a scalar. - assert(elems.size() == 1); - indices.push_back({}); - } - - if (!elems.empty()) { - printTensor(prefixStr, /*operand=*/i, - /*numOperands=*/op.getNumOperands(), elems, pid, indices, - dimWidths, rewriter); - } - } - } - rewriter.eraseOp(op); - return success(); +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addLegalOp(); } +}; - void printTensor(Value prefixStr, size_t operand, size_t numOperands, - ArrayRef elems, std::array pid, - ArrayRef> indices, - ArrayRef dimWidths, - ConversionPatternRewriter &rewriter) const { - assert(!elems.empty()); - assert(elems.size() == indices.size()); - assert(dimWidths.size() == indices.front().size()); +struct ConvertTritonAMDGPUToLLVM + : public triton::impl::ConvertTritonAMDGPUToLLVMBase< + ConvertTritonAMDGPUToLLVM> { + using ConvertTritonAMDGPUToLLVMBase< + ConvertTritonAMDGPUToLLVM>::ConvertTritonAMDGPUToLLVMBase; - size_t rank = dimWidths.size(); + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } - // Format is: - // pid (, , ) idx (, , ...) (operand ) - // where we leave off "(operand )" if there's only one operand. - // - // The Python wrapper munges `prefix` so that it prints nicely (e.g. starts - // with " " and ends with ": "). - - Value formatStrValue; - for (int i = 0; i < elems.size(); i++) { - std::string formatStr; - llvm::raw_string_ostream os(formatStr); - - // nvptx printf can only accept 32 args; if we pass more than that, it - // will print garbage for the trailing args. - constexpr int kMaxPrintfOperands = 32; - SmallVector printfOperands; - - // TODO(jlebar): We really should pad the pid, but because the max pid is - // not known at compile-time, this would require nontrivial device-side - // work. - os << "pid ("; - for (int j = 0; j < pid.size(); j++) { - if (j != 0) { - os << ", "; - } - os << getFormatSubstr(pid[j]); - printfOperands.push_back(pid[j]); - } - os << ") "; - - // If `rank` is large enough, we could end up exceeding - // kMaxPrintfOperands. In that case, just truncate the index. - // (Subtract 2 because we're going to add two operands after the index.) - int maxAllowedRank = kMaxPrintfOperands - printfOperands.size() - 2; - - os << "idx ("; - const auto &index = indices[i]; - for (size_t dim = 0; dim < index.size(); dim++) { - if (dim != 0) { - os << ", "; - } - if (dim == maxAllowedRank) { - os << "... (truncated)"; - break; - } - os << getFormatSubstr(index[dim], /*width=*/dimWidths[dim]); - printfOperands.push_back(index[dim]); - } - os << ")"; - - os << "%s"; - printfOperands.push_back(prefixStr); - - if (numOperands > 1) { - os << "(operand " << operand << ") "; - } - - auto elem = elems[i]; - os << getFormatSubstr(elem); - printfOperands.push_back(elem); - - // It's the same format string each iteration, but it's a lot easier if we - // construct the format string at the same time as we populate - // printfOperands. But we don't want to create BLOCK_SIZE duplicate - // strings, so we cache the Value. - if (i == 0) { - formatStrValue = llPrintf(formatStr, printfOperands, rewriter); - } else { - llPrintf(formatStrValue, printfOperands, rewriter); - } + ConvertTritonAMDGPUToLLVM(int32_t computeCapability) + : ConvertTritonAMDGPUToLLVMBase({computeCapability}) {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + mlir::LowerToLLVMOptions option(context); + option.overrideIndexBitwidth(32); + TritonGPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + + // Hack: WSMaterialization may have changed the effective number of warps, + // in a way that isn't reflected in triton_gpu.num-warps. If so, we have to + // respect that here. + if (Attribute attr = mod->getAttr("triton_gpu.num-warp-groups-per-cta")) { + numWarps *= attr.cast().getInt(); } - } - std::string getFormatSubstr(Value value, - std::optional width = std::nullopt) const { - std::string prefix = "%"; - if (width.has_value()) { - prefix += std::to_string(*width); + // Allocate shared memory and set barrier + ModuleAllocation allocation(mod); + ModuleMembarAnalysis membarPass(&allocation); + membarPass.run(); + + // Lower functions + { + mlir::LowerToLLVMOptions option(context); + TritonGPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMFunctionConversionTarget funcTarget(*context); + RewritePatternSet funcPatterns(context); + mlir::triton::populateFuncOpConversionPattern( + typeConverter, funcPatterns, numWarps, patternBenefitDefault); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + funcPatterns); + if (failed( + applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) + return signalPassFailure(); } - Type type = value.getType(); - if (type.isa()) { - return prefix + "p"; - } else if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { - return prefix + "f"; - } else if (type.isSignedInteger()) { - if (type.getIntOrFloatBitWidth() == 64) - return prefix + "lli"; - else - return prefix + "i"; - } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { - if (type.getIntOrFloatBitWidth() == 64) - return prefix + "llu"; - else - return prefix + "u"; + // initSharedMemory is run before the conversion of call and ret ops, + // because the call op has to know the shared memory base address of each + // function + initSharedMemory(typeConverter); + + // Convert call and ret ops + { + mlir::LowerToLLVMOptions option(context); + TritonGPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMFunctionConversionTarget funcTarget(*context); + RewritePatternSet funcPatterns(context); + if (failed( + applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) + return signalPassFailure(); } - assert(false && "not supported type"); - return ""; - } - // declare vprintf(i8*, i8*) as external function - static LLVM::LLVMFuncOp - getVprintfDeclaration(ConversionPatternRewriter &rewriter) { - auto moduleOp = - rewriter.getBlock()->getParent()->getParentOfType(); - StringRef funcName("vprintf"); - Operation *funcOp = moduleOp.lookupSymbol(funcName); - if (funcOp) - return cast(*funcOp); + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + + // Emit logics to get threadId/blockIds/linearized clusterCTAId etc. and + // cache the values. The reason to do it here is that cluster_ctaid is + // currently implemented via inline asm, and thus cannot be CSEed. + // clusterCTAId will be emitted only when numCTAs is larger than 1, and + // other values will be DCEed if not used hereafter. + OpBuilder::InsertPoint indexInsertPoint; + + RewritePatternSet patterns(context); + AMD::TargetInfo targetInfo("gfx1200"); + int benefit = patternBenefitPrioritizeOverLLVMConversions; + auto populatePatterns1 = [&](auto populateFunc) { + populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, + allocation, benefit); + }; - auto *context = rewriter.getContext(); + auto populatePatterns2 = [&](auto populateFunc) { + populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, + allocation, benefit); + }; - SmallVector argsType{ptr_ty(context), ptr_ty(context)}; - auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType); + auto populatePatterns3 = [&](auto populateFunc) { + populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, + allocation, benefit); + }; - ConversionPatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); + auto populatePatterns4 = [&](auto populateFunc) { + populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, + allocation, computeCapability, benefit); + }; - return rewriter.create(UnknownLoc::get(context), funcName, - funcType); - } + auto populatePatterns5 = [&](auto populateFunc) { + populateFunc(typeConverter, patterns, benefit); + }; - // extend integer to int32, extend float to float64 - // this comes from vprintf alignment requirements. - static std::pair - promoteValue(ConversionPatternRewriter &rewriter, Value value) { - auto *context = rewriter.getContext(); - auto type = value.getType(); - Value newOp = value; - Type newType = type; - auto loc = UnknownLoc::get(context); - - bool bUnsigned = type.isUnsignedInteger(); - if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) { - if (bUnsigned) { - newType = ui32_ty; - newOp = zext(newType, value); - } else { - newType = i32_ty; - newOp = sext(newType, value); - } - } else if (type.isBF16() || type.isF16() || type.isF32()) { - newType = f64_ty; - newOp = fpext(newType, value); - } + auto populatePatterns6 = [&](auto populateFunc) { + populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, + allocation, computeCapability, targetInfo, benefit); + }; - return {newType, newOp}; - } + auto populatePatterns7 = [&](auto populateFunc) { + populateFunc(typeConverter, patterns, targetInfo, benefit); + }; - // Returns a Value for the format string, which you can reuse. - static Value llPrintf(StringRef msg, ValueRange args, - ConversionPatternRewriter &rewriter) { - assert(!msg.empty() && "printf with empty string not supported"); - llvm::SmallString<64> msgNewline(msg); - msgNewline.push_back('\n'); - msgNewline.push_back('\0'); - Value msgValue = - LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), - rewriter, "printfFormat_", msgNewline); - llPrintf(msgValue, args, rewriter); - return msgValue; - } + AMD::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo, + patterns, numWarps, + axisInfoAnalysis, benefit); + AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps, + axisInfoAnalysis, benefit); + populatePatterns6(AMD::populateElementwiseOpToLLVMPatterns); + AMD::populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps, + axisInfoAnalysis, benefit); + populatePatterns7(mlir::triton::populateReduceOpToLLVMPatterns); + populatePatterns7(mlir::triton::populateScanOpToLLVMPatterns); + populatePatterns5(mlir::triton::populateViewOpToLLVMPatterns); + populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns); + mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, patterns, + benefit); + mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, patterns, + benefit); + mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, + benefit); + mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, + benefit); + mlir::triton::populateSPMDOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + AMD::populateSPMDOpToLLVMPattern(typeConverter, patterns, benefit); + // TODO(thomas): this should probably be done in a separate step to not + // interfere with our own lowering of arith ops. Add arith/math's patterns + // to help convert scalar expression to LLVM. + mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); + mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); + + // Native lowering patterns + mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, + mlir::gpu::amd::HIP); + + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + patterns); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { + return signalPassFailure(); + } - static void llPrintf(Value msg, ValueRange args, - ConversionPatternRewriter &rewriter) { - auto *ctx = rewriter.getContext(); - Type ptr = ptr_ty(ctx); - auto moduleOp = - rewriter.getBlock()->getParent()->getParentOfType(); - auto funcOp = getVprintfDeclaration(rewriter); - auto loc = UnknownLoc::get(ctx); - - Value one = i32_val(1); - Value zero = i32_val(0); - - Value bufferPtr = null(ptr); - - SmallVector newArgs; - if (args.size() >= 1) { - SmallVector argTypes; - for (auto arg : args) { - Type newType; - Value newArg; - std::tie(newType, newArg) = promoteValue(rewriter, arg); - argTypes.push_back(newType); - newArgs.push_back(newArg); - } - - Type structTy = LLVM::LLVMStructType::getLiteral(ctx, argTypes); - auto allocated = - rewriter.create(loc, ptr_ty(ctx), structTy, one, - /*alignment=*/0); - - for (const auto &entry : llvm::enumerate(newArgs)) { - auto index = i32_val(entry.index()); - auto fieldPtr = - gep(ptr_ty(ctx), structTy, allocated, ArrayRef{zero, index}); - store(entry.value(), fieldPtr); - } - bufferPtr = bitcast(allocated, ptr); + // Fold CTAId when there is only 1 CTA. + if (numCTAs == 1) { + mod.walk([](triton::nvgpu::ClusterCTAIdOp id) { + OpBuilder b(id); + Value zero = LLVM::createConstantI32(id->getLoc(), b, 0); + id.replaceAllUsesWith(zero); + }); } + } - SmallVector operands{msg, bufferPtr}; - call(funcOp, operands); +private: + void initSharedMemory(LLVMTypeConverter &typeConverter) { + ModuleOp mod = getOperation(); + OpBuilder b(mod.getBodyRegion()); + auto ctx = mod.getContext(); + auto loc = mod.getLoc(); + auto elemTy = typeConverter.convertType(b.getIntegerType(8)); + // Set array size 0 and external linkage indicates that we use dynamic + // shared allocation to allow a larger shared memory size for each kernel. + // + // Ask for 16B alignment on global_smem because that's the largest we should + // ever need (4xi32). + auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0); + auto global = b.create( + loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External, + "global_smem", /*value=*/Attribute(), /*alignment=*/16, + // Add ROCm support. + static_cast(NVVM::NVVMMemorySpace::kSharedMemorySpace)); } }; -struct GetProgramIdOpConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +} // anonymous namespace - LogicalResult - matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { +namespace mlir { +namespace triton { -#ifdef USE_ROCM - static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, - mlir::gpu::Dimension::y, - mlir::gpu::Dimension::z}; - Location loc = op->getLoc(); - assert(op.getAxisAsInt() < 3); - - Value blockId = - rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxisAsInt()]); - rewriter.replaceOpWithNewOp(op, i32_ty, blockId); - return success(); -#else - Value programId = llGetPid(op.getAxisAsInt(), op->getLoc(), - op->getParentOfType(), rewriter); - rewriter.replaceOp(op, programId); - return success(); -#endif - } -}; - -struct GetNumProgramsOpConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern< - triton::GetNumProgramsOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, - mlir::gpu::Dimension::y, - mlir::gpu::Dimension::z}; - Location loc = op->getLoc(); - assert(op.getAxis() < 3); - Value blockId = - rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]); - rewriter.replaceOpWithNewOp(op, i32_ty, blockId); - return success(); - } -}; -} // namespace - -namespace AMD { -void populateTritonGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, int numWarps, - ModuleAxisInfoAnalysis &axisInfoAnalysis, - ModuleAllocation &moduleAllocation, - PatternBenefit benefit) { - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - - mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, patterns, benefit); - mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, patterns, - benefit); - mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, benefit); +std::unique_ptr> createConvertTritonAMDGPUToLLVMPass() { + return std::make_unique(90); } -} // namespace AMD +} // namespace triton +} // namespace mlir diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVMPass.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVMPass.cpp deleted file mode 100644 index 26cc87e183..0000000000 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVMPass.cpp +++ /dev/null @@ -1,503 +0,0 @@ -#include "TritonAMDGPUToLLVM/Passes.h" - -#include "TargetInfo.h" -#include "Utility.h" -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" -#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" -#include "mlir/Conversion/LLVMCommon/VectorPattern.h" -#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -#include "triton/Tools/Sys/GetPlatform.hpp" - -#include "PatternTritonGPUOpToLLVM.h" -#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" - -#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" - -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_CONVERTTRITONAMDGPUTOLLVM -#include "TritonAMDGPUToLLVM/Passes.h.inc" -} // namespace triton -} // namespace mlir - -namespace mlir { -FailureOr -convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, - ConversionPatternRewriter &rewriter, - const LLVMTypeConverter &converter); -} - -using namespace mlir; -using namespace mlir::triton; - -namespace { - -// pass ws related named attrs. -static void addWSNamedAttrs(Operation *op, - ArrayRef attrs) { - for (const NamedAttribute attr : attrs) - if (attr.getName() == "async_agent" || attr.getName() == "agent.mutex_role") - op->setAttr(attr.getName(), attr.getValue()); -} - -#ifdef USE_ROCM -constexpr int LDSSize = 65536; -constexpr int kPtrBitWidth = 64; -#endif -class TritonLLVMFunctionConversionTarget : public ConversionTarget { -public: - explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalOp(); - } -}; - -struct ReturnOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto funcOp = op->getParentOfType(); - if (funcOp->hasAttr("nvvm.kernel")) { - // A GPU kernel - if (op.getNumOperands() > 0) { - return rewriter.notifyMatchFailure( - op, "Kernel functions do not support return with operands"); - } - rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), - op->getAttrs()); - } else { - // A device function - LLVM::ReturnOp newOp; - if (adaptor.getOperands().size() < 2) { - // Single or no return value. - newOp = - rewriter.create(op.getLoc(), adaptor.getOperands()); - } else { - // Pack the results into a struct. - auto packedResultsTy = this->getTypeConverter()->packFunctionResults( - funcOp.getResultTypes()); - Value packedResults = - rewriter.create(op.getLoc(), packedResultsTy); - auto loc = op.getLoc(); - for (auto it : llvm::enumerate(adaptor.getOperands())) { - packedResults = insert_val(packedResultsTy, packedResults, it.value(), - it.index()); - } - newOp = rewriter.create(op.getLoc(), packedResults); - } - newOp->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newOp->getResults()); - } - return success(); - } -}; - -/// FuncOp legalization pattern that converts MemRef arguments to pointers to -/// MemRef descriptors (LLVM struct data types) containing all the MemRef type -/// information. -struct FuncOpConversion : public ConvertOpToLLVMPattern { - FuncOpConversion(LLVMTypeConverter &converter, int numWarps, - PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit), numWarps(numWarps) {} - - /// Only retain those attributes that are not constructed by - /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument - /// attributes. - static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, - SmallVectorImpl &result) { - - for (const auto &attr : op->getAttrs()) { - if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == op.getFunctionTypeAttrName() || - attr.getName() == "std.varargs" || - (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) - continue; - result.push_back(attr); - } - } - - triton::FuncOp amendFuncOp(triton::FuncOp funcOp, - ConversionPatternRewriter &rewriter) const { - // Push back a variable that indicates the current stack pointer of shared - // memory to the function arguments. - auto loc = funcOp.getLoc(); - auto ctx = funcOp->getContext(); - auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - // 1. Modify the function type to add the new argument. - auto funcTy = funcOp.getFunctionType(); - auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); - amendedInputTy.push_back(ptrTy); - auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, - funcTy.getResults()); - // 2. Modify the argument attributes to add the new argument. - SmallVector amendedAttrs; - filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); - auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); - amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); - amendedAttrs.push_back(rewriter.getNamedAttr( - funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); - // 3. Add a new argument to the region - auto amendedFuncOp = rewriter.create( - funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); - auto ®ion = funcOp.getBody(); - region.addArgument(ptrTy, loc); - rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), - amendedFuncOp.end()); - return amendedFuncOp; - } - - LogicalResult - matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Prevent LLVM's inliner to inline this function - auto amendedFuncOp = funcOp; - if (!LLVM::isKernel(funcOp)) - amendedFuncOp = amendFuncOp(funcOp, rewriter); - - LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( - amendedFuncOp, rewriter, *getTypeConverter()); - if (!newFuncOp) { - return failure(); - } - - auto ctx = funcOp->getContext(); - - if (LLVM::isKernel(funcOp)) { - // Set an attribute to indicate this function is a kernel entry. - newFuncOp->setAttr("nvvm.kernel", - rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); - } else { - // The noinline attribute will be used by the LLVM codegen to prevent - // inlining. - // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267 - newFuncOp.setPassthroughAttr( - ArrayAttr::get(ctx, rewriter.getStringAttr("noinline"))); - rewriter.eraseOp(amendedFuncOp); - } - // Set an attribute for maxntidx, it could be used in latter LLVM codegen - // for `nvvm.annotation` metadata. - newFuncOp->setAttr("nvvm.maxntid", - rewriter.getDenseI32ArrayAttr(32 * numWarps)); - - // required by AxisInfoAnalysis - rewriter.eraseOp(funcOp); - return success(); - } - -private: - int numWarps{0}; -}; - -// CallOpInterfaceLowering is adapted from -// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485 -struct CallOpConversion : public ConvertOpToLLVMPattern { - CallOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit) {} - - LogicalResult - matchAndRewrite(triton::CallOp callOp, - typename triton::CallOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto promotedOperands = promoteOperands(callOp, adaptor, rewriter); - auto newCallOp = - convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter); - if (!newCallOp) - return failure(); - auto results = getCallOpResults(callOp, newCallOp, rewriter); - rewriter.replaceOp(callOp, results); - return success(); - } - -private: - SmallVector - promoteOperands(triton::CallOp callOp, - typename triton::CallOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // Get the last argument of the caller, which is the current stack pointer - // of shared memory and append it to the operands of the callOp. - auto loc = callOp.getLoc(); - auto caller = callOp->getParentOfType(); - auto promotedOperands = this->getTypeConverter()->promoteOperands( - callOp.getLoc(), /*opOperands=*/callOp->getOperands(), - adaptor.getOperands(), rewriter); - if (!caller->hasAttr("allocation.offset")) { - auto base = LLVM::getStackPointer(rewriter, caller); - promotedOperands.push_back(base); - return promotedOperands; - } - promotedOperands.push_back( - LLVM::getSharedMemoryBase(callOp->getLoc(), rewriter, callOp)); - return promotedOperands; - } - - LLVM::CallOp - convertCallOpToLLVMCallOp(triton::CallOp callOp, - ArrayRef promotedOperands, - ConversionPatternRewriter &rewriter) const { - // Pack the result types into a struct. - Type packedResult = nullptr; - unsigned numResults = callOp.getNumResults(); - auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); - - if (numResults != 0) { - if (!(packedResult = - this->getTypeConverter()->packFunctionResults(resultTypes))) - return nullptr; - } - auto newCallOp = rewriter.create( - callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), - promotedOperands, callOp->getAttrs()); - return newCallOp; - } - - SmallVector - getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp, - ConversionPatternRewriter &rewriter) const { - auto numResults = callOp.getNumResults(); - SmallVector results; - if (numResults < 2) { - // If < 2 results, packing did not do anything and we can just return. - results.append(newCallOp.result_begin(), newCallOp.result_end()); - } else { - // Otherwise, it had been converted to an operation producing a structure. - // Extract individual results from the structure and return them as list. - results.reserve(numResults); - for (unsigned i = 0; i < numResults; ++i) { - results.push_back(rewriter.create( - callOp.getLoc(), newCallOp->getResult(0), i)); - } - } - return results; - } -}; - -class TritonLLVMConversionTarget : public ConversionTarget { -public: - explicit TritonLLVMConversionTarget(MLIRContext &ctx) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addIllegalDialect(); - addIllegalDialect(); - addIllegalDialect(); - addIllegalDialect(); - addLegalOp(); - } -}; - -struct ConvertTritonAMDGPUToLLVM - : public triton::impl::ConvertTritonAMDGPUToLLVMBase< - ConvertTritonAMDGPUToLLVM> { - using ConvertTritonAMDGPUToLLVMBase< - ConvertTritonAMDGPUToLLVM>::ConvertTritonAMDGPUToLLVMBase; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - ConvertTritonAMDGPUToLLVM(int32_t computeCapability) - : ConvertTritonAMDGPUToLLVMBase({computeCapability}) {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - mlir::LowerToLLVMOptions option(context); - option.overrideIndexBitwidth(32); - TritonGPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMConversionTarget convTarget(*context); - int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); - int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); - int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - - // Hack: WSMaterialization may have changed the effective number of warps, - // in a way that isn't reflected in triton_gpu.num-warps. If so, we have to - // respect that here. - if (Attribute attr = mod->getAttr("triton_gpu.num-warp-groups-per-cta")) { - numWarps *= attr.cast().getInt(); - } - - // Allocate shared memory and set barrier - ModuleAllocation allocation(mod); - ModuleMembarAnalysis membarPass(&allocation); - membarPass.run(); - - // Lower functions - { - mlir::LowerToLLVMOptions option(context); - TritonGPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMFunctionConversionTarget funcTarget(*context); - RewritePatternSet funcPatterns(context); - funcPatterns.add(typeConverter, numWarps, - patternBenefitDefault); - mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, - funcPatterns); - if (failed( - applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) - return signalPassFailure(); - } - - // initSharedMemory is run before the conversion of call and ret ops, - // because the call op has to know the shared memory base address of each - // function - initSharedMemory(typeConverter); - - // Convert call and ret ops - { - mlir::LowerToLLVMOptions option(context); - TritonGPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMFunctionConversionTarget funcTarget(*context); - RewritePatternSet funcPatterns(context); - funcPatterns.add(typeConverter, patternBenefitDefault); - funcPatterns.add(typeConverter, - patternBenefitDefault); - if (failed( - applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) - return signalPassFailure(); - } - - ModuleAxisInfoAnalysis axisInfoAnalysis(mod); - - // Emit logics to get threadId/blockIds/linearized clusterCTAId etc. and - // cache the values. The reason to do it here is that cluster_ctaid is - // currently implemented via inline asm, and thus cannot be CSEed. - // clusterCTAId will be emitted only when numCTAs is larger than 1, and - // other values will be DCEed if not used hereafter. - OpBuilder::InsertPoint indexInsertPoint; - - RewritePatternSet patterns(context); - AMD::TargetInfo targetInfo("gfx1200"); - int benefit = patternBenefitPrioritizeOverLLVMConversions; - auto populatePatterns1 = [&](auto populateFunc) { - populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, - allocation, benefit); - }; - - auto populatePatterns2 = [&](auto populateFunc) { - populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, - allocation, benefit); - }; - - auto populatePatterns3 = [&](auto populateFunc) { - populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, - allocation, benefit); - }; - - auto populatePatterns4 = [&](auto populateFunc) { - populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, - allocation, computeCapability, benefit); - }; - - auto populatePatterns5 = [&](auto populateFunc) { - populateFunc(typeConverter, patterns, benefit); - }; - - auto populatePatterns6 = [&](auto populateFunc) { - populateFunc(typeConverter, patterns, numWarps, axisInfoAnalysis, - allocation, computeCapability, targetInfo, benefit); - }; - - auto populatePatterns7 = [&](auto populateFunc) { - populateFunc(typeConverter, patterns, targetInfo, benefit); - }; - - populatePatterns1(AMD::populateTritonGPUToLLVMPatterns); - AMD::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo, - patterns, numWarps, - axisInfoAnalysis, benefit); - AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps, - axisInfoAnalysis, benefit); - populatePatterns6(AMD::populateElementwiseOpToLLVMPatterns); - AMD::populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps, - axisInfoAnalysis, benefit); - populatePatterns7(mlir::triton::populateReduceOpToLLVMPatterns); - populatePatterns7(mlir::triton::populateScanOpToLLVMPatterns); - populatePatterns5(mlir::triton::populateViewOpToLLVMPatterns); - populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns); - - // TODO(thomas): this should probably be done in a separate step to not - // interfere with our own lowering of arith ops. Add arith/math's patterns - // to help convert scalar expression to LLVM. - mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); - mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); - - // Native lowering patterns - mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, - mlir::gpu::amd::HIP); - - mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, - patterns); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { - return signalPassFailure(); - } - - // Fold CTAId when there is only 1 CTA. - if (numCTAs == 1) { - mod.walk([](triton::nvgpu::ClusterCTAIdOp id) { - OpBuilder b(id); - Value zero = LLVM::createConstantI32(id->getLoc(), b, 0); - id.replaceAllUsesWith(zero); - }); - } - } - -private: - void initSharedMemory(LLVMTypeConverter &typeConverter) { - ModuleOp mod = getOperation(); - OpBuilder b(mod.getBodyRegion()); - auto ctx = mod.getContext(); - auto loc = mod.getLoc(); - auto elemTy = typeConverter.convertType(b.getIntegerType(8)); - // Set array size 0 and external linkage indicates that we use dynamic - // shared allocation to allow a larger shared memory size for each kernel. - // - // Ask for 16B alignment on global_smem because that's the largest we should - // ever need (4xi32). - auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0); - auto global = b.create( - loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External, - "global_smem", /*value=*/Attribute(), /*alignment=*/16, - // Add ROCm support. - static_cast(NVVM::NVVMMemorySpace::kSharedMemorySpace)); - } -}; - -} // anonymous namespace - -namespace mlir { -namespace triton { - -std::unique_ptr> createConvertTritonAMDGPUToLLVMPass() { - return std::make_unique(90); -} - -} // namespace triton -} // namespace mlir diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index ba660bc3f6..1dd16fb619 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -130,6 +130,18 @@ Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, return shuffleCommon(loc, rewriter, val, i, 0, ShflKind::idx, i32_val(0x1f)); } +Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, + ModuleOp moduleOp, int axis) { + assert(axis >= 0); + assert(axis < 3); + assert(moduleOp); + static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, + mlir::gpu::Dimension::y, + mlir::gpu::Dimension::z}; + Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[axis]); + return rewriter.create(loc, i32_ty, blockId); +} + } // namespace AMD } // namespace LLVM diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index 2872c79f53..cd61c46500 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -21,6 +21,9 @@ Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, int i); Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, Value i); + +Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, + ModuleOp moduleOp, int axis); } // namespace AMD } // namespace LLVM diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt index a827be3934..05c962a566 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt @@ -11,14 +11,13 @@ add_triton_library(TritonNVIDIAGPUToLLVM BarrierOpToLLVM.cpp TritonGPUToLLVM.cpp DecomposeUnsupportedConversions.cpp - PrintOpToLLVM.cpp - ControlFlowOpToLLVM.cpp SPMDOpToLLVM.cpp TensorPtrOpsToLLVM.cpp ClusterOpsToLLVM.cpp PTXAsmFormat.cpp Utility.cpp TargetInfo.cpp + PrintOpToLLVM.cpp DEPENDS TritonNVIDIAGPUConversionPassIncGen diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h index 87001a7552..dd85f4fc60 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -45,14 +45,6 @@ void populateTensorPtrOpsToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); -void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - PatternBenefit benefit); - -void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - PatternBenefit benefit); - void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); @@ -62,6 +54,10 @@ void populateClampFOpToLLVMPattern(LLVMTypeConverter &typeConverter, ModuleAxisInfoAnalysis &axisInfoAnalysis, int computeCapability, PatternBenefit benefit); + +void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); } // namespace NVIDIA } // namespace triton } // namespace mlir diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PrintOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PrintOpToLLVM.cpp index b9e5924037..3fbef1a365 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PrintOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PrintOpToLLVM.cpp @@ -18,14 +18,13 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto typeConverter = getTypeConverter(); auto loc = op->getLoc(); Value prefixStr = LLVM::addStringToModule(loc, rewriter, "printfPrefix_", op.getPrefix()); auto getPid = [&](int axis) { - return LLVM::NVIDIA::llGetPid(axis, loc, op->getParentOfType(), - rewriter); + return LLVM::NVIDIA::llGetPid(loc, rewriter, + op->getParentOfType(), axis); }; std::array pid = {getPid(0), getPid(1), getPid(2)}; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp index f675c09305..7958a889fe 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp @@ -6,21 +6,6 @@ namespace { using namespace mlir; using namespace mlir::triton; -struct GetProgramIdOpConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value programId = - LLVM::NVIDIA::llGetPid(op.getAxisAsInt(), op->getLoc(), - op->getParentOfType(), rewriter); - rewriter.replaceOp(op, programId); - return success(); - } -}; - struct GetNumProgramsOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -66,7 +51,6 @@ struct GetClusterCTAIdOpConversion void mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 7153a5ee02..e13380319f 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -209,6 +209,10 @@ Value TargetInfo::shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, return LLVM::NVIDIA::shuffleIdx(loc, rewriter, val, i); } +Value TargetInfo::programId(Location loc, ConversionPatternRewriter &rewriter, + ModuleOp moduleOp, int axis) const { + return LLVM::NVIDIA::llGetPid(loc, rewriter, moduleOp, axis); +} bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce) const { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index bd9819a508..c817cddb35 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -20,6 +20,8 @@ class TargetInfo : public mlir::triton::TargetInfoBase { int i) const override; Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, Value i) const override; + Value programId(Location loc, ConversionPatternRewriter &rewriter, + ModuleOp moduleOp, int axis) const override; bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce) const override; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index a7a56e0445..c891e784d1 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -34,13 +34,6 @@ namespace triton { } // namespace triton } // namespace mlir -namespace mlir { -FailureOr -convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, - ConversionPatternRewriter &rewriter, - const LLVMTypeConverter &converter); -} - using namespace mlir; using namespace mlir::triton::NVIDIA; namespace ttng = mlir::triton::nvidia_gpu; @@ -64,103 +57,6 @@ class TritonLLVMFunctionConversionTarget : public ConversionTarget { } }; -/// FuncOp legalization pattern that converts MemRef arguments to pointers to -/// MemRef descriptors (LLVM struct data types) containing all the MemRef type -/// information. - -struct FuncOpConversion : public ConvertOpToLLVMPattern { - FuncOpConversion(LLVMTypeConverter &converter, int numWarps, - PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit), numWarps(numWarps) {} - - /// Only retain those attributes that are not constructed by - /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument - /// attributes. - static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, - SmallVectorImpl &result) { - - for (const auto &attr : op->getAttrs()) { - if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == op.getFunctionTypeAttrName() || - attr.getName() == "std.varargs" || - (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) - continue; - result.push_back(attr); - } - } - - triton::FuncOp amendFuncOp(triton::FuncOp funcOp, - ConversionPatternRewriter &rewriter) const { - // Push back a variable that indicates the current stack pointer of shared - // memory to the function arguments. - auto loc = funcOp.getLoc(); - auto ctx = funcOp->getContext(); - auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - // 1. Modify the function type to add the new argument. - auto funcTy = funcOp.getFunctionType(); - auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); - amendedInputTy.push_back(ptrTy); - auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, - funcTy.getResults()); - // 2. Modify the argument attributes to add the new argument. - SmallVector amendedAttrs; - filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); - auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); - amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); - amendedAttrs.push_back(rewriter.getNamedAttr( - funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); - // 3. Add a new argument to the region - auto amendedFuncOp = rewriter.create( - funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); - auto ®ion = funcOp.getBody(); - region.addArgument(ptrTy, loc); - rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), - amendedFuncOp.end()); - return amendedFuncOp; - } - - LogicalResult - matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Prevent LLVM's inliner to inline this function - auto amendedFuncOp = funcOp; - if (!LLVM::isKernel(funcOp)) - amendedFuncOp = amendFuncOp(funcOp, rewriter); - - LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( - amendedFuncOp, rewriter, *getTypeConverter()); - if (!newFuncOp) { - return failure(); - } - - auto ctx = funcOp->getContext(); - - if (LLVM::isKernel(funcOp)) { - // Set an attribute to indicate this function is a kernel entry. - newFuncOp->setAttr("nvvm.kernel", - rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); - } else { - // The noinline attribute will be used by the LLVM codegen to prevent - // inlining. - // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267 - newFuncOp.setPassthroughAttr( - ArrayAttr::get(ctx, rewriter.getStringAttr("noinline"))); - rewriter.eraseOp(amendedFuncOp); - } - // Set an attribute for maxntidx, it could be used in latter LLVM codegen - // for `nvvm.annotation` metadata. - newFuncOp->setAttr("nvvm.maxntid", - rewriter.getDenseI32ArrayAttr(32 * numWarps)); - - // required by AxisInfoAnalysis - rewriter.eraseOp(funcOp); - return success(); - } - -private: - int numWarps{0}; -}; - class TritonLLVMConversionTarget : public ConversionTarget { public: explicit TritonLLVMConversionTarget(MLIRContext &ctx) @@ -211,8 +107,8 @@ struct ConvertTritonGPUToLLVM TritonGPUToLLVMTypeConverter typeConverter(context, option); TritonLLVMFunctionConversionTarget funcTarget(*context); RewritePatternSet funcPatterns(context); - funcPatterns.add(typeConverter, numWarps, - patternBenefitDefault); + mlir::triton::populateFuncOpConversionPattern( + typeConverter, funcPatterns, numWarps, patternBenefitDefault); mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, funcPatterns); if (failed( @@ -254,8 +150,12 @@ struct ConvertTritonGPUToLLVM mlir::triton::populateHistogramOpToLLVMPatterns(typeConverter, patterns, targetInfo, benefit); populatePrintOpToLLVMPattern(typeConverter, patterns, benefit); - populateControlFlowOpToLLVMPattern(typeConverter, patterns, benefit); - populateSPMDOpToLLVMPattern(typeConverter, patterns, benefit); + mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, + benefit); + mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern(typeConverter, patterns, + benefit); + mlir::triton::populateSPMDOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); // TODO(thomas): this should probably be done in a separate step to not // interfere with our own lowering of arith ops. Add arith/math's patterns // to help convert scalar expression to LLVM. diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp index c5c02bfa55..42d4daeaf7 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp @@ -64,8 +64,8 @@ Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, i32_val(0x1f)); } -Value llGetPid(int axis, Location loc, ModuleOp moduleOp, - ConversionPatternRewriter &rewriter) { +Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, + ModuleOp moduleOp, int axis) { assert(axis >= 0); assert(axis < 3); assert(moduleOp); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h index 4e1ba8ca0c..816c8b599a 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h @@ -44,8 +44,8 @@ Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, Value i); -Value llGetPid(int axis, Location loc, ModuleOp moduleOp, - ConversionPatternRewriter &rewriter); +Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, + ModuleOp moduleOp, int axis); } // namespace NVIDIA } // namespace LLVM From 730276f066006df398b70d98c16348a307940c8c Mon Sep 17 00:00:00 2001 From: pawelszczerbuk <153013546+pawelszczerbuk@users.noreply.github.com> Date: Mon, 11 Mar 2024 11:17:50 -0700 Subject: [PATCH 7/9] [DEBUG] Override fix - Turn FileCacheManager.has_file back into an interface method (#3336) https://github.com/openai/triton/pull/2934 have changed some of the FileCacheManager methods into a private interface, while `_has_file` method is still being used by the Override functionality. Exposing the method again. (An alternative is to call private `_has_file` from the override, which would be OK too, since it is debug only thing). --- python/triton/runtime/cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/triton/runtime/cache.py b/python/triton/runtime/cache.py index 65a000b038..2e8d70ea4e 100644 --- a/python/triton/runtime/cache.py +++ b/python/triton/runtime/cache.py @@ -68,20 +68,20 @@ def __init__(self, key, override=False, dump=False): def _make_path(self, filename) -> str: return os.path.join(self.cache_dir, filename) - def _has_file(self, filename) -> bool: + def has_file(self, filename) -> bool: if not self.cache_dir: raise RuntimeError("Could not create or locate cache dir") return os.path.exists(self._make_path(filename)) def get_file(self, filename) -> Optional[str]: - if self._has_file(filename): + if self.has_file(filename): return self._make_path(filename) else: return None def get_group(self, filename: str) -> Optional[Dict[str, str]]: grp_filename = f"__grp__{filename}" - if not self._has_file(grp_filename): + if not self.has_file(grp_filename): return None grp_filepath = self._make_path(grp_filename) with open(grp_filepath) as f: From eb613cdb2aeec8c2b3ba4a63fc9f0d6802eec53a Mon Sep 17 00:00:00 2001 From: Ilya V <152324710+joviliast@users.noreply.github.com> Date: Mon, 11 Mar 2024 20:15:45 +0100 Subject: [PATCH 8/9] [AMD][Navi31] Add emitIndices logic for WMMA layout (#3170) -Add emitWmmaOffsetForCTA -Add emitBaseIndexForWmmaLayout -Add emitOffsetForWmmaLayout -Fix some helper methods Signed-off-by: joviliast --- .../Conversion/TritonGPUToLLVM/Utility.h | 71 +++++++++++++++++++ lib/Conversion/TritonGPUToLLVM/Utility.cpp | 11 ++- 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index b5af132720..f254e72932 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -418,6 +418,7 @@ using LLVM::SharedMemoryObject; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::SharedMemoryObject; using ::mlir::triton::gpu::AMDMfmaEncodingAttr; +using ::mlir::triton::gpu::AMDWmmaEncodingAttr; using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::CTALayoutAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; @@ -848,6 +849,71 @@ emitOffsetForMfmaLayout(const AMDMfmaEncodingAttr &mfmaLayout, return offsets; } +static void emitWmmaOffsetForCTA(const AMDWmmaEncodingAttr &wmmaLayout, + SmallVector> &offsets, + unsigned ctaOffsetX, unsigned ctaOffsetY) { + const unsigned elemsPerThreadPerGroup = 8; + auto warpSize = getWarpSize(wmmaLayout); + assert(warpSize == 32); + auto shapePerCta = getShapePerCTATile(wmmaLayout); + for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) { + offsets.push_back( + {ctaOffsetX * shapePerCta[0] + 2 * elem, ctaOffsetY * shapePerCta[1]}); + } +} + +static SmallVector +emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter, + const AMDWmmaEncodingAttr &wmmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto _warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + assert(_warpsPerCTA.size() == 2); + SmallVector warpsPerCTA = {i32_val(_warpsPerCTA[0]), + i32_val(_warpsPerCTA[1])}; + auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(triton::gpu::getWarpSize(wmmaLayout)); + Value laneId = + urem(threadId, i32_val(triton::gpu::getWarpSize(wmmaLayout) / 2)); + Value threadIdPerWarp = urem(threadId, warpSize); + + Value warpId = udiv(threadId, warpSize); + Value warpId0 = urem(warpId, warpsPerCTA[0]); + Value warpId1 = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]); + + Value offWarp0 = mul(warpId0, i32_val(mnkDim[0])); + Value offWarp1 = mul(warpId1, i32_val(mnkDim[1])); + + return {add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0), + add(laneId, offWarp1)}; +} + +static SmallVector> +emitOffsetForWmmaLayout(const AMDWmmaEncodingAttr &wmmaLayout, + RankedTensorType type) { + auto tensorShape = type.getShape(); + SmallVector> offsets; + auto shapePerCTA = getShapePerCTA(wmmaLayout, tensorShape); + auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + + SmallVector numWarpsPerDim(2); + auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); + for (unsigned d = 0; d < 2; ++d) { + unsigned inPerCTA = std::min(tensorShape[d], shapePerCTA[d]); + unsigned inPerWarp = ceil(inPerCTA, warpsPerCTA[d]); + numWarpsPerDim[d] = ceil(inPerWarp, mnkDim[d]); + } + + for (unsigned i = 0; i < numWarpsPerDim[0]; ++i) { + for (unsigned j = 0; j < numWarpsPerDim[1]; ++j) { + emitWmmaOffsetForCTA(wmmaLayout, offsets, i, j); + } + } + return offsets; +} + static SmallVector> emitOffsetForLayout(Attribute layout, RankedTensorType type); @@ -932,6 +998,8 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, Attribute layout, type); } else if (auto mfmaLayout = layout.dyn_cast()) { result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type); + } else if (auto wmmaLayout = layout.dyn_cast()) { + result = emitBaseIndexForWmmaLayout(loc, rewriter, wmmaLayout, type); } else if (auto sliceLayout = layout.dyn_cast()) { auto parentLayout = sliceLayout.getParent(); auto parentShape = sliceLayout.paddedShape(type.getShape()); @@ -969,6 +1037,9 @@ emitOffsetForLayout(Attribute layout, RankedTensorType type) { if (auto mfmaLayout = layout.dyn_cast()) { return emitOffsetForMfmaLayout(mfmaLayout, type); } + if (auto wmmaLayout = layout.dyn_cast()) { + return emitOffsetForWmmaLayout(wmmaLayout, type); + } if (auto sliceLayout = layout.dyn_cast()) return emitOffsetForSliceLayout(sliceLayout, type); llvm_unreachable("unsupported emitOffsetForLayout"); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 25f6dbfc07..82d5a47cdc 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -523,14 +523,19 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, } return multiDimOffset; } - if (auto mfmaLayout = layout.dyn_cast()) { + if (layout.isa()) { auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type, false); SmallVector> offsets; assert(rank == 2); SmallVector multiDimOffset(rank); - emitMfmaOffsetForCTA(mfmaLayout, offsets, multiDimCTAInRepId[0], - multiDimCTAInRepId[1]); + if (auto mfmaLayout = layout.dyn_cast()) { + emitMfmaOffsetForCTA(mfmaLayout, offsets, multiDimCTAInRepId[0], + multiDimCTAInRepId[1]); + } else if (auto wmmaLayout = layout.dyn_cast()) { + emitWmmaOffsetForCTA(wmmaLayout, offsets, multiDimCTAInRepId[0], + multiDimCTAInRepId[1]); + } multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0])); multiDimOffset[1] = add(multiDimBase[1], i32_val(offsets[elemId][1])); return multiDimOffset; From d04f28864d1c1e6a3e0d6f16c4aa701c84310d4a Mon Sep 17 00:00:00 2001 From: Ilya V <152324710+joviliast@users.noreply.github.com> Date: Mon, 11 Mar 2024 21:20:54 +0100 Subject: [PATCH 9/9] [AMD]Enable EmitIndicesTest (#3167) Allows to test emmitting indices from different layouts. Valid also for AMD target device. Signed-off-by: joviliast --- .../amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt | 1 - unittest/CMakeLists.txt | 4 +- .../Conversion/TritonGPUToLLVM/CMakeLists.txt | 10 +- .../Conversion/TritonGPUToLLVM/DumpLayout.cpp | 56 +++--- .../TritonGPUToLLVM/EmitIndicesTest.cpp | 170 +++++++++--------- 5 files changed, 134 insertions(+), 107 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index 012290e4fe..167cc1cecd 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -21,4 +21,3 @@ add_triton_library(TritonAMDGPUToLLVM ) target_compile_definitions(TritonAMDGPUToLLVM PUBLIC USE_ROCM) - diff --git a/unittest/CMakeLists.txt b/unittest/CMakeLists.txt index f0c792d2aa..62d80e838a 100644 --- a/unittest/CMakeLists.txt +++ b/unittest/CMakeLists.txt @@ -10,7 +10,7 @@ get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) function(add_triton_ut) set(options) set(oneValueArgs NAME) - set(multiValueArgs SRCS LIBS) + set(multiValueArgs SRCS LIBS DEFS) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) add_test(NAME ${__NAME} COMMAND ${__NAME}) @@ -29,6 +29,8 @@ function(add_triton_ut) target_compile_options(${__NAME} PRIVATE -fno-rtti) + target_compile_definitions(${__NAME} PRIVATE ${__DEFS}) + # Without the TEST_DISCOVERY_TIMEOUT, the tests randomly time out on my mac # laptop. I think the issue may be that the very first time you run a program # it's a bit slow. diff --git a/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt b/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt index d148d007b2..3c5692a626 100644 --- a/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -5,7 +5,15 @@ add_triton_ut( ) add_triton_ut( - NAME TestEmitIndices + NAME TestEmitIndicesNvidia SRCS EmitIndicesTest.cpp DumpLayout.cpp LIBS TritonGPUIR TritonNvidiaGPUIR TritonNVIDIAGPUToLLVM + DEFS NVIDIA_TARGET=1 +) + +add_triton_ut( + NAME TestEmitIndicesAMD + SRCS EmitIndicesTest.cpp DumpLayout.cpp + LIBS TritonGPUIR TritonAMDGPUToLLVM + DEFS AMD_TARGET=1 ) diff --git a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp b/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp index db3fd57106..45e720a587 100644 --- a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp +++ b/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp @@ -22,16 +22,33 @@ */ #include "DumpLayout.h" - +#ifdef AMD_TARGET +#include "amd/lib/TritonAMDGPUToLLVM/Utility.h" +#else #include "nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h" -#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" - +#endif namespace mlir { namespace triton { namespace gpu { namespace { +#ifdef AMD_TARGET +Value getMockSmemBaseImpl([[maybe_unused]] IRRewriter &rewriter, + [[maybe_unused]] Location loc) { + return i32_val(0); +} +#else +Value getMockSmemBaseImpl(IRRewriter &rewriter, Location loc) { + Value mockSmemBase = + LLVM::NVIDIA::getSRegValue(rewriter, loc, "%mock_smem_base"); + auto llPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); + auto cast = rewriter.create( + loc, TypeRange{llPtrTy}, ValueRange{mockSmemBase}); + return cast.getResult(0); +} +#endif + //===----------------------------------------------------------------------===// // IndexEmitter //===----------------------------------------------------------------------===// @@ -39,9 +56,17 @@ namespace { class IndexEmitter { public: IndexEmitter(MLIRContext *context_) - : context(context_), option(context), typeConverter(context, option), - rewriter(context), loc(UnknownLoc::get(context)) { - rewriter.setInsertionPointToStart(&block); + : context(context_), option(context), rewriter(context), + loc(UnknownLoc::get(context)) { + mlir::OpBuilder builder(context); + std::vector inTypes{}; + std::vector outTypes{}; + auto funcTy = builder.getFunctionType(inTypes, outTypes); + auto func = builder.create(loc, "test_func", funcTy); + auto mlirModule = mlir::ModuleOp::create(loc); + mlirModule.push_back(func); + auto *block = func.addEntryBlock(); + rewriter.setInsertionPointToStart(block); } llvm::SmallVector> @@ -56,28 +81,17 @@ class IndexEmitter { Type elemTy, llvm::ArrayRef shape, bool withCTAOffset) { auto srcTy = RankedTensorType::get(shape, elemTy, srcLayout); - SharedMemoryObject smemObj(getMockSmemBase(), elemTy, shape, - sharedLayout.getOrder(), loc, rewriter); + SharedMemoryObject smemObj(getMockSmemBaseImpl(rewriter, loc), elemTy, + shape, sharedLayout.getOrder(), loc, rewriter); return getSwizzledSharedPtrs(loc, /*inVec=*/1, srcTy, sharedLayout, elemTy, smemObj, rewriter, smemObj.offsets, smemObj.strides); } private: - Value getMockSmemBase() { - Value mockSmemBase = - mlir::LLVM::NVIDIA::getSRegValue(rewriter, loc, "%mock_smem_base"); - auto llPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - auto cast = rewriter.create( - loc, TypeRange{llPtrTy}, ValueRange{mockSmemBase}); - return cast.getResult(0); - } - // Non-static members are initialized in declaration order MLIRContext *context; LowerToLLVMOptions option; - TritonGPUToLLVMTypeConverter typeConverter; - Block block; IRRewriter rewriter; Location loc; }; @@ -151,6 +165,8 @@ int eval(Value value, int ctaid, int tid) { return eval(xorOp.getLhs(), ctaid, tid) ^ eval(xorOp.getRhs(), ctaid, tid); } else if (auto trunciOp = llvm::dyn_cast(op)) { return eval(trunciOp.getIn(), ctaid, tid); + } else if (auto idxCastOp = llvm::dyn_cast(op)) { + return eval(idxCastOp.getIn(), ctaid, tid); } else if (auto castOp = llvm::dyn_cast(op)) { return eval(castOp.getOperand(0), ctaid, tid); } else if (auto threadOp = llvm::dyn_cast(op)) { @@ -181,7 +197,7 @@ std::string dumpDistributedLayout(Attribute layout, assert(shape.size() <= 2 && "High order tensor is not supported in dumpLayout"); - int numThreads = 32 * getNumWarpsPerCTA(layout); + int numThreads = getWarpSize(layout) * getNumWarpsPerCTA(layout); int numCTAs = getNumCTAs(layout); auto f16Ty = FloatType::getF16(layout.getContext()); int numElems = getTotalElemsPerThread(layout, shape, f16Ty); diff --git a/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp b/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp index 62001960f1..f123988762 100644 --- a/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp +++ b/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp @@ -312,6 +312,8 @@ TEST_F(EmitIndicesTest, BlockedLayout_SingleCTA_Vectorize) { /*order=*/{1, 0}, /*refStr=*/refStr); } +// FIXME: These tests are temporarily disabled due to ctaid.x|y|z are swapped +#ifdef TEST_FAILED TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAOrder_1_0) { // clang-format off std::string refStr = @@ -476,6 +478,89 @@ TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAWrapBeforeBroadcast_Dim0) { /*refStr=*/refStr); } +TEST_F(EmitIndicesTest, SliceLayout_MultiCTA) { + // clang-format off + std::string refStr = + "CTA0: T0:0|CTA0: T1:0|CTA0: T2:0|CTA0: T3:0 | CTA1: T0:0|CTA1: T1:0|CTA1: T2:0|CTA1: T3:0," + "CTA0: T4:0|CTA0: T5:0|CTA0: T6:0|CTA0: T7:0 | CTA1: T4:0|CTA1: T5:0|CTA1: T6:0|CTA1: T7:0," + "CTA0: T8:0|CTA0: T9:0|CTA0:T10:0|CTA0:T11:0 | CTA1: T8:0|CTA1: T9:0|CTA1:T10:0|CTA1:T11:0," + "CTA0:T12:0|CTA0:T13:0|CTA0:T14:0|CTA0:T15:0 | CTA1:T12:0|CTA1:T13:0|CTA1:T14:0|CTA1:T15:0," + "CTA0:T16:0|CTA0:T17:0|CTA0:T18:0|CTA0:T19:0 | CTA1:T16:0|CTA1:T17:0|CTA1:T18:0|CTA1:T19:0," + "CTA0:T20:0|CTA0:T21:0|CTA0:T22:0|CTA0:T23:0 | CTA1:T20:0|CTA1:T21:0|CTA1:T22:0|CTA1:T23:0," + "CTA0:T24:0|CTA0:T25:0|CTA0:T26:0|CTA0:T27:0 | CTA1:T24:0|CTA1:T25:0|CTA1:T26:0|CTA1:T27:0," + "CTA0:T28:0|CTA0:T29:0|CTA0:T30:0|CTA0:T31:0 | CTA1:T28:0|CTA1:T29:0|CTA1:T30:0|CTA1:T31:0," + + "CTA2: T0:0|CTA2: T1:0|CTA2: T2:0|CTA2: T3:0 | CTA3: T0:0|CTA3: T1:0|CTA3: T2:0|CTA3: T3:0," + "CTA2: T4:0|CTA2: T5:0|CTA2: T6:0|CTA2: T7:0 | CTA3: T4:0|CTA3: T5:0|CTA3: T6:0|CTA3: T7:0," + "CTA2: T8:0|CTA2: T9:0|CTA2:T10:0|CTA2:T11:0 | CTA3: T8:0|CTA3: T9:0|CTA3:T10:0|CTA3:T11:0," + "CTA2:T12:0|CTA2:T13:0|CTA2:T14:0|CTA2:T15:0 | CTA3:T12:0|CTA3:T13:0|CTA3:T14:0|CTA3:T15:0," + "CTA2:T16:0|CTA2:T17:0|CTA2:T18:0|CTA2:T19:0 | CTA3:T16:0|CTA3:T17:0|CTA3:T18:0|CTA3:T19:0," + "CTA2:T20:0|CTA2:T21:0|CTA2:T22:0|CTA2:T23:0 | CTA3:T20:0|CTA3:T21:0|CTA3:T22:0|CTA3:T23:0," + "CTA2:T24:0|CTA2:T25:0|CTA2:T26:0|CTA2:T27:0 | CTA3:T24:0|CTA3:T25:0|CTA3:T26:0|CTA3:T27:0," + "CTA2:T28:0|CTA2:T29:0|CTA2:T30:0|CTA2:T31:0 | CTA3:T28:0|CTA3:T29:0|CTA3:T30:0|CTA3:T31:0\n"; + // clang-format on + + runSliceBlockedMultiCTA(/*size=*/16, /*sizePerThread=*/{1, 1}, + /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{1, 1}, + /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, + /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0}, + /*sliceDim=*/1, /*refStr=*/refStr); +} + +//===----------------------------------------------------------------------===// +// Tests for SharedEncodingAttr +//===----------------------------------------------------------------------===// + +TEST_F(EmitIndicesTest, SharedLayout) { + // clang-format off + std::string refStr = + "(0: 0),(0: 1),(0: 2),(0: 3),(0: 4),(0: 5),(0: 6),(0: 7),(0: 8),(0: 9),(0:10),(0:11),(0:12),(0:13),(0:14),(0:15),(0:16),(0:17),(0:18),(0:19),(0:20),(0:21),(0:22),(0:23),(0:24),(0:25),(0:26),(0:27),(0:28),(0:29),(0:30),(0:31)\n" + "(1: 0),(1: 1),(1: 2),(1: 3),(1: 4),(1: 5),(1: 6),(1: 7),(1: 8),(1: 9),(1:10),(1:11),(1:12),(1:13),(1:14),(1:15),(1:16),(1:17),(1:18),(1:19),(1:20),(1:21),(1:22),(1:23),(1:24),(1:25),(1:26),(1:27),(1:28),(1:29),(1:30),(1:31)\n" + "(2: 8),(2: 9),(2:10),(2:11),(2:12),(2:13),(2:14),(2:15),(2: 0),(2: 1),(2: 2),(2: 3),(2: 4),(2: 5),(2: 6),(2: 7),(2:24),(2:25),(2:26),(2:27),(2:28),(2:29),(2:30),(2:31),(2:16),(2:17),(2:18),(2:19),(2:20),(2:21),(2:22),(2:23)\n" + "(3: 8),(3: 9),(3:10),(3:11),(3:12),(3:13),(3:14),(3:15),(3: 0),(3: 1),(3: 2),(3: 3),(3: 4),(3: 5),(3: 6),(3: 7),(3:24),(3:25),(3:26),(3:27),(3:28),(3:29),(3:30),(3:31),(3:16),(3:17),(3:18),(3:19),(3:20),(3:21),(3:22),(3:23)\n" + "(4:16),(4:17),(4:18),(4:19),(4:20),(4:21),(4:22),(4:23),(4:24),(4:25),(4:26),(4:27),(4:28),(4:29),(4:30),(4:31),(4: 0),(4: 1),(4: 2),(4: 3),(4: 4),(4: 5),(4: 6),(4: 7),(4: 8),(4: 9),(4:10),(4:11),(4:12),(4:13),(4:14),(4:15)\n" + "(5:16),(5:17),(5:18),(5:19),(5:20),(5:21),(5:22),(5:23),(5:24),(5:25),(5:26),(5:27),(5:28),(5:29),(5:30),(5:31),(5: 0),(5: 1),(5: 2),(5: 3),(5: 4),(5: 5),(5: 6),(5: 7),(5: 8),(5: 9),(5:10),(5:11),(5:12),(5:13),(5:14),(5:15)\n" + "(6:24),(6:25),(6:26),(6:27),(6:28),(6:29),(6:30),(6:31),(6:16),(6:17),(6:18),(6:19),(6:20),(6:21),(6:22),(6:23),(6: 8),(6: 9),(6:10),(6:11),(6:12),(6:13),(6:14),(6:15),(6: 0),(6: 1),(6: 2),(6: 3),(6: 4),(6: 5),(6: 6),(6: 7)\n" + "(7:24),(7:25),(7:26),(7:27),(7:28),(7:29),(7:30),(7:31),(7:16),(7:17),(7:18),(7:19),(7:20),(7:21),(7:22),(7:23),(7: 8),(7: 9),(7:10),(7:11),(7:12),(7:13),(7:14),(7:15),(7: 0),(7: 1),(7: 2),(7: 3),(7: 4),(7: 5),(7: 6),(7: 7)\n"; + // clang-format on + + runSharedSingleCTA(/*row=*/8, /*col=*/32, /*rowMajor=*/true, + /*elemTyStr=*/"F16", /*refStr=*/refStr); +} + +TEST_F(EmitIndicesTest, LayoutVisualizer_Blocked) { + CTALayoutAttr CTALayout = + CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{2, 2}, + /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0}); + + Attribute blockedLayout = BlockedEncodingAttr::get( + /*context=*/&context, /*sizePerThread=*/{1, 4}, + /*threadsPerWarp=*/{2, 16}, + /*warpsPerCTA=*/{4, 1}, /*order=*/{1, 0}, /*CTALayout=*/CTALayout); + + llvm::SmallVector shape = {/*row=*/128, /*col=*/128}; + + std::ofstream ofs("blockedLayout.csv"); + ofs << dumpDistributedLayout(blockedLayout, shape, /*multiCTA=*/true); +} + +TEST_F(EmitIndicesTest, LayoutVisualizer_Shared) { + CTALayoutAttr CTALayout = + CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1, 1}, + /*CTASplitNum=*/{1, 1}, /*CTAOrder=*/{1, 0}); + + Attribute sharedLayout = SharedEncodingAttr::get( + /*context=*/&context, /*vec=*/1, /*perPhase=*/2, /*maxPhase=*/8, + /*order=*/{0, 1}, /*CTALayout=*/CTALayout); + + llvm::SmallVector shape = {/*row=*/16, /*col=*/16}; + Type elemTy = FloatType::getF16(&context); + + std::ofstream ofs("sharedLayout.csv"); + ofs << dumpSharedLayout(sharedLayout, shape, elemTy, /*multiCTA=*/false); +} +#endif + //===----------------------------------------------------------------------===// // Tests for SliceEncodingAttr //===----------------------------------------------------------------------===// @@ -512,35 +597,6 @@ TEST_F(EmitIndicesTest, SliceLayout_SingleCTA_SliceDim0) { /*order=*/{1, 0}, /*sliceDim=*/0, /*refStr=*/refStr); } -TEST_F(EmitIndicesTest, SliceLayout_MultiCTA) { - // clang-format off - std::string refStr = - "CTA0: T0:0|CTA0: T1:0|CTA0: T2:0|CTA0: T3:0 | CTA1: T0:0|CTA1: T1:0|CTA1: T2:0|CTA1: T3:0," - "CTA0: T4:0|CTA0: T5:0|CTA0: T6:0|CTA0: T7:0 | CTA1: T4:0|CTA1: T5:0|CTA1: T6:0|CTA1: T7:0," - "CTA0: T8:0|CTA0: T9:0|CTA0:T10:0|CTA0:T11:0 | CTA1: T8:0|CTA1: T9:0|CTA1:T10:0|CTA1:T11:0," - "CTA0:T12:0|CTA0:T13:0|CTA0:T14:0|CTA0:T15:0 | CTA1:T12:0|CTA1:T13:0|CTA1:T14:0|CTA1:T15:0," - "CTA0:T16:0|CTA0:T17:0|CTA0:T18:0|CTA0:T19:0 | CTA1:T16:0|CTA1:T17:0|CTA1:T18:0|CTA1:T19:0," - "CTA0:T20:0|CTA0:T21:0|CTA0:T22:0|CTA0:T23:0 | CTA1:T20:0|CTA1:T21:0|CTA1:T22:0|CTA1:T23:0," - "CTA0:T24:0|CTA0:T25:0|CTA0:T26:0|CTA0:T27:0 | CTA1:T24:0|CTA1:T25:0|CTA1:T26:0|CTA1:T27:0," - "CTA0:T28:0|CTA0:T29:0|CTA0:T30:0|CTA0:T31:0 | CTA1:T28:0|CTA1:T29:0|CTA1:T30:0|CTA1:T31:0," - - "CTA2: T0:0|CTA2: T1:0|CTA2: T2:0|CTA2: T3:0 | CTA3: T0:0|CTA3: T1:0|CTA3: T2:0|CTA3: T3:0," - "CTA2: T4:0|CTA2: T5:0|CTA2: T6:0|CTA2: T7:0 | CTA3: T4:0|CTA3: T5:0|CTA3: T6:0|CTA3: T7:0," - "CTA2: T8:0|CTA2: T9:0|CTA2:T10:0|CTA2:T11:0 | CTA3: T8:0|CTA3: T9:0|CTA3:T10:0|CTA3:T11:0," - "CTA2:T12:0|CTA2:T13:0|CTA2:T14:0|CTA2:T15:0 | CTA3:T12:0|CTA3:T13:0|CTA3:T14:0|CTA3:T15:0," - "CTA2:T16:0|CTA2:T17:0|CTA2:T18:0|CTA2:T19:0 | CTA3:T16:0|CTA3:T17:0|CTA3:T18:0|CTA3:T19:0," - "CTA2:T20:0|CTA2:T21:0|CTA2:T22:0|CTA2:T23:0 | CTA3:T20:0|CTA3:T21:0|CTA3:T22:0|CTA3:T23:0," - "CTA2:T24:0|CTA2:T25:0|CTA2:T26:0|CTA2:T27:0 | CTA3:T24:0|CTA3:T25:0|CTA3:T26:0|CTA3:T27:0," - "CTA2:T28:0|CTA2:T29:0|CTA2:T30:0|CTA2:T31:0 | CTA3:T28:0|CTA3:T29:0|CTA3:T30:0|CTA3:T31:0\n"; - // clang-format on - - runSliceBlockedMultiCTA(/*size=*/16, /*sizePerThread=*/{1, 1}, - /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{1, 1}, - /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, - /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0}, - /*sliceDim=*/1, /*refStr=*/refStr); -} - //===----------------------------------------------------------------------===// // Tests for NvidiaMmaEncodingAttr //===----------------------------------------------------------------------===// @@ -571,27 +627,6 @@ TEST_F(EmitIndicesTest, MmaLayout) { /*refStr=*/refStr); } -//===----------------------------------------------------------------------===// -// Tests for SharedEncodingAttr -//===----------------------------------------------------------------------===// - -TEST_F(EmitIndicesTest, SharedLayout) { - // clang-format off - std::string refStr = - "(0: 0),(0: 1),(0: 2),(0: 3),(0: 4),(0: 5),(0: 6),(0: 7),(0: 8),(0: 9),(0:10),(0:11),(0:12),(0:13),(0:14),(0:15),(0:16),(0:17),(0:18),(0:19),(0:20),(0:21),(0:22),(0:23),(0:24),(0:25),(0:26),(0:27),(0:28),(0:29),(0:30),(0:31)\n" - "(1: 0),(1: 1),(1: 2),(1: 3),(1: 4),(1: 5),(1: 6),(1: 7),(1: 8),(1: 9),(1:10),(1:11),(1:12),(1:13),(1:14),(1:15),(1:16),(1:17),(1:18),(1:19),(1:20),(1:21),(1:22),(1:23),(1:24),(1:25),(1:26),(1:27),(1:28),(1:29),(1:30),(1:31)\n" - "(2: 8),(2: 9),(2:10),(2:11),(2:12),(2:13),(2:14),(2:15),(2: 0),(2: 1),(2: 2),(2: 3),(2: 4),(2: 5),(2: 6),(2: 7),(2:24),(2:25),(2:26),(2:27),(2:28),(2:29),(2:30),(2:31),(2:16),(2:17),(2:18),(2:19),(2:20),(2:21),(2:22),(2:23)\n" - "(3: 8),(3: 9),(3:10),(3:11),(3:12),(3:13),(3:14),(3:15),(3: 0),(3: 1),(3: 2),(3: 3),(3: 4),(3: 5),(3: 6),(3: 7),(3:24),(3:25),(3:26),(3:27),(3:28),(3:29),(3:30),(3:31),(3:16),(3:17),(3:18),(3:19),(3:20),(3:21),(3:22),(3:23)\n" - "(4:16),(4:17),(4:18),(4:19),(4:20),(4:21),(4:22),(4:23),(4:24),(4:25),(4:26),(4:27),(4:28),(4:29),(4:30),(4:31),(4: 0),(4: 1),(4: 2),(4: 3),(4: 4),(4: 5),(4: 6),(4: 7),(4: 8),(4: 9),(4:10),(4:11),(4:12),(4:13),(4:14),(4:15)\n" - "(5:16),(5:17),(5:18),(5:19),(5:20),(5:21),(5:22),(5:23),(5:24),(5:25),(5:26),(5:27),(5:28),(5:29),(5:30),(5:31),(5: 0),(5: 1),(5: 2),(5: 3),(5: 4),(5: 5),(5: 6),(5: 7),(5: 8),(5: 9),(5:10),(5:11),(5:12),(5:13),(5:14),(5:15)\n" - "(6:24),(6:25),(6:26),(6:27),(6:28),(6:29),(6:30),(6:31),(6:16),(6:17),(6:18),(6:19),(6:20),(6:21),(6:22),(6:23),(6: 8),(6: 9),(6:10),(6:11),(6:12),(6:13),(6:14),(6:15),(6: 0),(6: 1),(6: 2),(6: 3),(6: 4),(6: 5),(6: 6),(6: 7)\n" - "(7:24),(7:25),(7:26),(7:27),(7:28),(7:29),(7:30),(7:31),(7:16),(7:17),(7:18),(7:19),(7:20),(7:21),(7:22),(7:23),(7: 8),(7: 9),(7:10),(7:11),(7:12),(7:13),(7:14),(7:15),(7: 0),(7: 1),(7: 2),(7: 3),(7: 4),(7: 5),(7: 6),(7: 7)\n"; - // clang-format on - - runSharedSingleCTA(/*row=*/8, /*col=*/32, /*rowMajor=*/true, - /*elemTyStr=*/"F16", /*refStr=*/refStr); -} - //===----------------------------------------------------------------------===// // The following unittests are tools for Triton developers to visualize layouts. // You can modify parameters and shapes here to create your own layout and @@ -599,22 +634,6 @@ TEST_F(EmitIndicesTest, SharedLayout) { // Microsoft Excel. //===----------------------------------------------------------------------===// -TEST_F(EmitIndicesTest, LayoutVisualizer_Blocked) { - CTALayoutAttr CTALayout = - CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{2, 2}, - /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0}); - - Attribute blockedLayout = BlockedEncodingAttr::get( - /*context=*/&context, /*sizePerThread=*/{1, 4}, - /*threadsPerWarp=*/{2, 16}, - /*warpsPerCTA=*/{4, 1}, /*order=*/{1, 0}, /*CTALayout=*/CTALayout); - - llvm::SmallVector shape = {/*row=*/128, /*col=*/128}; - - std::ofstream ofs("blockedLayout.csv"); - ofs << dumpDistributedLayout(blockedLayout, shape, /*multiCTA=*/true); -} - TEST_F(EmitIndicesTest, LayoutVisualizer_Slice) { CTALayoutAttr CTALayout = CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1, 1}, @@ -648,22 +667,6 @@ TEST_F(EmitIndicesTest, LayoutVisualizer_Mma) { ofs << dumpDistributedLayout(mmaLayout, shape, /*multiCTA=*/false); } -TEST_F(EmitIndicesTest, LayoutVisualizer_Shared) { - CTALayoutAttr CTALayout = - CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1, 1}, - /*CTASplitNum=*/{1, 1}, /*CTAOrder=*/{1, 0}); - - Attribute sharedLayout = SharedEncodingAttr::get( - /*context=*/&context, /*vec=*/1, /*perPhase=*/2, /*maxPhase=*/8, - /*order=*/{0, 1}, /*CTALayout=*/CTALayout); - - llvm::SmallVector shape = {/*row=*/16, /*col=*/16}; - Type elemTy = FloatType::getF16(&context); - - std::ofstream ofs("sharedLayout.csv"); - ofs << dumpSharedLayout(sharedLayout, shape, elemTy, /*multiCTA=*/false); -} - } // namespace gpu } // namespace triton } // namespace mlir @@ -674,6 +677,5 @@ TEST_F(EmitIndicesTest, LayoutVisualizer_Shared) { int main(int argc, char *argv[]) { testing::InitGoogleTest(&argc, argv); - // FIXME: These tests are temporarily disabled due to ctaid.x|y|z are swapped - // return RUN_ALL_TESTS(); + return RUN_ALL_TESTS(); }