Skip to content

Commit

Permalink
Merge branch 'llvm-target' into annotated-ptr-op
Browse files Browse the repository at this point in the history
  • Loading branch information
victor-eds authored May 31, 2024
2 parents 440b16d + 5cf6579 commit d0cd06a
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 3 deletions.
1 change: 1 addition & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ target_link_libraries(triton-opt PRIVATE
MLIROptLib
MLIRPass
MLIRTransforms
MLIRSPIRVDialect
)

mlir_check_all_link_libraries(triton-opt)
Expand Down
2 changes: 2 additions & 0 deletions bin/triton-opt.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include "./RegisterTritonDialects.h"

#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"

int main(int argc, char **argv) {
mlir::DialectRegistry registry;
registerTritonDialects(registry);
registry.insert<mlir::spirv::SPIRVDialect>();

return mlir::asMainReturnCode(mlir::MlirOptMain(
argc, argv, "Triton (GPU) optimizer driver\n", registry));
Expand Down
145 changes: 145 additions & 0 deletions test/TritonGEN/tritongen-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,151 @@ llvm.func @triton_gen.named_barrier(%barrier_id : i32, %thread_group_count : i32

// -----

module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Kernel, Addresses, GroupNonUniformShuffle, Int64], []>, #spirv.resource_limits<subgroup_size = 32>>
} {
llvm.func @triton_gen.sub_group_reduce() {
%0 = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VAL:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(0 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(1 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(2 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(3 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(4 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(5 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(6 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(7 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(8 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32
%10 = llvm.mlir.constant(0.0 : f32) : f32
// CHECK: [[VAL:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(9 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32
%11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(10 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32
%12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(11 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32
%13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(12 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32
%14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32
llvm.return
}
}

// -----

module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Kernel, Addresses, GroupNonUniformShuffle, Int64], []>, #spirv.resource_limits<subgroup_size = 16>>
} {
llvm.func @triton_gen.sub_group_reduce() {
%0 = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VAL:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(0 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(1 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(2 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(3 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(4 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(5 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(6 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(7 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(8 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32
%10 = llvm.mlir.constant(0.0 : f32) : f32
// CHECK: [[VAL:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(9 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32
%11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(10 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32
%12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(11 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32
%13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(12 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32
%14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32
llvm.return
}
}

// -----

// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xordj(f64, i32) -> f64 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorfj(f32, i32) -> f32 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorDhj(f16, i32) -> f16 attributes {passthrough = ["convergent"]}
Expand Down
33 changes: 33 additions & 0 deletions test/TritonGEN/tritongen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,39 @@ llvm.func @triton_gen.named_barrier_wait(%barrier_id : i32) {
llvm.return
}

llvm.func @triton_gen.sub_group_reduce() {
// CHECK-LABEL: triton_gen.sub_group_reduce
%0 = llvm.mlir.constant(0 : i32) : i32
// CHECK: triton_gen.sub_group_reduce sum %0 {size = 16} : i32
%1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce prod %0 {size = 16} : i32
%2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce umin %0 {size = 16} : i32
%3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce umax %0 {size = 16} : i32
%4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce imin %0 {size = 16} : i32
%5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce imax %0 {size = 16} : i32
%6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce or %0 {size = 16} : i32
%7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce xor %0 {size = 16} : i32
%8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce and %0 {size = 16} : i32
%9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32
%10 = llvm.mlir.constant(0.0 : f32) : f32
// CHECK: triton_gen.sub_group_reduce fsum %10 {size = 16} : f32
%11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32
// CHECK: triton_gen.sub_group_reduce fprod %10 {size = 16} : f32
%12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32
// CHECK: triton_gen.sub_group_reduce fmin %10 {size = 16} : f32
%13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32
// CHECK: triton_gen.sub_group_reduce fmax %10 {size = 16} : f32
%14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32
llvm.return
}

llvm.func @triton_gen.sub_group_shuffle() {
// CHECK-LABEL: triton_gen.sub_group_shuffle
%0 = llvm.mlir.constant(0 : i32) : i32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,26 @@ class TritonGEN_Attr<string name, string attrMnemonic, list<Trait> traits = []>
let cppNamespace = "::mlir::triton::TritonGEN";
}

/// Enum attribute of the different reduce kinds.
def TritonGEN_ReduceKindAttr : I32EnumAttr<"ReduceKind", "TritonGEN reduce kind",
[
I32EnumAttrCase<"SUM", 0, "sum">,
I32EnumAttrCase<"PROD", 1, "prod">,
I32EnumAttrCase<"UMIN", 2, "umin">,
I32EnumAttrCase<"UMAX", 3, "umax">,
I32EnumAttrCase<"IMIN", 4, "imin">,
I32EnumAttrCase<"IMAX", 5, "imax">,
I32EnumAttrCase<"OR", 6, "or">,
I32EnumAttrCase<"XOR", 7, "xor">,
I32EnumAttrCase<"AND", 8, "and">,
I32EnumAttrCase<"FSUM", 9, "fsum">,
I32EnumAttrCase<"FPROD", 10, "fprod">,
I32EnumAttrCase<"FMIN", 11, "fmin">,
I32EnumAttrCase<"FMAX", 12, "fmax">
]> {
let cppNamespace = "::mlir::triton::TritonGEN";
}

/// Enum attribute of the different shuffle kinds.
def TritonGEN_ShflKindAttr : I32EnumAttr<"ShflKind", "TritonGEN shuffle kind",
[
Expand Down
23 changes: 23 additions & 0 deletions third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,29 @@ def TritonGEN_NamedBarrierWaitOp : TritonGEN_Op<"named_barrier_wait">,

def IntegerOrFloatType : AnyTypeOf<[AnyInteger, AnyFloat]>;

def TritonGEN_SubGroupReduceOp : TritonGEN_Op<"sub_group_reduce", [
AllTypesMatch<["res", "value"]>]>,
Results<(outs IntegerOrFloatType:$res)>,
Arguments<(ins IntegerOrFloatType:$value,
TritonGEN_ReduceKindAttr:$kind,
I32Attr:$size)> {
let summary = "Subgroup reduce";

let description = [{
The `triton_gen.sub_group_reduce` operation is invoked by all work items in
a subgroup, each of them providing a $value. The $size argument is used to
form groups of $size consecutive work items called clusters. Each cluster
performs the reduction operation identified by $kind. The result of the
cluster reduction is propagated to the work items belonging to that cluster.
}];

let assemblyFormat = [{
$kind $value ` ` `{` `size` `=` $size `}` attr-dict `:` type($value)
}];

let hasVerifier = 1;
}

def TritonGEN_SubGroupShuffleOp : TritonGEN_Op<"sub_group_shuffle", [
TypesMatchWith<"result and value have the same type",
"res", "value", "$_self">]>,
Expand Down
9 changes: 9 additions & 0 deletions third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ template <typename Op> static LogicalResult verifyInput(Op op) {
return success();
}

//===----------------------------------------------------------------------===//
// gen.sub_group_reduce
//===----------------------------------------------------------------------===//

LogicalResult TritonGEN::SubGroupReduceOp::verify() {
// TODO: Add verification for SubGroupReduceOp.
return success();
}

//===----------------------------------------------------------------------===//
// gen.matrix.dpas
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions third_party/intel/lib/TritonGENToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ add_triton_library(TritonGENToLLVM

LINK_LIBS PUBLIC
GenISAIntrinsics
MLIRLLVMDialect
MLIRSPIRVDialect
)
61 changes: 58 additions & 3 deletions third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
Expand Down Expand Up @@ -95,6 +97,13 @@ static std::string getTypeMangling(Type ty) {
});
}

/// Get the subgroup size from the target.
static int getSubgroupSize(Operation *op) {
spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op);
assert(attr && "Expecting valid target env attribute");
return attr.getResourceLimits().getSubgroupSize();
}

static LLVM::CallOp createSubGroupShuffle(ConversionPatternRewriter &rewriter,
Value value, Value mask,
TritonGEN::ShflKind kind) {
Expand Down Expand Up @@ -903,6 +912,51 @@ struct TritonGENNamedBarrierWaitLowering
}
};

struct TritonSubGroupReduceLowering
: public ConvertOpToLLVMPattern<TritonGEN::SubGroupReduceOp> {
using ConvertOpToLLVMPattern<
TritonGEN::SubGroupReduceOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(TritonGEN::SubGroupReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value val = op.getValue();
Type val_ty = val.getType();
llvm::LLVMContext llvmContext;
LLVM::TypeToLLVMIRTranslator typeTranslator(llvmContext);
auto moduleOp = op->getParentOfType<ModuleOp>();
auto kind = rewriter.create<LLVM::ConstantOp>(
loc, i8_ty, static_cast<int>(op.getKind()));

std::string funcName;
SmallVector<Type> argTypes;
SmallVector<Value> args;
if (getSubgroupSize(op) == op.getSize()) {
funcName = llvm::GenISAIntrinsic::getName(
llvm::GenISAIntrinsic::GenISA_WaveAll,
{typeTranslator.translateType(val_ty)});
argTypes = {val_ty, i8_ty, i32_ty};
args = {val, kind, i32_val(0)};
} else {
funcName = llvm::GenISAIntrinsic::getName(
llvm::GenISAIntrinsic::GenISA_WaveClustered,
{typeTranslator.translateType(val_ty)});
argTypes = {val_ty, i8_ty, i32_ty, i32_ty};
auto size = rewriter.create<LLVM::ConstantOp>(
loc, i32_ty, static_cast<int>(op.getSize()));
args = {val, kind, size, i32_val(0)};
}

LLVM::LLVMFuncOp funcOp =
LLVM::lookupOrCreateFn(moduleOp, funcName, argTypes, val_ty);
funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);

rewriter.replaceOp(op, rewriter.create<LLVM::CallOp>(loc, funcOp, args));
return success();
}
};

struct TritonSubGroupShuffleLowering
: public ConvertOpToLLVMPattern<TritonGEN::SubGroupShuffleOp> {
using ConvertOpToLLVMPattern<
Expand Down Expand Up @@ -1066,9 +1120,10 @@ void mlir::triton::populateTritonGENToLLVMConversionPatterns(
TritonGENSubgroupIdLowering, TritonGENBarrierLowering,
TritonGENSplitBarrierSignalLowering, TritonGENSplitBarrierWaitLowering,
TritonGENNamedBarrierSignalLowering, TritonGENNamedBarrierWaitLowering,
TritonSubGroupShuffleLowering, TritonMatrixDPASLowering,
TritonMatrix2DBlockLoadLowering, TritonMatrix2DBlockStoreLowering,
TritonMatrix2DBlockPrefetchLowering>(converter);
TritonSubGroupReduceLowering, TritonSubGroupShuffleLowering,
TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering,
TritonMatrix2DBlockStoreLowering, TritonMatrix2DBlockPrefetchLowering>(
converter);
}

void registerConvertTritonTritonGENToLLVMInterface(DialectRegistry &registry) {
Expand Down

0 comments on commit d0cd06a

Please sign in to comment.