Skip to content

Commit

Permalink
Merge commit '82e7a32179d6d3ecadac88a06916ba2b52bcfbdb'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Dec 20, 2024
2 parents 13725c1 + 82e7a32 commit b9afb6c
Show file tree
Hide file tree
Showing 32 changed files with 799 additions and 564 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ namespace triton {
constexpr int patternBenefitDefault = 1;
constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
constexpr int patternBenefitClampOptimizedPattern = 20;
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;

struct BackendCallbacks {
/**
Expand All @@ -50,7 +49,7 @@ void populateElementwiseOpToLLVMPatterns(
// callback receives 1) the current source op, 2) the number of issued LLVM
// instructions and 3) their input types. Each MLIR backend can provide a
// callback and, thus, handle backend-specific behaviors.
void populateMemoryOpToLLVMPattern(
void populateMemoryOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit,
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);
Expand Down Expand Up @@ -102,10 +101,6 @@ void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);

void populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit);

void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
Expand Down
3 changes: 1 addition & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -720,8 +720,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
auto ans = mmaLayout.getVersionMajor() == 3 &&
dotOperandLayout.getOpIdx() == 0 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
!cvtNeedsSharedMemory(parentTy, srcTy) &&
(elementTypeSize == 16 || elementTypeSize == 8) &&
!cvtNeedsSharedMemory(parentTy, srcTy) && elementTypeSize == 8 &&
dotOperandLayout.getKWidth() == 32 / elementTypeSize;
return ans;
}
Expand Down
52 changes: 17 additions & 35 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@

namespace {

using ::mlir::LLVM::getMultiDimOffset;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getWrappedMultiDimOffset;
using ::mlir::LLVM::linearize;

using namespace mlir;
using namespace mlir::triton::gpu;

// XXX(Keren): A temporary knob to control the use of legacy MMA conversion
Expand Down Expand Up @@ -105,13 +101,14 @@ struct ConvertLayoutOpConversion
// of performance issue observed.
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, type,
multiDimCTAInRepId, shapePerCTATile);
SmallVector<Value> multiDimOffsetWrapped = getWrappedMultiDimOffset(
rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile,
shapePerCTA);
Value offset = linearize(rewriter, loc, multiDimOffsetWrapped,
paddedRepShape, outOrd);
LLVM::getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId,
type, multiDimCTAInRepId, shapePerCTATile);
SmallVector<Value> multiDimOffsetWrapped =
LLVM::getWrappedMultiDimOffset(rewriter, loc, multiDimOffset,
origRepShape, shapePerCTATile,
shapePerCTA);
Value offset = LLVM::linearize(rewriter, loc, multiDimOffsetWrapped,
paddedRepShape, outOrd);
auto elemPtrTy = smemBase.getType();
Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec);
Expand Down Expand Up @@ -267,7 +264,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// conversions. TODO(jlebar): Eventually we want this to be the only pattern.
explicit ConvertLayoutOpUsingLinearLayoutsConversion(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
PatternBenefit benefit = 2)
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}

Expand Down Expand Up @@ -395,16 +392,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
if (!layoutIsOK(srcTy.getEncoding()) || !layoutIsOK(dstTy.getEncoding())) {
return failure();
}
// FIXME [Dot LL] Remove this once we implement this trick in LLs
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) {
return failure();
}

// The following check can be removed when generalized warp shuffle
// conversions are ready:
if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) {
return failure();
}

assert(cvtNeedsSharedMemory(srcTy, dstTy));

Expand Down Expand Up @@ -666,22 +653,17 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion

} // namespace

void mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern(
void mlir::triton::populateConvertLayoutOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit) {
if (useLegacyMMAConversion) {
// Prioritize the legacy MMA conversion over the LinearLayout conversion.
// Only for debugging purposes.
patterns.add<ConvertLayoutOpConversion>(typeConverter, targetInfo,
benefit.getBenefit() + 1);
}
patterns.add<ConvertLayoutOpUsingLinearLayoutsConversion>(
typeConverter, targetInfo, benefit);
}

void mlir::triton::populateConvertLayoutOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit) {
// We prefer using the linear layout conversion, so it gets a higher benefit.
// Eventually the LL conversion will subsume all of the others and be the only
// one left.
mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern(
typeConverter, targetInfo, patterns, benefit.getBenefit() + 1);
patterns.add<ConvertLayoutOpBlockedToDotOpShortcutConversion>(
typeConverter, targetInfo, benefit);
patterns.add<ConvertLayoutOpConversion>(typeConverter, targetInfo, benefit);
}
43 changes: 7 additions & 36 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,33 +121,12 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {

// FIXME [Dot LL]
// Do for all DotOperandEncodingAttr once we have LLs for all of them
static bool isSupportedDotOpLayout(MemDescType srcTy,
RankedTensorType dstTy) {
auto srcLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
auto dstLayout = dstTy.getEncoding();
auto bitwidth = dstTy.getElementTypeBitWidth();
auto rank = dstTy.getRank();
static bool isSupportedLayout(Attribute dstLayout) {
if (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
LinearEncodingAttr>(dstLayout))
return true;
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
auto vecWidth = 32 / bitwidth;
auto kWidth = dot.getKWidth();
auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2;
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
auto needTrans = kOrder != srcLayout.getOrder()[0];
auto canUseLdmatrix =
(bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth);
if (mma.isHopper()) {
// I think we should be able to remove this condition, but it's here
// as the legacy ldmatrix path does not support it
canUseLdmatrix &= srcTy.getElementTypeBitWidth() * kWidth == 32;
}
// If we remove this one, ldmatrix will IMA. It can probably be relaxed
// though
canUseLdmatrix &=
srcTy.getShape()[0] >= 8 &&
srcTy.getShape()[1] >= 4 * kWidth & dstTy.getRank() <= 2;
return !canUseLdmatrix;
}
if (isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>(dot.getParent()))
if (isa<MmaEncodingTrait>(dot.getParent()))
return true;
}
return false;
Expand All @@ -156,12 +135,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
LogicalResult
matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemDescType srcTy = op.getSrc().getType();
RankedTensorType dstTy = op.getType();
Attribute dstLayout = dstTy.getEncoding();
if (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
LinearEncodingAttr>(dstLayout) ||
isSupportedDotOpLayout(srcTy, dstTy)) {
if (isSupportedLayout(dstLayout)) {
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
rewriter);
}
Expand Down Expand Up @@ -198,11 +174,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
auto loc = op.getLoc();
auto srcTy = op.getSrc().getType();
auto dstTy = op.getResult().getType();
auto dstShape = dstTy.getShape();
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
assert((!isa<DotOperandEncodingAttr>(dstTy.getEncoding()) ||
isSupportedDotOpLayout(srcTy, dstTy)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");

auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
loc, adaptor.getSrc(),
Expand Down Expand Up @@ -265,7 +236,7 @@ struct LocalStoreOpConversion

} // namespace

void mlir::triton::populateMemoryOpToLLVMPattern(
void mlir::triton::populateMemoryOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit,
std::optional<BackendCallbacks> backendCallbacks) {
Expand Down
26 changes: 22 additions & 4 deletions lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ namespace {
// dot(a, b, inputPrecision="tf32x3") ->
// let aBig = f32ToTF32(a), aSmall = a - aBig;
// let bBig = f32ToTF32(b), bSmall = b - bBig;
// dot(aSmall, bBig, inputPrecision="tf32") +
// dot(aBig, bSmall, inputPrecision="tf32") +
// dot(aBig, bBig, inputPrecision="tf32")
// let small = dot(aSmall, bBig, inputPrecision="tf32") +
// dot(aBig, bSmall, inputPrecision="tf32")
// let masked_nans = replaceNansWithZeros(small)
// let big = dot(aBig, bBig, inputPrecision="tf32")
// return big + masked_nans;
class TF32x3 : public OpRewritePattern<DotOp> {
public:
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -62,6 +64,13 @@ class TF32x3 : public OpRewritePattern<DotOp> {
InputPrecision::TF32,
dotOp.getMaxNumImpreciseAcc());
};
auto replaceNansWithZeros = [&](Value value) -> Value {
auto nans = rewriter.create<arith::CmpFOp>(
dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value);
auto zero = zeroLike(value);
return rewriter.create<arith::SelectOp>(dotOp->getLoc(), nans, zero,
value);
};

auto aBig = f32ToTF32(dotOp.getA());
auto aSmall = sub(dotOp.getA(), aBig);
Expand All @@ -73,7 +82,16 @@ class TF32x3 : public OpRewritePattern<DotOp> {

auto dot1 = dot(aSmall, bBig, zero);
auto dot2 = dot(aBig, bSmall, dot1);
auto dot3 = dot(aBig, bBig, dot2);

// If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0.
// If rhs is +infinity, we will have:
// +infinity * 1.0 = +infinity
// +infinity * 0.0 = NaN
// We would get the wrong result if we sum these partial products. Instead,
// we must override any accumulated result if the last partial product is
// non-finite.
auto dot2withZeroedNans = replaceNansWithZeros(dot2);
auto dot3 = dot(aBig, bBig, dot2withZeroedNans);

auto sum = add(dot3, dotOp.getC());

Expand Down
5 changes: 4 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,10 @@ struct MMAV3UseRegOperand
dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/kWidth);
auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(),
dotOperandEnc);
if (!matchMmaV3AndDotOperandLayout(srcTy, newTy))
// TODO(Keren): relax the condition once
// https://github.com/triton-lang/triton/pull/5419 is merged
if (!cvtReordersRegisters(srcTy, newTy) &&
!matchMmaV3AndDotOperandLayout(srcTy, newTy))
return failure();

Value newOperand =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
return op;
}

assert("don't know how to predicate this op" && false);
op->emitError("pipeliner doesn't know how to predicate this op.");
llvm::report_fatal_error("Fatal pipeliner error");
return op;
}

Expand Down
42 changes: 42 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,6 +1690,48 @@ def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr):
assert (torch.equal(X, Y))


@pytest.mark.interpreter
@pytest.mark.skipif((is_cuda() and torch.cuda.get_device_capability()[0] < 9) or is_hip(),
reason="Requires compute capability >= 9 for NV")
def test_load_scope_sem_coop_grid_cta_not_one(device):

@triton.jit
def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr):
numel = 512
offset = tl.program_id(0) * BLOCK_SIZE
index = offset
mask = index < numel
a = tl.load(ptrs, mask=mask)
tl.store(ptrs, a)

block_size = 128
data = torch.zeros((128, ), device=device, dtype=torch.float32)

out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=4, launch_cooperative_grid=True)
out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=4, launch_cooperative_grid=False)


@pytest.mark.interpreter
@pytest.mark.skipif(is_hip(), reason="Not implemented for AMD At this moment")
def test_load_scope_sem_coop_grid_cta_one(device):

@triton.jit
def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr):
numel = 512
offset = tl.program_id(0) * BLOCK_SIZE
index = offset
mask = index < numel
a = tl.load(ptrs, mask=mask)
tl.store(ptrs, a)

block_size = 128
data = torch.zeros((128, ), device=device, dtype=torch.float32)

# Should do nothing different for num_ctas=1 (with coop launch grid)
out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=1, launch_cooperative_grid=True)
out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=1, launch_cooperative_grid=False)


# ---------------
# test cast
# ---------------
Expand Down
12 changes: 8 additions & 4 deletions python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import time
import inspect
from typing import Dict
from typing import Dict, Tuple, List, Optional

from .jit import KernelInterface
from .errors import OutOfResources, PTXASError
Expand All @@ -23,7 +23,7 @@ def __init__(
restore_value,
pre_hook=None,
post_hook=None,
prune_configs_by: Dict = None,
prune_configs_by: Optional[Dict] = None,
warmup=None,
rep=None,
use_cuda_graph=False,
Expand All @@ -40,7 +40,7 @@ def __init__(
else:
self.configs = configs
self.keys = key
self.cache = {}
self.cache: Dict[Tuple, Config] = {}
self.arg_names = arg_names

# Reset to zero or restore values
Expand Down Expand Up @@ -211,14 +211,18 @@ def run(self, *args, **kwargs):
self.nargs = None
return ret

def prune_configs(self, kwargs):
def prune_configs(self, kwargs: Dict) -> List[Config]:
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
elif not isinstance(top_k, int):
# Slice index must be an integer
raise TypeError(f"Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int")

if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(
Expand Down
3 changes: 2 additions & 1 deletion python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def _call_hook(
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})"
repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}, launch_cooperative_grid={options.launch_cooperative_grid}]({arg_reprs})"

class JitFunctionInfo:

Expand All @@ -521,6 +521,7 @@ def __init__(self, module, name, jit_function):
'num_ctas': options.num_ctas,
'num_stages': options.num_stages,
'enable_fp_fusion': options.enable_fp_fusion,
'launch_cooperative_grid': options.launch_cooperative_grid,
'extern_libs': options.extern_libs,
'configs': configs,
'specialization_data': specialization_data,
Expand Down
Loading

0 comments on commit b9afb6c

Please sign in to comment.