Skip to content

Commit

Permalink
Address code review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
  • Loading branch information
etiotto committed Mar 11, 2024
1 parent 7cae167 commit 32dbaa8
Showing 1 changed file with 6 additions and 19 deletions.
25 changes: 6 additions & 19 deletions third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
Expand Down Expand Up @@ -44,20 +45,6 @@ using namespace mlir::triton;
// Helper Functions
//===----------------------------------------------------------------------===//

static LLVM::LLVMFuncOp
getOrCreateFunction(StringRef funcName, Type retType, ArrayRef<Type> argTypes,
ModuleOp moduleOp, Location loc,
ConversionPatternRewriter &rewriter) {
Operation *funcOp = moduleOp.lookupSymbol(funcName);
if (funcOp)
return cast<LLVM::LLVMFuncOp>(funcOp);

auto funcType = LLVM::LLVMFunctionType::get(retType, argTypes);
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
return rewriter.create<LLVM::LLVMFuncOp>(loc, funcName, funcType);
};

static LLVM::CallOp createDeviceFunctionCall(
ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
ArrayRef<Type> argTypes, ArrayRef<Value> args, bool convergent = false) {
Expand All @@ -68,7 +55,7 @@ static LLVM::CallOp createDeviceFunctionCall(
rewriter.getArrayAttr(StringAttr::get(context, "convergent"));

LLVM::LLVMFuncOp funcOp =
getOrCreateFunction(funcName, retType, argTypes, moduleOp, loc, rewriter);
LLVM::lookupOrCreateFn(moduleOp, funcName, argTypes, retType);
funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
if (convergent)
funcOp.setPassthroughAttr(convergentAttr);
Expand Down Expand Up @@ -215,10 +202,10 @@ static LLVM::CallOp createGenISADPAS(TritonGEN::MatrixDPASOp op,
std::string funcName = llvm::GenISAIntrinsic::getName(
llvm::GenISAIntrinsic::GenISA_sub_group_dpas, llvmTypes);

LLVM::LLVMFuncOp funcOp = getOrCreateFunction(
funcName, resType,
{opTypes[0], aTy, bTy, int32Ty, int32Ty, int32Ty, int32Ty, int1Ty},
moduleOp, loc, rewriter);
ArrayRef<Type> argTypes{opTypes[0], aTy, bTy, int32Ty,
int32Ty, int32Ty, int32Ty, int1Ty};
LLVM::LLVMFuncOp funcOp =
LLVM::lookupOrCreateFn(moduleOp, funcName, argTypes, resType);
funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);

auto precA = rewriter.create<LLVM::ConstantOp>(loc, int32Ty,
Expand Down

0 comments on commit 32dbaa8

Please sign in to comment.