Skip to content

Commit

Permalink
Merge commit 'd04f28864d1c1e6a3e0d6f16c4aa701c84310d4a'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Mar 18, 2024
2 parents 66b5079 + d04f288 commit 0469c40
Show file tree
Hide file tree
Showing 44 changed files with 896 additions and 1,162 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ jobs:
- name: Install Triton on ROCM
run: |
pip install --force-reinstall numpy==1.22.4
pip uninstall -y triton
cd python
pip install -e .
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);

void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);

void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
PatternBenefit benefit);

} // namespace triton
} // namespace mlir

Expand Down
2 changes: 2 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class TargetInfoBase {
Value val, int i) const = 0;
virtual Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter,
Value val, Value i) const = 0;
virtual Value programId(Location loc, ConversionPatternRewriter &rewriter,
ModuleOp moduleOp, int axis) const = 0;
virtual bool warpReduce(ConversionPatternRewriter &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce) const = 0;
Expand Down
71 changes: 71 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ using LLVM::SharedMemoryObject;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::SharedMemoryObject;
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
using ::mlir::triton::gpu::AMDWmmaEncodingAttr;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::CTALayoutAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
Expand Down Expand Up @@ -848,6 +849,71 @@ emitOffsetForMfmaLayout(const AMDMfmaEncodingAttr &mfmaLayout,
return offsets;
}

static void emitWmmaOffsetForCTA(const AMDWmmaEncodingAttr &wmmaLayout,
SmallVector<SmallVector<unsigned>> &offsets,
unsigned ctaOffsetX, unsigned ctaOffsetY) {
const unsigned elemsPerThreadPerGroup = 8;
auto warpSize = getWarpSize(wmmaLayout);
assert(warpSize == 32);
auto shapePerCta = getShapePerCTATile(wmmaLayout);
for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) {
offsets.push_back(
{ctaOffsetX * shapePerCta[0] + 2 * elem, ctaOffsetY * shapePerCta[1]});
}
}

static SmallVector<Value>
emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter,
const AMDWmmaEncodingAttr &wmmaLayout,
RankedTensorType type) {
auto shape = type.getShape();
auto _warpsPerCTA = wmmaLayout.getWarpsPerCTA();
assert(_warpsPerCTA.size() == 2);
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
i32_val(_warpsPerCTA[1])};
auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr();

Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(triton::gpu::getWarpSize(wmmaLayout));
Value laneId =
urem(threadId, i32_val(triton::gpu::getWarpSize(wmmaLayout) / 2));
Value threadIdPerWarp = urem(threadId, warpSize);

Value warpId = udiv(threadId, warpSize);
Value warpId0 = urem(warpId, warpsPerCTA[0]);
Value warpId1 = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]);

Value offWarp0 = mul(warpId0, i32_val(mnkDim[0]));
Value offWarp1 = mul(warpId1, i32_val(mnkDim[1]));

return {add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0),
add(laneId, offWarp1)};
}

static SmallVector<SmallVector<unsigned>>
emitOffsetForWmmaLayout(const AMDWmmaEncodingAttr &wmmaLayout,
RankedTensorType type) {
auto tensorShape = type.getShape();
SmallVector<SmallVector<unsigned>> offsets;
auto shapePerCTA = getShapePerCTA(wmmaLayout, tensorShape);
auto warpsPerCTA = wmmaLayout.getWarpsPerCTA();

SmallVector<unsigned> numWarpsPerDim(2);
auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr();
for (unsigned d = 0; d < 2; ++d) {
unsigned inPerCTA = std::min<unsigned>(tensorShape[d], shapePerCTA[d]);
unsigned inPerWarp = ceil<unsigned>(inPerCTA, warpsPerCTA[d]);
numWarpsPerDim[d] = ceil<unsigned>(inPerWarp, mnkDim[d]);
}

for (unsigned i = 0; i < numWarpsPerDim[0]; ++i) {
for (unsigned j = 0; j < numWarpsPerDim[1]; ++j) {
emitWmmaOffsetForCTA(wmmaLayout, offsets, i, j);
}
}
return offsets;
}

static SmallVector<SmallVector<unsigned>>
emitOffsetForLayout(Attribute layout, RankedTensorType type);

Expand Down Expand Up @@ -932,6 +998,8 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, Attribute layout,
type);
} else if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>()) {
result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type);
} else if (auto wmmaLayout = layout.dyn_cast<AMDWmmaEncodingAttr>()) {
result = emitBaseIndexForWmmaLayout(loc, rewriter, wmmaLayout, type);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(type.getShape());
Expand Down Expand Up @@ -969,6 +1037,9 @@ emitOffsetForLayout(Attribute layout, RankedTensorType type) {
if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>()) {
return emitOffsetForMfmaLayout(mfmaLayout, type);
}
if (auto wmmaLayout = layout.dyn_cast<AMDWmmaEncodingAttr>()) {
return emitOffsetForWmmaLayout(wmmaLayout, type);
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>())
return emitOffsetForSliceLayout(sliceLayout, type);
llvm_unreachable("unsupported emitOffsetForLayout");
Expand Down
9 changes: 8 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods<MemoryEf
let description = [{
This operation allocates buffer in shared memory and return a descriptor
containing the address and a view of the buffer.

Explicitly deallocating a buffer is optional; see local_dealloc.
}];
let arguments = (ins Optional<TT_Tensor>:$init);

Expand All @@ -184,7 +186,12 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree<SharedM
let summary = "dealloc buffer";

let description = [{
This operation deallocate a buffer explicitly. Using the buffer after this operation is undefined.
This operation deallocates a buffer explicitly. Using the buffer after this
operation is undefined.

This operation is optional. If you don't explicitly dealloc a buffer, the
compiler assumes it's deallocated at the first point that post-dominates all
uses of the alloc.
}];

let arguments = (ins TT_MemDescType:$ptr);
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ add_triton_library(TritonGPUToLLVM
ReduceOpToLLVM.cpp
ScanOpToLLVM.cpp
ConvertLayoutOpToLLVM.cpp
ControlFlowOpToLLVM.cpp
FuncOpToLLVM.cpp
SPMDOpToLLVM.cpp

DEPENDS
TritonGPUConversionPassIncGen
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "PatternTritonGPUOpToLLVM.h"
#include "Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

namespace {

Expand Down Expand Up @@ -133,7 +133,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {

} // namespace

void mlir::triton::NVIDIA::populateControlFlowOpToLLVMPattern(
void mlir::triton::populateControlFlowOpToLLVMPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<ReturnOpConversion>(typeConverter, benefit);
Expand Down
116 changes: 116 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

namespace mlir {
FailureOr<LLVM::LLVMFuncOp>
convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &converter);
}

namespace {

using namespace mlir;
using namespace mlir::triton;

/// FuncOp legalization pattern that converts MemRef arguments to pointers to
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information.
struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
PatternBenefit benefit)
: ConvertOpToLLVMPattern(converter, benefit), numWarps(numWarps) {}

/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs,
SmallVectorImpl<NamedAttribute> &result) {

for (const auto &attr : op->getAttrs()) {
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
attr.getName() == op.getFunctionTypeAttrName() ||
attr.getName() == "std.varargs" ||
(filterArgAttrs && attr.getName() == op.getArgAttrsAttrName()))
continue;
result.push_back(attr);
}
}

triton::FuncOp amendFuncOp(triton::FuncOp funcOp,
ConversionPatternRewriter &rewriter) const {
// Push back a variable that indicates the current stack pointer of shared
// memory to the function arguments.
auto loc = funcOp.getLoc();
auto ctx = funcOp->getContext();
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
// 1. Modify the function type to add the new argument.
auto funcTy = funcOp.getFunctionType();
auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs());
amendedInputTy.push_back(ptrTy);
auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy,
funcTy.getResults());
// 2. Modify the argument attributes to add the new argument.
SmallVector<NamedAttribute> amendedAttrs;
filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs);
auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs());
amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx));
amendedAttrs.push_back(rewriter.getNamedAttr(
funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs)));
// 3. Add a new argument to the region
auto amendedFuncOp = rewriter.create<triton::FuncOp>(
funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs);
auto &region = funcOp.getBody();
region.addArgument(ptrTy, loc);
rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(),
amendedFuncOp.end());
return amendedFuncOp;
}

LogicalResult
matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Prevent LLVM's inliner to inline this function
auto amendedFuncOp = funcOp;
if (!LLVM::isKernel(funcOp))
amendedFuncOp = amendFuncOp(funcOp, rewriter);

LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp(
amendedFuncOp, rewriter, *getTypeConverter());
if (!newFuncOp) {
return failure();
}

auto ctx = funcOp->getContext();

if (LLVM::isKernel(funcOp)) {
// Set an attribute to indicate this function is a kernel entry.
newFuncOp->setAttr("nvvm.kernel",
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
} else {
// The noinline attribute will be used by the LLVM codegen to prevent
// inlining.
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267
newFuncOp.setPassthroughAttr(
ArrayAttr::get(ctx, rewriter.getStringAttr("noinline")));
rewriter.eraseOp(amendedFuncOp);
}
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata.
newFuncOp->setAttr("nvvm.maxntid",
rewriter.getDenseI32ArrayAttr(32 * numWarps));
rewriter.eraseOp(funcOp);
return success();
}

private:
int numWarps{0};
};

} // namespace

void mlir::triton::populateFuncOpConversionPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps,
PatternBenefit benefit) {
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
}
38 changes: 38 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

namespace {

using namespace mlir;
using namespace mlir::triton;

struct GetProgramIdOpConversion
: public ConvertOpToLLVMPattern<triton::GetProgramIdOp> {
explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<triton::GetProgramIdOp>(typeConverter, benefit),
targetInfo(targetInfo) {}

LogicalResult
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value programId = targetInfo.programId(op->getLoc(), rewriter,
op->getParentOfType<ModuleOp>(),
op.getAxisAsInt());
rewriter.replaceOp(op, programId);
return success();
}

private:
const TargetInfoBase &targetInfo;
};

} // namespace

void mlir::triton::populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit) {
patterns.add<GetProgramIdOpConversion>(typeConverter, targetInfo, benefit);
}
11 changes: 8 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,14 +523,19 @@ SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
}
return multiDimOffset;
}
if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>()) {
if (layout.isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>()) {
auto multiDimBase =
emitBaseIndexForLayout(loc, rewriter, layout, type, false);
SmallVector<SmallVector<unsigned>> offsets;
assert(rank == 2);
SmallVector<Value> multiDimOffset(rank);
emitMfmaOffsetForCTA(mfmaLayout, offsets, multiDimCTAInRepId[0],
multiDimCTAInRepId[1]);
if (auto mfmaLayout = layout.dyn_cast<AMDMfmaEncodingAttr>()) {
emitMfmaOffsetForCTA(mfmaLayout, offsets, multiDimCTAInRepId[0],
multiDimCTAInRepId[1]);
} else if (auto wmmaLayout = layout.dyn_cast<AMDWmmaEncodingAttr>()) {
emitWmmaOffsetForCTA(wmmaLayout, offsets, multiDimCTAInRepId[0],
multiDimCTAInRepId[1]);
}
multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0]));
multiDimOffset[1] = add(multiDimBase[1], i32_val(offsets[elemId][1]));
return multiDimOffset;
Expand Down
Loading

0 comments on commit 0469c40

Please sign in to comment.