Skip to content

Commit

Permalink
[NFI]: Cleanup ElementwiseOpToLLVM.cpp (#2973)
Browse files Browse the repository at this point in the history
Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
  • Loading branch information
etiotto authored Dec 10, 2024
1 parent 0773668 commit d16a1dd
Showing 1 changed file with 76 additions and 84 deletions.
160 changes: 76 additions & 84 deletions third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,8 @@ namespace {
/* ----- FP8E5M2 ------ */
// This data-type is the standard FP8E5M2 format
static SmallVector<Value>
Fp16_to_Fp8E5M2_func(Location loc, ConversionPatternRewriter &rewriter,
Fp16_to_Fp8E5M2_RTNE(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1));

Value a0 = bitcast(fp16x2Vec0, i32_ty);
Value a1 = bitcast(fp16x2Vec1, i32_ty);

auto fp8x4VecTy = vec_ty(i8_ty, 4);
a0 = bitcast(a0, fp8x4VecTy);
a1 = bitcast(a1, fp8x4VecTy);

return {extract_element(i8_ty, a0, i32_val(1)),
extract_element(i8_ty, a0, i32_val(3)),
extract_element(i8_ty, a1, i32_val(1)),
extract_element(i8_ty, a1, i32_val(3))};
}

static SmallVector<Value>
Fp16_to_Fp8E5M2_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {

Value val = zext(i32_ty, bitcast(v[0], i16_ty));
Value sign = and_(i32_ty, val, i32_val(0x8000));
Value nosign = and_(i32_ty, val, i32_val(0x7fff));
Expand All @@ -63,8 +38,32 @@ Fp16_to_Fp8E5M2_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
}

static SmallVector<Value>
Fp8E5M2_to_Fp16_func(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
Fp16_to_Fp8E5M2_RTZ(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1));

Value a0 = bitcast(fp16x2Vec0, i32_ty);
Value a1 = bitcast(fp16x2Vec1, i32_ty);

auto fp8x4VecTy = vec_ty(i8_ty, 4);
a0 = bitcast(a0, fp8x4VecTy);
a1 = bitcast(a1, fp8x4VecTy);

return {extract_element(i8_ty, a0, i32_val(1)),
extract_element(i8_ty, a0, i32_val(3)),
extract_element(i8_ty, a1, i32_val(1)),
extract_element(i8_ty, a1, i32_val(3))};
}

static SmallVector<Value> Fp8E5M2_to_Fp16(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a0 = undef(fp8x4VecTy);
a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(0));
Expand All @@ -89,9 +88,9 @@ Fp8E5M2_to_Fp16_func(Location loc, ConversionPatternRewriter &rewriter,
extract_element(f16_ty, fp16x2Vec1, i32_val(1))};
}

static SmallVector<Value>
Fp8E5M2_to_Bf16_func(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
static SmallVector<Value> Fp8E5M2_to_Bf16(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a0 = undef(fp8x4VecTy);
a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(0));
Expand Down Expand Up @@ -178,9 +177,9 @@ Fp8E5M2_to_Bf16_func(Location loc, ConversionPatternRewriter &rewriter,
extract_element(bf16_ty, bf16x2Vec1, i32_val(1))};
}

static SmallVector<Value>
Bf16_to_Fp8E5M2_func(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
static SmallVector<Value> Bf16_to_Fp8E5M2(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
Value bf16x2Vec0 = undef(bf16x2VecTy);
Value bf16x2Vec1 = undef(bf16x2VecTy);
Expand Down Expand Up @@ -259,8 +258,8 @@ Bf16_to_Fp8E5M2_func(Location loc, ConversionPatternRewriter &rewriter,
}

static SmallVector<Value>
Bf16_to_Fp8E5M2_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
Bf16_to_Fp8E5M2_RTNE(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
Value val = zext(i32_ty, bitcast(v[0], i16_ty));
Value sign = and_(i32_ty, val, i32_val(0x8000));
Value nosign = and_(i32_ty, val, i32_val(0x7fff));
Expand Down Expand Up @@ -320,8 +319,8 @@ Bf16_to_Fp8E5M2_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
// - has multiple nans (when all exponent bits are 1)
// - has an exponent bias of 15 (vs. 7 for fp8e4m3)
static SmallVector<Value>
Fp8E4M3B15_to_Fp16_func(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a0 = undef(fp8x4VecTy);
a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(0));
Expand Down Expand Up @@ -357,8 +356,8 @@ Fp8E4M3B15_to_Fp16_func(Location loc, ConversionPatternRewriter &rewriter,
}

static SmallVector<Value>
Fp16_to_Fp8E4M3B15_func(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
Fp16_to_Fp8E4M3B15(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
Expand Down Expand Up @@ -404,9 +403,9 @@ Fp16_to_Fp8E4M3B15_func(Location loc, ConversionPatternRewriter &rewriter,
// has more than a single NaN values.

// Fp8E4M3 -> Fp16 (packed)
static SmallVector<Value>
Fp8E4M3Nv_to_Fp16_func(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
static SmallVector<Value> Fp8E4M3Nv_to_Fp16(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a0 = undef(fp8x4VecTy);
a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(0));
Expand Down Expand Up @@ -478,9 +477,9 @@ Fp8E4M3Nv_to_Fp16_func(Location loc, ConversionPatternRewriter &rewriter,
}

// Fp16 -> Fp8E4M3 (packed)
static SmallVector<Value>
Fp16_to_Fp8E4M3Nv_func(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
static SmallVector<Value> Fp16_to_Fp8E4M3Nv(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);

Expand All @@ -503,8 +502,8 @@ Fp16_to_Fp8E4M3Nv_func(Location loc, ConversionPatternRewriter &rewriter,
}

static SmallVector<Value>
Fp16_to_Fp8E4M3Nv_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
Fp16_to_Fp8E4M3Nv_RTNE(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
Value val = zext(i32_ty, bitcast(v[0], i16_ty));
Value sign = and_(i32_ty, val, i32_val(0x8000));
Value nosign = and_(i32_ty, val, i32_val(0x7fff));
Expand Down Expand Up @@ -556,9 +555,9 @@ Fp16_to_Fp8E4M3Nv_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
return {extract_element(i8_ty, res, i32_val(1))};
}

static SmallVector<Value>
Fp8E4M3Nv_to_Bf16_func(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
static SmallVector<Value> Fp8E4M3Nv_to_Bf16(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a0 = undef(fp8x4VecTy);
a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(0));
Expand Down Expand Up @@ -656,9 +655,9 @@ Fp8E4M3Nv_to_Bf16_func(Location loc, ConversionPatternRewriter &rewriter,
extract_element(bf16_ty, bf16x2Vec1, i32_val(1))};
}

static SmallVector<Value>
Bf16_to_Fp8E4M3Nv_func(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
static SmallVector<Value> Bf16_to_Fp8E4M3Nv(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
Value bf16x2Vec0 = undef(bf16x2VecTy);
Value bf16x2Vec1 = undef(bf16x2VecTy);
Expand Down Expand Up @@ -737,8 +736,8 @@ Bf16_to_Fp8E4M3Nv_func(Location loc, ConversionPatternRewriter &rewriter,
}

static SmallVector<Value>
Bf16_to_Fp8E4M3Nv_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
Bf16_to_Fp8E4M3Nv_RTNE(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
Value val = zext(i32_ty, bitcast(v[0], i16_ty));
Value sign = and_(i32_ty, val, i32_val(0x8000));
Value nosign = and_(i32_ty, val, i32_val(0x7fff));
Expand Down Expand Up @@ -790,9 +789,9 @@ Bf16_to_Fp8E4M3Nv_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
return {extract_element(i8_ty, res, i32_val(1))};
}

static SmallVector<Value> Bf16_to_Fp16_func(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
static SmallVector<Value> Bf16_to_Fp16(Location loc,
ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto bf16x2VecTy = vec_ty(bf16_ty, 2);

Value bf16x2Vec = undef(bf16x2VecTy);
Expand Down Expand Up @@ -997,39 +996,34 @@ struct FpToFpOpConversion
std::pair<ConverterT, size_t>>
srcMap = {
// F8 -> F16
{{F8E4M3B15TyID, F16TyID, undefRounding},
{Fp8E4M3B15_to_Fp16_func, 4}},
{{F8E4M3TyID, F16TyID, undefRounding}, {Fp8E4M3Nv_to_Fp16_func, 2}},
{{F8E5M2TyID, F16TyID, undefRounding}, {Fp8E5M2_to_Fp16_func, 4}},
{{F8E4M3B15TyID, F16TyID, undefRounding}, {Fp8E4M3B15_to_Fp16, 4}},
{{F8E4M3TyID, F16TyID, undefRounding}, {Fp8E4M3Nv_to_Fp16, 2}},
{{F8E5M2TyID, F16TyID, undefRounding}, {Fp8E5M2_to_Fp16, 4}},
// F16 -> F8
{{F16TyID, F8E4M3B15TyID, RoundingMode::RTZ},
{Fp16_to_Fp8E4M3B15_func, 4}},
{Fp16_to_Fp8E4M3B15, 4}},
{{F16TyID, F8E4M3B15TyID, RoundingMode::RTNE},
// TODO: provide proper implementation for RTNE rounding.
{Fp16_to_Fp8E4M3B15_func, 4}},
{{F16TyID, F8E4M3TyID, RoundingMode::RTZ},
{Fp16_to_Fp8E4M3Nv_func, 2}},
{Fp16_to_Fp8E4M3B15, 4}},
{{F16TyID, F8E4M3TyID, RoundingMode::RTZ}, {Fp16_to_Fp8E4M3Nv, 2}},
{{F16TyID, F8E4M3TyID, RoundingMode::RTNE},
{Fp16_to_Fp8E4M3Nv_RTNE_func, 1}},
{Fp16_to_Fp8E4M3Nv_RTNE, 1}},
{{F16TyID, F8E5M2TyID, RoundingMode::RTZ},
{Fp16_to_Fp8E5M2_func, 4}},
{Fp16_to_Fp8E5M2_RTZ, 4}},
{{F16TyID, F8E5M2TyID, RoundingMode::RTNE},
{Fp16_to_Fp8E5M2_RTNE_func, 1}},
{Fp16_to_Fp8E5M2_RTNE, 1}},
// F8 -> BF16
{{F8E5M2TyID, BF16TyID, undefRounding}, {Fp8E5M2_to_Bf16_func, 4}},
{{F8E4M3TyID, BF16TyID, undefRounding},
{Fp8E4M3Nv_to_Bf16_func, 4}},
{{F8E5M2TyID, BF16TyID, undefRounding}, {Fp8E5M2_to_Bf16, 4}},
{{F8E4M3TyID, BF16TyID, undefRounding}, {Fp8E4M3Nv_to_Bf16, 4}},
// BF16 -> F8
{{BF16TyID, F8E5M2TyID, RoundingMode::RTZ},
{Bf16_to_Fp8E5M2_func, 4}},
{{BF16TyID, F8E5M2TyID, RoundingMode::RTZ}, {Bf16_to_Fp8E5M2, 4}},
{{BF16TyID, F8E5M2TyID, RoundingMode::RTNE},
{Bf16_to_Fp8E5M2_RTNE_func, 1}},
{{BF16TyID, F8E4M3TyID, RoundingMode::RTZ},
{Bf16_to_Fp8E4M3Nv_func, 4}},
{Bf16_to_Fp8E5M2_RTNE, 1}},
{{BF16TyID, F8E4M3TyID, RoundingMode::RTZ}, {Bf16_to_Fp8E4M3Nv, 4}},
{{BF16TyID, F8E4M3TyID, RoundingMode::RTNE},
{Bf16_to_Fp8E4M3Nv_RTNE_func, 1}},
{Bf16_to_Fp8E4M3Nv_RTNE, 1}},
// BF16 -> F16
{{BF16TyID, F16TyID, undefRounding}, {Bf16_to_Fp16_func, 2}},
{{BF16TyID, F16TyID, undefRounding}, {Bf16_to_Fp16, 2}},
};

std::tuple<TypeID, TypeID, RoundingMode> key = {
Expand Down Expand Up @@ -1097,6 +1091,7 @@ struct FpToFpOpConversion
auto [cvtFunc, numElements] =
getConversionFunc(srcType, dstType, roundingMode);
SmallVector<Value> inVals;
inVals.reserve(std::min(numElements, operands.size()));
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
inVals.push_back(operands[i][0]);
}
Expand Down Expand Up @@ -1323,9 +1318,8 @@ struct TruncFOpConversion
return {// Trunc uses the default rounding mode: RTNE
intel::convertFp32ToBf16(loc, rewriter, operands[0][0],
RoundingMode::RTNE)};
} else {
return {rewriter.create<LLVM::FPTruncOp>(loc, elemTy, operands[0][0])};
}
return {rewriter.create<LLVM::FPTruncOp>(loc, elemTy, operands[0][0])};
}
};

Expand Down Expand Up @@ -1488,7 +1482,6 @@ void populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
PatternBenefit benefit) {
using namespace mlir::triton::gpu;

patterns.add<PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis,
benefit);
Expand All @@ -1511,7 +1504,6 @@ void populateElementwiseOpToLLVMPatterns(
patterns.add<ExtFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<TruncFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<FPToSIOpConversion>(typeConverter, axisInfoAnalysis, benefit);

patterns.add<SIToFPOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<FpToFpOpConversion>(typeConverter, axisInfoAnalysis, benefit);

Expand Down

0 comments on commit d16a1dd

Please sign in to comment.