Skip to content

Commit

Permalink
Change some OCL usages to SPIRV (#1191)
Browse files Browse the repository at this point in the history
KhronosGroup/SPIRV-LLVM-Translator@9e60105
reverted the change in SPIRV translator to translate OCL C built-ins to
SPIRV built-ins. We need to use the SPIRV built-ins directly for the
cases below:
1. `_Z25__spirv_BuiltInSubgroupIdv` lowering is faster than
`_Z16get_sub_group_idv`.
2. `_Z31intel_convert_as_bfloat16_floatt` causes 349 `RuntimeError:
Triton Error [ZE]: 0x78000018` failures.
3. `_Z32intel_convert_bfloat16_as_ushortf` causes 70 `RuntimeError:
Triton Error [ZE]: 0x78000018` failures.

Signed-off-by: Whitney Tsang <whitney.tsang@intel.com>
  • Loading branch information
whitneywhtsang authored May 27, 2024
1 parent 0dbfa04 commit da631c7
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion lib/Target/SPIRV/spirv-llvm-translator.conf
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3c1ff53d6202f028e6173fed7b378e3599a76606
9e60105170110e0fd01f073954763af399d1c596
18 changes: 9 additions & 9 deletions test/Conversion/intel/arith_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm | FileCheck %s

// CHECK-DAG: llvm.func spir_funccc @_Z32intel_convert_bfloat16_as_ushortf(f32) -> i16
// CHECK-DAG: llvm.func spir_funccc @_Z27__spirv_ConvertFToBF16INTELf(f32) -> i16
// CHECK-DAG: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>

Expand All @@ -13,10 +13,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_0]][2] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_0]][3] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK: %[[VAL_6:.*]] = llvm.call spir_funccc @_Z32intel_convert_bfloat16_as_ushortf(%[[VAL_2]]) : (f32) -> i16
// CHECK: %[[VAL_7:.*]] = llvm.call spir_funccc @_Z32intel_convert_bfloat16_as_ushortf(%[[VAL_3]]) : (f32) -> i16
// CHECK: %[[VAL_8:.*]] = llvm.call spir_funccc @_Z32intel_convert_bfloat16_as_ushortf(%[[VAL_4]]) : (f32) -> i16
// CHECK: %[[VAL_9:.*]] = llvm.call spir_funccc @_Z32intel_convert_bfloat16_as_ushortf(%[[VAL_5]]) : (f32) -> i16
// CHECK: %[[VAL_6:.*]] = llvm.call spir_funccc @_Z27__spirv_ConvertFToBF16INTELf(%[[VAL_2]]) : (f32) -> i16
// CHECK: %[[VAL_7:.*]] = llvm.call spir_funccc @_Z27__spirv_ConvertFToBF16INTELf(%[[VAL_3]]) : (f32) -> i16
// CHECK: %[[VAL_8:.*]] = llvm.call spir_funccc @_Z27__spirv_ConvertFToBF16INTELf(%[[VAL_4]]) : (f32) -> i16
// CHECK: %[[VAL_9:.*]] = llvm.call spir_funccc @_Z27__spirv_ConvertFToBF16INTELf(%[[VAL_5]]) : (f32) -> i16
// CHECK: %[[VAL_10:.*]] = llvm.mlir.undef : !llvm.struct<(i16, i16, i16, i16)>
// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_10]][0] : !llvm.struct<(i16, i16, i16, i16)>
// CHECK: %[[VAL_12:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_11]][1] : !llvm.struct<(i16, i16, i16, i16)>
Expand All @@ -35,10 +35,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(i16, i16, i16, i16)>
// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_0]][2] : !llvm.struct<(i16, i16, i16, i16)>
// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_0]][3] : !llvm.struct<(i16, i16, i16, i16)>
// CHECK: %[[VAL_6:.*]] = llvm.call spir_funccc @_Z31intel_convert_as_bfloat16_floatt(%[[VAL_2]]) : (i16) -> f32
// CHECK: %[[VAL_7:.*]] = llvm.call spir_funccc @_Z31intel_convert_as_bfloat16_floatt(%[[VAL_3]]) : (i16) -> f32
// CHECK: %[[VAL_8:.*]] = llvm.call spir_funccc @_Z31intel_convert_as_bfloat16_floatt(%[[VAL_4]]) : (i16) -> f32
// CHECK: %[[VAL_9:.*]] = llvm.call spir_funccc @_Z31intel_convert_as_bfloat16_floatt(%[[VAL_5]]) : (i16) -> f32
// CHECK: %[[VAL_6:.*]] = llvm.call spir_funccc @_Z27__spirv_ConvertBF16ToFINTELs(%[[VAL_2]]) : (i16) -> f32
// CHECK: %[[VAL_7:.*]] = llvm.call spir_funccc @_Z27__spirv_ConvertBF16ToFINTELs(%[[VAL_3]]) : (i16) -> f32
// CHECK: %[[VAL_8:.*]] = llvm.call spir_funccc @_Z27__spirv_ConvertBF16ToFINTELs(%[[VAL_4]]) : (i16) -> f32
// CHECK: %[[VAL_9:.*]] = llvm.call spir_funccc @_Z27__spirv_ConvertBF16ToFINTELs(%[[VAL_5]]) : (i16) -> f32
// CHECK: %[[VAL_10:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32)>
// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_10]][0] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK: %[[VAL_12:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_11]][1] : !llvm.struct<(f32, f32, f32, f32)>
Expand Down
4 changes: 2 additions & 2 deletions test/TritonGEN/tritongen-to-llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s

// CHECK-DAG: llvm.func spir_funccc @_Z16get_sub_group_idv() -> i32
// CHECK-DAG: llvm.func spir_funccc @_Z25__spirv_BuiltInSubgroupIdv() -> i32
// CHECK-DAG: llvm.func spir_funccc @_Z14get_num_groupsj(i32) -> i64
// CHECK-DAG: llvm.func spir_funccc @_Z14get_local_sizej(i32) -> i64
// CHECK-DAG: llvm.func spir_funccc @_Z12get_group_idj(i32) -> i64
Expand Down Expand Up @@ -49,7 +49,7 @@ llvm.func @gen_special_regs() -> i32 {
// CHECK: llvm.call @_Z14get_num_groupsj([[TWO3]]) : (i32) -> i64
%12 = triton_gen.grid.dim.z : i32

// CHECK: llvm.call @_Z16get_sub_group_idv() : () -> i32
// CHECK: llvm.call @_Z25__spirv_BuiltInSubgroupIdv() : () -> i32
%13 = triton_gen.subgroup.id : i32

llvm.return %1 : i32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ struct TritonGENSubgroupIdLowering
ConversionPatternRewriter &rewriter) const override {
auto retType = rewriter.getIntegerType(32);
LLVM::CallOp callOp = createDeviceFunctionCall(
rewriter, "_Z16get_sub_group_idv", retType, {}, {});
rewriter, "_Z25__spirv_BuiltInSubgroupIdv", retType, {}, {});
rewriter.replaceOp(op, callOp);
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ struct FpToFpOpConversion
const Value &v) {
auto moduleOp =
v.getDefiningOp()->getParentWithTrait<OpTrait::SymbolTable>();
constexpr StringLiteral name = "_Z31intel_convert_as_bfloat16_floatt";
constexpr StringLiteral name = "_Z27__spirv_ConvertBF16ToFINTELs";
auto ext_func = triton::gpu::intel::lookupOrCreateSPIRVFn(moduleOp, name,
i16_ty, f32_ty);
auto call =
Expand All @@ -1308,7 +1308,7 @@ struct FpToFpOpConversion
auto moduleOp =
v.getDefiningOp()->getParentWithTrait<OpTrait::SymbolTable>();
// Intel SPIR-V extension only supports round-to-nearest-even
constexpr StringLiteral name = "_Z32intel_convert_bfloat16_as_ushortf";
constexpr StringLiteral name = "_Z27__spirv_ConvertFToBF16INTELf";
auto trunc_func = triton::gpu::intel::lookupOrCreateSPIRVFn(
moduleOp, name, f32_ty, i16_ty);
auto call = triton::gpu::intel::createSPIRVBuiltinCall(loc, rewriter,
Expand Down

0 comments on commit da631c7

Please sign in to comment.