From 548612eddd9a56c1d9477b2f6c5ed97cee79f3a9 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 25 Apr 2024 02:30:49 +0000 Subject: [PATCH] Replace GenISA ftof usage with arith::TruncFOp Signed-off-by: Whitney Tsang --- .../ElementwiseOpToLLVM.cpp | 32 ++----------------- 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp index 67d6da8aef..6d79c9606c 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1,5 +1,4 @@ #include "PatternTritonGPUOpToLLVM.h" -#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/MLIRContext.h" #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" @@ -1330,33 +1329,6 @@ struct FpToFpOpConversion ConversionPatternRewriter &rewriter, const Value &v, const RoundingMode rounding) { MLIRContext *ctx = rewriter.getContext(); - - LLVM::RoundingMode roundingMode; - switch (rounding) { - case RoundingMode::RTNE: - roundingMode = LLVM::RoundingMode::NearestTiesToEven; - break; - case RoundingMode::RTZ: - roundingMode = LLVM::RoundingMode::TowardZero; - break; - default: - llvm::errs() << "WARNING: unsupported rounding mode for f32->f16 " - "conversion: " - << stringifyRoundingMode(rounding) << "\n"; - llvm_unreachable(""); - } - - NamedAttrList convertedAttr; - convertedAttr.set(LLVM::ConstrainedFPTruncIntr::getRoundingModeAttrName(), - LLVM::RoundingModeAttr::get(ctx, roundingMode)); - convertedAttr.set( - LLVM::ConstrainedFPTruncIntr::getFPExceptionBehaviorAttrName(), - arith::getLLVMDefaultFPExceptionBehavior(*ctx)); - return rewriter.create(loc, f16_ty, v, - convertedAttr); - -#if 0 - // FIXME: test_typeconvert_downcast fails when lower to arith::TruncFOp. arith::RoundingMode roundingMode; switch (rounding) { case RoundingMode::RTNE: @@ -1373,7 +1345,6 @@ struct FpToFpOpConversion } return rewriter.create( loc, f16_ty, v, arith::RoundingModeAttr::get(ctx, roundingMode)); -#endif } std::pair @@ -2023,9 +1994,10 @@ struct TruncFOpConversion return {// Trunc uses the default rounding mode: RTNE FpToFpOpConversion::convertFp32ToBf16( loc, rewriter, operands[0][0], RoundingMode::RTNE)}; - } else { + } else if (!op.getRoundingModeAttr()) { return {rewriter.create(loc, elemTy, operands[0][0])}; } + return {}; } };