-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge commit 'd04f28864d1c1e6a3e0d6f16c4aa701c84310d4a'
- Loading branch information
Showing
44 changed files
with
896 additions
and
1,162 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" | ||
#include "triton/Conversion/TritonGPUToLLVM/Utility.h" | ||
|
||
namespace mlir { | ||
FailureOr<LLVM::LLVMFuncOp> | ||
convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, | ||
ConversionPatternRewriter &rewriter, | ||
const LLVMTypeConverter &converter); | ||
} | ||
|
||
namespace { | ||
|
||
using namespace mlir; | ||
using namespace mlir::triton; | ||
|
||
/// FuncOp legalization pattern that converts MemRef arguments to pointers to | ||
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type | ||
/// information. | ||
struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> { | ||
FuncOpConversion(LLVMTypeConverter &converter, int numWarps, | ||
PatternBenefit benefit) | ||
: ConvertOpToLLVMPattern(converter, benefit), numWarps(numWarps) {} | ||
|
||
/// Only retain those attributes that are not constructed by | ||
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument | ||
/// attributes. | ||
static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, | ||
SmallVectorImpl<NamedAttribute> &result) { | ||
|
||
for (const auto &attr : op->getAttrs()) { | ||
if (attr.getName() == SymbolTable::getSymbolAttrName() || | ||
attr.getName() == op.getFunctionTypeAttrName() || | ||
attr.getName() == "std.varargs" || | ||
(filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) | ||
continue; | ||
result.push_back(attr); | ||
} | ||
} | ||
|
||
triton::FuncOp amendFuncOp(triton::FuncOp funcOp, | ||
ConversionPatternRewriter &rewriter) const { | ||
// Push back a variable that indicates the current stack pointer of shared | ||
// memory to the function arguments. | ||
auto loc = funcOp.getLoc(); | ||
auto ctx = funcOp->getContext(); | ||
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); | ||
// 1. Modify the function type to add the new argument. | ||
auto funcTy = funcOp.getFunctionType(); | ||
auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); | ||
amendedInputTy.push_back(ptrTy); | ||
auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, | ||
funcTy.getResults()); | ||
// 2. Modify the argument attributes to add the new argument. | ||
SmallVector<NamedAttribute> amendedAttrs; | ||
filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); | ||
auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); | ||
amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); | ||
amendedAttrs.push_back(rewriter.getNamedAttr( | ||
funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); | ||
// 3. Add a new argument to the region | ||
auto amendedFuncOp = rewriter.create<triton::FuncOp>( | ||
funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); | ||
auto ®ion = funcOp.getBody(); | ||
region.addArgument(ptrTy, loc); | ||
rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), | ||
amendedFuncOp.end()); | ||
return amendedFuncOp; | ||
} | ||
|
||
LogicalResult | ||
matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
// Prevent LLVM's inliner to inline this function | ||
auto amendedFuncOp = funcOp; | ||
if (!LLVM::isKernel(funcOp)) | ||
amendedFuncOp = amendFuncOp(funcOp, rewriter); | ||
|
||
LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( | ||
amendedFuncOp, rewriter, *getTypeConverter()); | ||
if (!newFuncOp) { | ||
return failure(); | ||
} | ||
|
||
auto ctx = funcOp->getContext(); | ||
|
||
if (LLVM::isKernel(funcOp)) { | ||
// Set an attribute to indicate this function is a kernel entry. | ||
newFuncOp->setAttr("nvvm.kernel", | ||
rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); | ||
} else { | ||
// The noinline attribute will be used by the LLVM codegen to prevent | ||
// inlining. | ||
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267 | ||
newFuncOp.setPassthroughAttr( | ||
ArrayAttr::get(ctx, rewriter.getStringAttr("noinline"))); | ||
rewriter.eraseOp(amendedFuncOp); | ||
} | ||
// Set an attribute for maxntidx, it could be used in latter LLVM codegen | ||
// for `nvvm.annotation` metadata. | ||
newFuncOp->setAttr("nvvm.maxntid", | ||
rewriter.getDenseI32ArrayAttr(32 * numWarps)); | ||
rewriter.eraseOp(funcOp); | ||
return success(); | ||
} | ||
|
||
private: | ||
int numWarps{0}; | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::triton::populateFuncOpConversionPattern( | ||
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, | ||
PatternBenefit benefit) { | ||
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" | ||
#include "triton/Conversion/TritonGPUToLLVM/Utility.h" | ||
|
||
namespace { | ||
|
||
using namespace mlir; | ||
using namespace mlir::triton; | ||
|
||
struct GetProgramIdOpConversion | ||
: public ConvertOpToLLVMPattern<triton::GetProgramIdOp> { | ||
explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter, | ||
const TargetInfoBase &targetInfo, | ||
PatternBenefit benefit = 1) | ||
: ConvertOpToLLVMPattern<triton::GetProgramIdOp>(typeConverter, benefit), | ||
targetInfo(targetInfo) {} | ||
|
||
LogicalResult | ||
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
Value programId = targetInfo.programId(op->getLoc(), rewriter, | ||
op->getParentOfType<ModuleOp>(), | ||
op.getAxisAsInt()); | ||
rewriter.replaceOp(op, programId); | ||
return success(); | ||
} | ||
|
||
private: | ||
const TargetInfoBase &targetInfo; | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::triton::populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, | ||
RewritePatternSet &patterns, | ||
const TargetInfoBase &targetInfo, | ||
PatternBenefit benefit) { | ||
patterns.add<GetProgramIdOpConversion>(typeConverter, targetInfo, benefit); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.