Skip to content

Commit

Permalink
[intel] Refactor shared memory representation in TTGIR
Browse files Browse the repository at this point in the history
  • Loading branch information
prathams417 authored and whitneywhtsang committed Mar 17, 2024
1 parent a37de50 commit 0b0b8d7
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 274 deletions.
3 changes: 3 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4526,6 +4526,9 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
# skip even if scratch buffer equal to lds_size, because real scratch buffer is typically larger due to padding
if scratch_shape[0] * scratch_shape[1] * int32_size >= lds_size:
pytest.skip("Scratch buffer is too large")
if is_xpu() and M == 128 and N == 128 and interm_layout and (dst_layout.sz_per_thread == [1, 8]
or dst_layout.sz_per_thread == [4, 4]):
pytest.skip("FIXME: out of resource: shared memory")

layouts = f"""
#src = {src_layout}
Expand Down
29 changes: 13 additions & 16 deletions test/Conversion/tritongpu_to_gen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.func @basic_alloc_tensor() {
// CHECK-NEXT: llvm.mlir.constant
// CHECK-NEXT: llvm.getelementptr
// CHECK-NEXT: llvm.bitcast
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #shared0>
%0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #shared0>
tt.return
}
}
Expand All @@ -331,20 +330,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :

#shared0 = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_extract_slice(%arg0: !llvm.ptr<3>)
tt.func @basic_extract_slice() {
// CHECK-LABEL: basic_subview(%arg0: !llvm.ptr<3>)
tt.func @basic_subview() {
// CHECK: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue
// CHECK-NEXT: llvm.add
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: llvm.add
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: llvm.add
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: llvm.mul
// CHECK-NEXT: llvm.add
Expand All @@ -354,8 +348,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-NEXT: llvm.add
// CHECK-NEXT: llvm.getelementptr
%index = arith.constant 1 : i32
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
%1 = triton_gpu.extract_slice %0[%index, 0, 0][1, 16, 32][1, 1, 1] : tensor<128x16x32xf32, #shared0> to tensor<16x32xf32, #shared0>
%zero = arith.constant 0 : i32
%0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x16x32xf32, #shared0>
%1 = triton_gpu.memdesc_subview %0[%index, %zero, %zero] : !tt.memdesc<128x16x32xf32, #shared0> -> !tt.memdesc<16x32xf32, #shared0>
tt.return
}
}
Expand Down Expand Up @@ -659,7 +654,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
// CHECK-SAME: vector<8xf32>, !llvm.ptr<3>
// CHECK: llvm.store
// CHECK-SAME: vector<8xf32>, !llvm.ptr<3>
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf32, #blocked0> -> tensor<128x32xf32, #shared0>
%0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0>
tt.return
}
}
Expand Down Expand Up @@ -938,11 +933,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} {
// CHECK-LABEL: test_base_index_cache
tt.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
// CHECK: llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.mlir.constant(0 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: llvm.call @_Z12get_local_idj([[ZERO]]) : (i32) -> i64
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf32, #blocked0> -> tensor<128x32xf32, #shared0>
%1 = triton_gpu.convert_layout %arg0 : tensor<128x32xf32, #blocked0> -> tensor<128x32xf32, #shared0>
%0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0>
%1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0>
tt.return
}
}
Expand All @@ -953,13 +949,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} {
// CHECK-LABEL: test_index_cache_different_block
tt.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
// CHECK: llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.mlir.constant(0 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: llvm.call @_Z12get_local_idj([[ZERO]]) : (i32) -> i64
%0 = triton_gpu.convert_layout %arg0 : tensor<128x32xf32, #blocked0> -> tensor<128x32xf32, #shared0>
%0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0>
cf.cond_br %arg1, ^bb1, ^bb2
^bb1: // pred: ^bb0
%1 = triton_gpu.convert_layout %arg0 : tensor<128x32xf32, #blocked0> -> tensor<128x32xf32, #shared0>
%1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0>
cf.br ^bb2
^bb2: // 2 preds: ^bb0, ^bb1
tt.return
Expand Down
263 changes: 121 additions & 142 deletions third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,131 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
Location loc, Value tensor,
DotOperandEncodingAttr bEncoding,
const SharedMemoryObject &smemObj,
TritonGPUToLLVMTypeConverter *typeConverter, Value thread);
const LLVMTypeConverter *typeConverter, Value thread);

} // namespace intel
} // namespace SharedToDotOperandDPAS

namespace {

// shared -> dot_operand if the result layout is dpas
Value lowerSharedToDotOperandDPAS(
triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor,
const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter,
const DpasEncodingAttr &dpasLayout,
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) {
auto loc = op.getLoc();
Value src = op.getSrc();
Value dst = op.getResult();

auto llvmElemTy = typeConverter->convertType(
src.getType().cast<RankedTensorType>().getElementType());

auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
llvmElemTy, rewriter);
Value res;
if (!isOuter) {
res = SharedToDotOperandDPAS::intel::convertLayout(
dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout,
smemObj, typeConverter, tid_val());
} else {
assert(false && "unsupported DPAS layout found");
}
return res;
}
// shared -> dpas_operand
LogicalResult lowerSharedToDotOperand(triton::gpu::LocalLoadOp op,
triton::gpu::LocalLoadOpAdaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) {
auto loc = op.getLoc();
auto dstEnc = op.getType().getEncoding().cast<DotOperandEncodingAttr>();
auto sharedLayout =
op.getSrc().getType().getEncoding().cast<SharedEncodingAttr>();

int K;
if (dstEnc.getOpIdx() == 0) // $a
K = op.getType().getShape()[sharedLayout.getOrder()[0]];
else // $b
K = op.getType().getShape()[sharedLayout.getOrder()[1]];
bool isOuter = K == 1;

Value res;
if (auto dpasLayout =
dstEnc.getParent().dyn_cast_or_null<DpasEncodingAttr>()) {
res = lowerSharedToDotOperandDPAS(op, adaptor, typeConverter, rewriter,
dpasLayout, dstEnc, isOuter);
} else if (auto blockedLayout =
dstEnc.getParent().dyn_cast_or_null<BlockedEncodingAttr>()) {
auto thread = getThreadId(rewriter, loc);
res = SharedToDotOperandFMA::convertLayout(
dstEnc.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout, thread,
loc, typeConverter, rewriter);
} else {
assert(false && "Unsupported dot operand layout found");
}

rewriter.replaceOp(op, res);
return success();
}

LogicalResult lowerSharedToDistributed(triton::gpu::LocalLoadOp op,
triton::gpu::LocalLoadOpAdaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) {
auto loc = op.getLoc();
auto srcTy = op.getSrc().getType();
auto dstTy = op.getResult().getType();
auto dstShape = dstTy.getShape();
assert(dstShape.size() <= 2 &&
"Unexpected rank of ConvertLayout(shared->blocked)");
auto srcSharedLayout = srcTy.getEncoding().cast<SharedEncodingAttr>();
auto dstLayout = dstTy.getEncoding();
auto inOrd = getOrder(srcSharedLayout);

auto smemObj = getSharedMemoryObjectFromStruct(
loc, adaptor.getSrc(), typeConverter->convertType(srcTy.getElementType()),
rewriter);
auto elemTy = typeConverter->convertType(dstTy.getElementType());

auto srcStrides =
getStridesFromShapeAndOrder(srcTy.getShape(), inOrd, loc, rewriter);
auto dstIndices = emitIndices(loc, rewriter, dstLayout, dstTy, true);

SmallVector<Value> outVals = loadSharedToDistributed(
op.getResult(), dstIndices, op.getSrc(), smemObj, elemTy, loc, rewriter);

Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);

return success();
}

struct LocalLoadOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::LocalLoadOp> {
public:
using ConvertOpToLLVMPattern<
triton::gpu::LocalLoadOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemDescType srcTy = op.getSrc().getType();
RankedTensorType dstTy = op.getType();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if (dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerSharedToDotOperand(op, adaptor, getTypeConverter(), rewriter);
}
if (srcLayout.isa<SharedEncodingAttr>() &&
isaDistributedLayout(dstLayout)) {
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
rewriter);
}
return failure();
}
};

struct ConvertLayoutOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
public:
Expand All @@ -49,21 +168,9 @@ struct ConvertLayoutOpConversion
RankedTensorType dstTy = op.getType();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if (isaDistributedLayout(srcLayout) &&
dstLayout.isa<SharedEncodingAttr>()) {
return lowerDistributedToShared(op, adaptor, rewriter);
}
if (srcLayout.isa<SharedEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerSharedToDotOperand(op, adaptor, rewriter);
}
if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) {
return lowerDistributedToDistributed(op, adaptor, rewriter);
}
if (srcLayout.isa<SharedEncodingAttr>() &&
isaDistributedLayout(dstLayout)) {
return lowerSharedToDistributed(op, adaptor, rewriter);
}
// TODO: to be implemented
llvm_unreachable("unsupported layout conversion");
return failure();
Expand Down Expand Up @@ -446,39 +553,6 @@ struct ConvertLayoutOpConversion
return success();
}

LogicalResult
lowerSharedToDistributed(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto srcTy = op.getSrc().getType();
auto dstTy = op.getResult().getType();
auto dstShape = dstTy.getShape();
assert(dstShape.size() <= 2 &&
"Unexpected rank of ConvertLayout(shared->blocked)");
auto srcSharedLayout = srcTy.getEncoding().cast<SharedEncodingAttr>();
auto dstLayout = dstTy.getEncoding();
auto inOrd = getOrder(srcSharedLayout);

auto smemObj = getSharedMemoryObjectFromStruct(
loc, adaptor.getSrc(),
getTypeConverter()->convertType(srcTy.getElementType()), rewriter);
auto elemTy = getTypeConverter()->convertType(dstTy.getElementType());

auto srcStrides =
getStridesFromShapeAndOrder(srcTy.getShape(), inOrd, loc, rewriter);
auto dstIndices = emitIndices(loc, rewriter, dstLayout, dstTy, true);

SmallVector<Value> outVals =
loadSharedToDistributed(op.getResult(), dstIndices, op.getSrc(),
smemObj, elemTy, loc, rewriter);

Value result =
packLLElements(loc, getTypeConverter(), outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);

return success();
}

Value computeStMatrixAddr(Value laneId, int matStride, Location loc,
ConversionPatternRewriter &rewriter) const {
Value rowInMat = urem(laneId, i32_val(8)); // row in the 8x8 matrix
Expand Down Expand Up @@ -514,107 +588,12 @@ struct ConvertLayoutOpConversion
getTypeConverter()->convertType(elemTy), smemBase, offset);
rewriter.create<triton::nvgpu::StoreMatrixOp>(loc, addr, inputs);
}

// blocked -> shared.
// Swizzling in shared memory to avoid bank conflict. Normally used for
// A/B operands of dots.
LogicalResult
lowerDistributedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
auto dstShapePerCTA = triton::gpu::getShapePerCTA(dstTy);
auto srcLayout = srcTy.getEncoding();
auto outOrd = dstTy.getEncoding().cast<SharedEncodingAttr>().getOrder();
assert(srcTy.getShape().size() == 2 ||
(srcTy.getShape().size() <= 3 && outOrd[2] == 0) &&
"Unexpected rank of ConvertLayout(blocked->shared)");
Value smemBase =
LLVM::utils::getSharedMemoryBase(loc, rewriter, op.getOperation());
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = ptr_ty(rewriter.getContext(), 3);
smemBase = bitcast(smemBase, elemPtrTy);

int32_t elemSize = elemTy.getIntOrFloatBitWidth();
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
auto dstStrides =
getStridesFromShapeAndOrder(dstShapePerCTA, outOrd, loc, rewriter);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy, false);
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
storeDistributedToShared(op.getSrc(), inVals, dstStrides, srcIndices,
op.getResult(), smemBase, elemTy, loc, rewriter);
auto smemObj = SharedMemoryObject(smemBase, elemTy, dstShapePerCTA, outOrd,
loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}

// shared -> dpas_operand
LogicalResult
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto dstEnc = op.getType().getEncoding().cast<DotOperandEncodingAttr>();
auto sharedLayout =
op.getSrc().getType().getEncoding().cast<SharedEncodingAttr>();

int K;
if (dstEnc.getOpIdx() == 0) // $a
K = op.getType().getShape()[sharedLayout.getOrder()[0]];
else // $b
K = op.getType().getShape()[sharedLayout.getOrder()[1]];
bool isOuter = K == 1;

Value res;
if (auto dpasLayout =
dstEnc.getParent().dyn_cast_or_null<DpasEncodingAttr>()) {
res = lowerSharedToDotOperandDPAS(op, adaptor, rewriter, dpasLayout,
dstEnc, isOuter);
} else if (auto blockedLayout =
dstEnc.getParent().dyn_cast_or_null<BlockedEncodingAttr>()) {
auto thread = getThreadId(rewriter, loc);
res = SharedToDotOperandFMA::convertLayout(
dstEnc.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout,
thread, loc, getTypeConverter(), rewriter);
} else {
assert(false && "Unsupported dot operand layout found");
}

rewriter.replaceOp(op, res);
return success();
}

// shared -> dot_operand if the result layout is dpas
Value lowerSharedToDotOperandDPAS(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, const DpasEncodingAttr &dpasLayout,
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const {
auto loc = op.getLoc();
Value src = op.getSrc();
Value dst = op.getResult();

auto llvmElemTy = getTypeConverter()->convertType(
src.getType().cast<RankedTensorType>().getElementType());

auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
llvmElemTy, rewriter);
Value res;
if (!isOuter) {
res = SharedToDotOperandDPAS::intel::convertLayout(
dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout,
smemObj, getTypeConverter(), tid_val());
} else {
assert(false && "unsupported DPAS layout found");
}
return res;
}
};
} // namespace

void mlir::triton::intel::populateConvertLayoutOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<ConvertLayoutOpConversion>(typeConverter, benefit);
patterns.add<LocalLoadOpConversion>(typeConverter, benefit);
}
Loading

0 comments on commit 0b0b8d7

Please sign in to comment.