From a05990485c7000036ab34fe93a5ed0cbeb93b997 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 25 Apr 2024 16:04:47 +0000 Subject: [PATCH] Revert "Replace GenISA ftof usage with arith::TruncFOp" This reverts commit 548612eddd9a56c1d9477b2f6c5ed97cee79f3a9. --- .../ElementwiseOpToLLVM.cpp | 32 +++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp index 6d79c9606c..67d6da8aef 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1,4 +1,5 @@ #include "PatternTritonGPUOpToLLVM.h" +#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/MLIRContext.h" #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" @@ -1329,6 +1330,33 @@ struct FpToFpOpConversion ConversionPatternRewriter &rewriter, const Value &v, const RoundingMode rounding) { MLIRContext *ctx = rewriter.getContext(); + + LLVM::RoundingMode roundingMode; + switch (rounding) { + case RoundingMode::RTNE: + roundingMode = LLVM::RoundingMode::NearestTiesToEven; + break; + case RoundingMode::RTZ: + roundingMode = LLVM::RoundingMode::TowardZero; + break; + default: + llvm::errs() << "WARNING: unsupported rounding mode for f32->f16 " + "conversion: " + << stringifyRoundingMode(rounding) << "\n"; + llvm_unreachable(""); + } + + NamedAttrList convertedAttr; + convertedAttr.set(LLVM::ConstrainedFPTruncIntr::getRoundingModeAttrName(), + LLVM::RoundingModeAttr::get(ctx, roundingMode)); + convertedAttr.set( + LLVM::ConstrainedFPTruncIntr::getFPExceptionBehaviorAttrName(), + arith::getLLVMDefaultFPExceptionBehavior(*ctx)); + return rewriter.create(loc, f16_ty, v, + convertedAttr); + +#if 0 + // FIXME: test_typeconvert_downcast fails when lower to arith::TruncFOp. arith::RoundingMode roundingMode; switch (rounding) { case RoundingMode::RTNE: @@ -1345,6 +1373,7 @@ struct FpToFpOpConversion } return rewriter.create( loc, f16_ty, v, arith::RoundingModeAttr::get(ctx, roundingMode)); +#endif } std::pair @@ -1994,10 +2023,9 @@ struct TruncFOpConversion return {// Trunc uses the default rounding mode: RTNE FpToFpOpConversion::convertFp32ToBf16( loc, rewriter, operands[0][0], RoundingMode::RTNE)}; - } else if (!op.getRoundingModeAttr()) { + } else { return {rewriter.create(loc, elemTy, operands[0][0])}; } - return {}; } };