Skip to content

Commit

Permalink
[GEN]: Introduce the GEN dialect in Triton - part 7 (#642)
Browse files Browse the repository at this point in the history
Lower `tt.dot` to `triton_gen.matrix.dpas` and convert that operation to
GenISA dpas.

---------

Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
Signed-off-by: Whitney Tsang <whitney.tsang@intel.com>
Co-authored-by: Whitney Tsang <whitney.tsang@intel.com>
  • Loading branch information
etiotto and whitneywhtsang authored Mar 11, 2024
1 parent c5c9e52 commit 804dbeb
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 144 deletions.
50 changes: 23 additions & 27 deletions include/triton/Dialect/TritonGEN/IR/TritonGENOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class FixedVectorOfRankAndType<list<int> allowedRanks,

def TritonGEN_MatrixElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, F16, BF16]>;

def TritonGEN_MatrixDPASOp : TritonGEN_Op<"matrix.dpas">,
def TritonGEN_MatrixDPASOp : TritonGEN_Op<"dpas">,
Results<(outs FixedVectorOf<[TritonGEN_MatrixElemType]>:$d)>,
Arguments<(ins
FixedVectorOfRankAndType<[1], [TritonGEN_MatrixElemType]>:$c,
Expand All @@ -222,44 +222,40 @@ def TritonGEN_MatrixDPASOp : TritonGEN_Op<"matrix.dpas">,
let summary = "GEN matrix multiply-add (for PVC)";

string baseDescription = [{
The 'gen.matrix.dpas' operation is a matrix multiply-add operation as follows:
The 'gen.dpas' operation is a matrix multiply-add operation as follows:

D = C + A x B
D = C + A x B

where
D : MxN
C : MxN
A : MxK
B : KxN
where
D : MxN
C : MxN
A : MxK
B : KxN

M : repeat count ($rc), must be 1, 2, 4, or 8
N : fixed execution size, must be 16
K : depth * OPS_PER_CHAN
OPS_PER_CHAN
1 : for TF32
2 : for 16-bit precision(BF, HF)
4 : for 8-bit precision (FP8, UB, B)
8 : for less-then 8 bit precision (U4/S4, U2/S2).
M : repeat count, must be 1, 2, 4, or 8
N : fixed execution size, must be 16
K : depth * OPS_PER_CHAN
OPS_PER_CHAN
1 : for TF32
2 : for 16-bit precision(BF, HF)
4 : for 8-bit precision (FP8, UB, B)
8 : for less-then 8 bit precision (U4/S4, U2/S2).

If depth is 8, K would be 8, 16, 32, or 64 (based on OPS_PER_CHAN).
If depth is 8, K would be 8, 16, 32, or 64 (based on OPS_PER_CHAN).

$a, $b, $c, $d - matrix A, B, C, D, respectively
$pa, $pb - precision of matrix A and B resepectively
$rc - repect count
}];

string llvmBuilder = [{
$d = createGenISADPAS(op, builder, moduleTranslation);
$rc - repeat count
}];

let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
operands ` ` `{` `pa` `=` $pa `,` `pb` `=` $pb `,` `rc` `=` $rc `}` attr-dict `:` functional-type(operands, results)
}];

let hasVerifier = 1;
}

def TritonGEN_Matrix2DBlockLoadOp : TritonGEN_Op<"matrix.2Dblockload">,
def TritonGEN_Matrix2DBlockLoadOp : TritonGEN_Op<"2Dblockload">,
Results<(outs FixedVectorOf<[TritonGEN_MatrixElemType]>:$res)>,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
Expand All @@ -279,7 +275,7 @@ def TritonGEN_Matrix2DBlockLoadOp : TritonGEN_Op<"matrix.2Dblockload">,
let summary = "GEN 2D block load";

string baseDescription = [{
The 'gen.matrix.2Dblockload' operation loads a submatrix from an array in memory.
The 'gen.2Dblockload' operation loads a submatrix from an array in memory.
$ptr - the base address of the memory array
$base_width, $base_height, $base_pitch - the shape of the memory array
$x, $y, $tile_width, $tile_height - the starting offsets and shape of the submatrix to load
Expand All @@ -306,7 +302,7 @@ def TritonGEN_Matrix2DBlockLoadOp : TritonGEN_Op<"matrix.2Dblockload">,
let hasVerifier = 1;
}

def TritonGEN_Matrix2DBlockStoreOp : TritonGEN_Op<"matrix.2Dblockstore">,
def TritonGEN_Matrix2DBlockStoreOp : TritonGEN_Op<"2Dblockstore">,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
I32:$base_width,
Expand All @@ -326,7 +322,7 @@ def TritonGEN_Matrix2DBlockStoreOp : TritonGEN_Op<"matrix.2Dblockstore">,
let summary = "GEN 2D block store";

string baseDescription = [{
The 'gen.matrix.2Dblockstore' operation stores to a submatrix from an array in memory.
The 'gen.2Dblockstore' operation stores to a submatrix from an array in memory.
$ptr - the base address of the memory array
$base_width, $base_height, $base_pitch - the shape of the memory array
$x, $y, $tile_width, $tile_height - the starting offsets and shape of the submatrix to load
Expand Down
112 changes: 45 additions & 67 deletions lib/Dialect/TritonGEN/IR/TritonGENOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ LogicalResult TritonGEN::MatrixDPASOp::verify() {
Type AElemTy = ATy.getElementType();
Type BElemTy = BTy.getElementType();
Type CElemTy = CTy.getElementType();
if (AElemTy != BElemTy)
return this->emitOpError(
"element type of 2nd (A) and 3rd (B) operands must match");

// ATy is required to be vector<RC x i16> as hard coded by IGC.
if (ATy.getNumElements() * AElemTy.getIntOrFloatBitWidth() != getRc() * 16)
Expand All @@ -71,71 +68,52 @@ LogicalResult TritonGEN::MatrixDPASOp::verify() {
return this->emitOpError(
"3rd operand (B) bit-size should be systolic depth (8) times 32");

return TypeSwitch<Type, LogicalResult>(AElemTy)
.Case<Float32Type>([&](auto ty) -> LogicalResult {
if (precision != TritonGEN::PrecisionType::TF32)
return this->emitOpError("precision should be TF32 when 2nd (A) or "
"3rd (B) operand element type is f32");
if (!CElemTy.isF32())
return this->emitOpError("the element type for 1st operand (C) and "
"the result should be f32");
return success();
})
.Case<BFloat16Type>([&](auto ty) -> LogicalResult {
if (precision != TritonGEN::PrecisionType::BF16)
return this->emitOpError(
"precision should be BF16 when 2nd (A) or 3rd (B) operand "
"element type is bf16");
if (!CElemTy.isF32())
return this->emitOpError(
"the element type for 1st operand (C) and the "
"result should be f32");
return success();
})
.Case<Float16Type>([&](auto ty) -> LogicalResult {
if (precision != TritonGEN::PrecisionType::FP16)
return this->emitOpError("precision should be FP16 when 2nd (A) or "
"3rd (B) operand element type is f16");
if (!CElemTy.isF32())
return this->emitOpError(
"the element type for 1st operand (C) and the "
"result should be f32");
return success();
})
.Case<IntegerType>([&](auto ty) -> LogicalResult {
if (!ty.isInteger(8))
return this->emitOpError(
"expecting 2nd (A) or 3rd (B) operand element type to be f32, "
"bf16, f16, or i8");

if (precision == TritonGEN::PrecisionType::U8) {
if (ty.isSigned())
return this->emitOpError(
"precision should be S8 when 2nd (A) or 3rd (B) operand "
"element type is signed i8");
} else if (precision == TritonGEN::PrecisionType::S8) {
if (ty.isUnsigned())
return this->emitOpError(
"precision should be U8 when 2nd (A) or 3rd (B) operand "
"element type is unsigned i8");
} else
return this->emitOpError("precision should be U8 or S8 when 2nd (A) "
"or 3rd (B) operand element type is i8");

if (!CElemTy.isInteger(32))
return this->emitOpError("the element type for 1st operand (C) and "
"the result should be i32");

return success();
})
.Default([&](mlir::Type) -> LogicalResult {
return this->emitOpError("expecting 2nd (A) or 3rd (B) operand element "
"type to be f32, bf16, f16, or i8");
});
if (precision == TritonGEN::PrecisionType::U8 ||
precision == TritonGEN::PrecisionType::S8) {
if (!CElemTy.isInteger(32))
return this->emitOpError("the element type for 1st operand (C) and "
"the result should be i32");
} else if (!CElemTy.isF32())
return this->emitOpError("the element type for 1st operand (C) and the "
"result should be f32");

switch (precision) {
case TritonGEN::PrecisionType::TF32:
if (!AElemTy.isa<Float32Type>() && !AElemTy.isInteger(32))
return this->emitOpError("A and B operand element type should be f32 or "
"i32 when precision type is tf32");
break;
case TritonGEN::PrecisionType::BF16:
if (!AElemTy.isa<BFloat16Type>() && !AElemTy.isInteger(16))
return this->emitOpError("A and B operand element type should be bf16 or "
"i16 when precision type is bf16");
break;
case TritonGEN::PrecisionType::FP16:
if (!AElemTy.isa<Float16Type>() && !AElemTy.isInteger(16))
return this->emitOpError("A and B operand element type should be f16 or "
"i16 when precision type is f16");
break;
case TritonGEN::PrecisionType::U8:
if (!(AElemTy.isInteger(8) && !AElemTy.cast<IntegerType>().isSigned()) &&
!AElemTy.isInteger(16))
return this->emitOpError("A and B operand element type should be u8, i8, "
"or i16 when precision type is u8");
break;
case TritonGEN::PrecisionType::S8:
if (!(AElemTy.isInteger(8) && !AElemTy.cast<IntegerType>().isUnsigned()) &&
!AElemTy.isInteger(16))
return this->emitOpError("A and B operand element type should be s8, i8, "
"or i16 when precision type is s8");
break;
default:
return this->emitOpError(
"expecting precision type to be tf32, bf16, fp16, u8, or s8");
}
return success();
}

//===----------------------------------------------------------------------===//
// gen.matrix.2Dblockload
// gen.2Dblockload
//===----------------------------------------------------------------------===//

static std::optional<int> getConstantInt(Value v) {
Expand Down Expand Up @@ -206,15 +184,15 @@ template <typename Op> static LogicalResult verifyInput(Op op) {
}

//===----------------------------------------------------------------------===//
// gen.matrix.2Dblockload
// gen.2Dblockload
//===----------------------------------------------------------------------===//

LogicalResult TritonGEN::Matrix2DBlockLoadOp::verify() {
return verifyInput(*this);
}

//===----------------------------------------------------------------------===//
// gen.matrix.2Dblockstore
// gen.2Dblockstore
//===----------------------------------------------------------------------===//

LogicalResult TritonGEN::Matrix2DBlockStoreOp::verify() {
Expand Down
12 changes: 6 additions & 6 deletions test/Conversion/tritongpu_to_gen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: dot_f32_f16_f16_f32_1
tt.func @dot_f32_f16_f16_f32_1(%a: tensor<8x16xf16, #dot_operand_a>, %b: tensor<16x16xf16, #dot_operand_b>, %c: tensor<8x16xf32, #mma>) {
// CHECK: genx.matrix.dpas {{.*}}, {{.*}}, {{.*}} {pa = #genx.precision_type<FP16>, pb = #genx.precision_type<FP16>, rc = 8 : i32} : (vector<8xf32>, vector<8xi16>, vector<8xi32>) -> vector<8xf32>
// CHECK: llvm.call @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32
%0 = tt.dot %a, %b, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<8x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<8x16xf32, #mma>
tt.return
}
Expand All @@ -578,7 +578,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK-LABEL: dot_f32_f16_f16_f32_2
tt.func @dot_f32_f16_f16_f32_2(%a: tensor<16x16xf16, #dot_operand_a>, %b: tensor<16x16xf16, #dot_operand_b>, %c: tensor<16x16xf32, #mma>) {
// COM: 2 repetitions along axis for M.
// CHECK-COUNT-2: genx.matrix.dpas {{.*}}, {{.*}}, {{.*}} {pa = #genx.precision_type<FP16>, pb = #genx.precision_type<FP16>, rc = 8 : i32} : (vector<8xf32>, vector<8xi16>, vector<8xi32>) -> vector<8xf32>
// CHECK-COUNT-2: llvm.call @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32
%0 = tt.dot %a, %b, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma>
tt.return
}
Expand All @@ -593,7 +593,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: dot_i32_i8_i8_i32_1
tt.func @dot_i32_i8_i8_i32_1(%a: tensor<8x32xi8, #dot_operand_a>, %b: tensor<32x16xi8, #dot_operand_b>, %c: tensor<8x16xi32, #mma>) {
// CHECK: genx.matrix.dpas {{.*}}, {{.*}}, {{.*}} {pa = #genx.precision_type<S8>, pb = #genx.precision_type<S8>, rc = 8 : i32} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32>
// CHECK: llvm.call @llvm.genx.GenISA.sub.group.dpas.v8i32.v8i32.v8i16.v8i32
%0 = tt.dot %a, %b, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<8x32xi8, #dot_operand_a> * tensor<32x16xi8, #dot_operand_b> -> tensor<8x16xi32, #mma>
tt.return
}
Expand All @@ -609,7 +609,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK-LABEL: dot_i32_i8_i8_i32_2
tt.func @dot_i32_i8_i8_i32_2(%a: tensor<8x64xi8, #dot_operand_a>, %b: tensor<64x16xi8, #dot_operand_b>, %c: tensor<8x16xi32, #mma>) {
// COM: 2 repetition along axis for K.
// CHECK: genx.matrix.dpas {{.*}}, {{.*}}, {{.*}} {pa = #genx.precision_type<S8>, pb = #genx.precision_type<S8>, rc = 8 : i32} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32>
// CHECK: llvm.call @llvm.genx.GenISA.sub.group.dpas.v8i32.v8i32.v8i16.v8i32
%0 = tt.dot %a, %b, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<8x64xi8, #dot_operand_a> * tensor<64x16xi8, #dot_operand_b> -> tensor<8x16xi32, #mma>
tt.return
}
Expand All @@ -624,7 +624,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: dot_f32_tf32_tf32_f32_1
tt.func @dot_f32_tf32_tf32_f32_1(%a: tensor<8x8xf32, #dot_operand_a>, %b: tensor<8x16xf32, #dot_operand_b>, %c: tensor<8x16xf32, #mma>) {
// CHECK: genx.matrix.dpas {{.*}}, {{.*}}, {{.*}} {pa = #genx.precision_type<TF32>, pb = #genx.precision_type<TF32>, rc = 8 : i32} : (vector<8xf32>, vector<4xi32>, vector<8xi32>) -> vector<8xf32>
// CHECK: llvm.call @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32
%0 = tt.dot %a, %b, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<8x8xf32, #dot_operand_a> * tensor<8x16xf32, #dot_operand_b> -> tensor<8x16xf32, #mma>
tt.return
}
Expand All @@ -640,7 +640,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK-LABEL: dot_f32_tf32_tf32_f32_2
tt.func @dot_f32_tf32_tf32_f32_2(%a: tensor<8x8xf32, #dot_operand_a>, %b: tensor<8x32xf32, #dot_operand_b>, %c: tensor<8x32xf32, #mma>) {
// COM: 2 repetitions along axis for N.
// CHECK-COUNT-2: genx.matrix.dpas {{.*}}, {{.*}}, {{.*}} {pa = #genx.precision_type<TF32>, pb = #genx.precision_type<TF32>, rc = 8 : i32} : (vector<8xf32>, vector<4xi32>, vector<8xi32>) -> vector<8xf32>
// CHECK-COUNT-2: llvm.call @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32
%0 = tt.dot %a, %b, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<8x8xf32, #dot_operand_a> * tensor<8x32xf32, #dot_operand_b> -> tensor<8x32xf32, #mma>
tt.return
}
Expand Down
64 changes: 64 additions & 0 deletions test/TritonGEN/tritongen-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,67 @@ llvm.func @triton_gen.fptofp(%a : f32) {
%0 = triton_gen.fptofp %a : f32 to f16
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<32xi8>) {
// expected-error @+1 {{'triton_gen.dpas' op expecting repeat count to be 1, 2, 4, or 8}}
%0 = triton_gen.dpas %c, %a, %b {pa=s8, pb=s8, rc=6} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<32xi8>) {
// expected-error @+1 {{'triton_gen.dpas' op expecting precision of matrix A and B to be the same}}
%0 = triton_gen.dpas %c, %a, %b {pa=s8, pb=u8, rc=8} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xi8>, %a : vector<16xi8>, %b : vector<32xi8>) {
// expected-error @+1 {{'triton_gen.dpas' op 1st operand (C) and result (D) should have the same type}}
%0 = triton_gen.dpas %c, %a, %b {pa=s8, pb=s8, rc=8} : (vector<8xi8>, vector<16xi8>, vector<32xi8>) -> vector<8xi32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<16xi32>, %a : vector<16xi8>, %b : vector<32xi8>) {
// expected-error @+1 {{'triton_gen.dpas' op the dimension for 1st operand (C) and result (D) should match repeat count}}
%0 = triton_gen.dpas %c, %a, %b {pa=s8, pb=s8, rc=8} : (vector<16xi32>, vector<16xi8>, vector<32xi8>) -> vector<16xi32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi8>, %b : vector<8xi8>) {
// expected-error @+1 {{'triton_gen.dpas' op 2nd operand (A) bit-size should be repeat count times 16}}
%0 = triton_gen.dpas %c, %a, %b {pa=s8, pb=s8, rc=8} : (vector<8xi32>, vector<8xi8>, vector<8xi8>) -> vector<8xi32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<16xi8>) {
// expected-error @+1 {{'triton_gen.dpas' op 3rd operand (B) bit-size should be systolic depth (8) times 32}}
%0 = triton_gen.dpas %c, %a, %b {pa=s8, pb=s8, rc=8} : (vector<8xi32>, vector<16xi8>, vector<16xi8>) -> vector<8xi32>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xi8>, %a : vector<16xi8>, %b : vector<32xi8>) {
// expected-error @+1 {{'triton_gen.dpas' op the element type for 1st operand (C) and the result should be i32}}
%0 = triton_gen.dpas %c, %a, %b {pa=s8, pb=s8, rc=8} : (vector<8xi8>, vector<16xi8>, vector<32xi8>) -> vector<8xi8>
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<8xf16>, %b : vector<16xf16>) {
// expected-error @+1 {{'triton_gen.dpas' op expecting precision type to be tf32, bf16, fp16, u8, or s8}}
%0 = triton_gen.dpas %c, %a, %b {pa=s4, pb=s4, rc=8} : (vector<8xf32>, vector<8xf16>, vector<16xf16>) -> vector<8xf32>
llvm.return
}
16 changes: 16 additions & 0 deletions test/TritonGEN/tritongen-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,19 @@ llvm.func @triton_gen.sub_group_shuffle() {
%16 = triton_gen.sub_group_shuffle xor %15, %0 : f64 -> f64
llvm.return
}

// -----

llvm.func @triton_gen.dpas.f32(%c : vector<8xf32>, %a : vector<4xf32>, %b : vector<8xf32>) {
// CHECK-DAG: [[A:%.*]] = llvm.bitcast %arg1 : vector<4xf32> to vector<8xi16>
// CHECK-DAG: [[B:%.*]] = llvm.bitcast %arg2 : vector<8xf32> to vector<8xi32>
// CHECK-DAG: [[CST_8a:%.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK-DAG: [[CST_8b:%.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK-DAG: [[CST_8c:%.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK-DAG: [[CST_8d:%.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK-DAG: [[CST_FALSE:%.*]] = llvm.mlir.constant(false) : i1
// CHECK-NEXT: llvm.call @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32
// CHEC-SAME: (%arg0, [[A]], [[B]], [[CST_8a]], [[CST_8b]], [[CST_8c]], [[CST_8d]], [[CST_FALSE]]) : (vector<8xf32>, vector<8xi16>, vector<8xi32>, i32, i32, i32, i32, i1) -> vector<8xf32>
%0 = triton_gen.dpas %c, %a, %b {pa = tf32, pb = tf32, rc = 8} : (vector<8xf32>, vector<4xf32>, vector<8xf32>) -> vector<8xf32>
llvm.return
}
6 changes: 6 additions & 0 deletions test/TritonGEN/tritongen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,9 @@ llvm.func @triton_gen.fptofp(%a: f32, %b: f16) {
%8 = triton_gen.fptofp %b : f16 to f32
llvm.return
}

llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<32xi8>) {
// CHECK: %0 = triton_gen.dpas %arg0, %arg1, %arg2 {pa = s8, pb = s8, rc = 8} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32>
%0 = triton_gen.dpas %c, %a, %b {pa=s8, pb=s8, rc=8} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32>
llvm.return
}
Loading

0 comments on commit 804dbeb

Please sign in to comment.