Skip to content

Commit

Permalink
[GEN] Fix mangled name for dpas (#862)
Browse files Browse the repository at this point in the history
Signed-off-by: Whitney Tsang <whitney.tsang@intel.com>
  • Loading branch information
whitneywhtsang authored Apr 12, 2024
1 parent 0ed19e4 commit b9dc24a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
4 changes: 2 additions & 2 deletions test/Conversion/tritongpu_to_gen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: dot_i32_i8_i8_i32_1
tt.func @dot_i32_i8_i8_i32_1(%a: tensor<8x32xi8, #dot_operand_a>, %b: tensor<32x16xi8, #dot_operand_b>, %c: tensor<8x16xi32, #mma>) {
// CHECK: llvm.call @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iDv8_i(%{{.*}}, %{{.*}}, %{{.*}}) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32>
// CHECK: llvm.call @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iS0_(%{{.*}}, %{{.*}}, %{{.*}}) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32>
%0 = tt.dot %a, %b, %c, inputPrecision = tf32 : tensor<8x32xi8, #dot_operand_a> * tensor<32x16xi8, #dot_operand_b> -> tensor<8x16xi32, #mma>
tt.return
}
Expand All @@ -565,7 +565,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK-LABEL: dot_i32_i8_i8_i32_2
tt.func @dot_i32_i8_i8_i32_2(%a: tensor<8x64xi8, #dot_operand_a>, %b: tensor<64x16xi8, #dot_operand_b>, %c: tensor<8x16xi32, #mma>) {
// COM: 2 repetition along axis for K.
// CHECK: llvm.call @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iDv8_i(%{{.*}}, %{{.*}}, %{{.*}}) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32>
// CHECK: llvm.call @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iS0_(%{{.*}}, %{{.*}}, %{{.*}}) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32>
%0 = tt.dot %a, %b, %c, inputPrecision = tf32 : tensor<8x64xi8, #dot_operand_a> * tensor<64x16xi8, #dot_operand_b> -> tensor<8x16xi32, #mma>
tt.return
}
Expand Down
8 changes: 4 additions & 4 deletions test/TritonGEN/tritongen-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -141,26 +141,26 @@ llvm.func @triton_gen.sub_group_shuffle() {

// -----

// CHECK: llvm.func spir_funccc @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iDv8_i(vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32> attributes {passthrough = ["convergent"]}
// CHECK: llvm.func spir_funccc @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iS0_(vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32> attributes {passthrough = ["convergent"]}

llvm.func @triton_gen.dpas.i8(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<32xi8>) {
// CHECK: llvm.func @triton_gen.dpas.i8(%arg0: vector<8xi32>, %arg1: vector<16xi8>, %arg2: vector<32xi8>) {
// CHECK-DAG: [[A:%.*]] = llvm.bitcast %arg1 : vector<16xi8> to vector<8xi16>
// CHECK-DAG: [[B:%.*]] = llvm.bitcast %arg2 : vector<32xi8> to vector<8xi32>
// CHECK-NEXT: llvm.call @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iDv8_i([[A]], [[B]], %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32>
// CHECK-NEXT: llvm.call @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iS0_([[A]], [[B]], %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32>
%0 = triton_gen.dpas %c, %a, %b {pa = i8, pb = i8, rc = 8} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32>
llvm.return
}

// -----

// CHECK: llvm.func spir_funccc @_Z36intel_sub_group_u8_u8_matrix_mad_k32Dv8_sDv8_iDv8_i(vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32> attributes {passthrough = ["convergent"]}
// CHECK: llvm.func spir_funccc @_Z36intel_sub_group_u8_u8_matrix_mad_k32Dv8_sDv8_iS0_(vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32> attributes {passthrough = ["convergent"]}

llvm.func @triton_gen.dpas.u8(%c : vector<8xi32>, %a : vector<16xi8>, %b : vector<32xi8>) {
// CHECK: llvm.func @triton_gen.dpas.u8(%arg0: vector<8xi32>, %arg1: vector<16xi8>, %arg2: vector<32xi8>) {
// CHECK-DAG: [[A:%.*]] = llvm.bitcast %arg1 : vector<16xi8> to vector<8xi16>
// CHECK-DAG: [[B:%.*]] = llvm.bitcast %arg2 : vector<32xi8> to vector<8xi32>
// CHECK-NEXT: llvm.call @_Z36intel_sub_group_u8_u8_matrix_mad_k32Dv8_sDv8_iDv8_i([[A]], [[B]], %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32>
// CHECK-NEXT: llvm.call @_Z36intel_sub_group_u8_u8_matrix_mad_k32Dv8_sDv8_iS0_([[A]], [[B]], %arg0) {passthrough = ["convergent"]} : (vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32>
%0 = triton_gen.dpas %c, %a, %b {pa = u8, pb = u8, rc = 8} : (vector<8xi32>, vector<16xi8>, vector<32xi8>) -> vector<8xi32>
llvm.return
}
Expand Down
7 changes: 5 additions & 2 deletions third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,12 @@ static LLVM::CallOp createGenISADPAS(TritonGEN::MatrixDPASOp op,
stringifyPrecisionType(op.getPb()).str() + "_matrix_mad_k" +
std::to_string(8 /*systolic depth*/ *
getNumOperandsPerDword(precisionA));
std::string bMangledTy = getTypeMangling(bTy);
std::string cMangledTy = getTypeMangling(opTypes[0]);
if (bMangledTy == cMangledTy)
cMangledTy = "S0_";
fnName = "_Z" + std::to_string(fnName.size()) + fnName +
getTypeMangling(aTy) + getTypeMangling(bTy) +
getTypeMangling(opTypes[0]);
getTypeMangling(aTy) + bMangledTy + cMangledTy;
SmallVector<Type> argTypes{aTy, bTy, opTypes[0]};
SmallVector<Value> args{a, b, op.getC()};

Expand Down

0 comments on commit b9dc24a

Please sign in to comment.