diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index 0085764da7..8cbedecae1 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -12,6 +12,7 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" @@ -44,20 +45,6 @@ using namespace mlir::triton; // Helper Functions //===----------------------------------------------------------------------===// -static LLVM::LLVMFuncOp -getOrCreateFunction(StringRef funcName, Type retType, ArrayRef argTypes, - ModuleOp moduleOp, Location loc, - ConversionPatternRewriter &rewriter) { - Operation *funcOp = moduleOp.lookupSymbol(funcName); - if (funcOp) - return cast(funcOp); - - auto funcType = LLVM::LLVMFunctionType::get(retType, argTypes); - ConversionPatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - return rewriter.create(loc, funcName, funcType); -}; - static LLVM::CallOp createDeviceFunctionCall( ConversionPatternRewriter &rewriter, StringRef funcName, Type retType, ArrayRef argTypes, ArrayRef args, bool convergent = false) { @@ -68,7 +55,7 @@ static LLVM::CallOp createDeviceFunctionCall( rewriter.getArrayAttr(StringAttr::get(context, "convergent")); LLVM::LLVMFuncOp funcOp = - getOrCreateFunction(funcName, retType, argTypes, moduleOp, loc, rewriter); + LLVM::lookupOrCreateFn(moduleOp, funcName, argTypes, retType); funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC); if (convergent) funcOp.setPassthroughAttr(convergentAttr); @@ -215,10 +202,10 @@ static LLVM::CallOp createGenISADPAS(TritonGEN::MatrixDPASOp op, std::string funcName = llvm::GenISAIntrinsic::getName( llvm::GenISAIntrinsic::GenISA_sub_group_dpas, llvmTypes); - LLVM::LLVMFuncOp funcOp = getOrCreateFunction( - funcName, resType, - {opTypes[0], aTy, bTy, int32Ty, int32Ty, int32Ty, int32Ty, int1Ty}, - moduleOp, loc, rewriter); + ArrayRef argTypes{opTypes[0], aTy, bTy, int32Ty, + int32Ty, int32Ty, int32Ty, int1Ty}; + LLVM::LLVMFuncOp funcOp = + LLVM::lookupOrCreateFn(moduleOp, funcName, argTypes, resType); funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC); auto precA = rewriter.create(loc, int32Ty,