Skip to content

Commit

Permalink
Add the nested layout with dot op layout.
Browse files Browse the repository at this point in the history
Such as `#triton_gpu.slice<{dim = 1, parent = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>}>`
  • Loading branch information
chengjunlu committed Apr 25, 2024
1 parent eb631f8 commit 20237ca
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 0 deletions.
17 changes: 17 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,11 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {

InterfaceMethod<"Return size per thread for dot operands.", "SmallVector<unsigned>",
"getSizePerThreadForOperands", (ins "unsigned":$opIdx)>,

InterfaceMethod<"Return element sizes per thread for dot operands.", "SmallVector<unsigned>",
"getElemsPerThreadForOperands", (ins "ArrayRef<int64_t>":$tensorShape,
"Type":$eltTy,
"unsigned":$opIdx)>,
];
}

Expand Down Expand Up @@ -904,6 +909,10 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
return contigPerThread;
};

SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const {
llvm_unreachable("getElemsPerThreadForOperands is not supported.");
};

}];

let hasCustomAssemblyFormat = 1;
Expand Down Expand Up @@ -968,6 +977,10 @@ Suppose we have a tensor with shape [32, 48], `warpsPerCTA` set to [2, 3].
SmallVector<unsigned> contigPerThread(rank, 1);
return contigPerThread;
};

SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const {
llvm_unreachable("getElemsPerThreadForOperands is not supported.");
};
}];
}

Expand Down Expand Up @@ -1169,6 +1182,10 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
return contigPerThread;
};

SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const {
llvm_unreachable("getElemsPerThreadForOperands is not supported.");
};

}];

let hasCustomAssemblyFormat = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ along the row (resp. col) dimension.
SmallVector<unsigned> getShapeC() const;
SmallVector<int64_t> getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const;
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const;
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;

Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,9 @@ unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
SmallVector<unsigned>
DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
Type eltTy) const {
if (auto mmaParent = getParent().dyn_cast<MmaEncodingTrait>()) {
return mmaParent.getElemsPerThreadForOperands(shape, eltTy, getOpIdx());
}
llvm_unreachable("getElemsPerThread is not supported for dot operand");
return SmallVector<unsigned>();
}
Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/TritonIntelGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,15 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
}
}

SmallVector<unsigned> DpasEncodingAttr::getElemsPerThreadForOperands(
ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const {
SmallVector<unsigned> sizePerThread = getSizePerThreadForOperands(opIdx);
SmallVector<int64_t> repetitions = getDPASRepetitions(shape, opIdx);

return SmallVector<unsigned>{(unsigned)(sizePerThread[0] * repetitions[0]),
(unsigned)(sizePerThread[1] * repetitions[1])};
};

SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
unsigned threadsPerWarp = getSubGroupSize();
auto shapeC = getShapeC();
Expand Down
169 changes: 169 additions & 0 deletions third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,87 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout,
}
}

static SmallVector<SmallVector<unsigned>>
emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout,
RankedTensorType type) {

if (auto dpasLayout = dotLayout.getParent().dyn_cast<DpasEncodingAttr>()) {
ArrayRef<int64_t> shape = type.getShape();
SmallVector<SmallVector<unsigned>> offsets;
auto shapePerCTA = triton::gpu::getShapePerCTA(type);

auto opIdx = dotLayout.getOpIdx();
SmallVector<int64_t> numReps =
dpasLayout.getDPASRepetitions(shapePerCTA, opIdx);

SmallVector<unsigned> warpShape;
if (opIdx == 0) {
warpShape = dpasLayout.getShapeA();
} else {
warpShape = dpasLayout.getShapeB();
}

unsigned warpSize = triton::gpu::getWarpSize(dpasLayout);
unsigned numElemPerInstPerThread = product<unsigned>(warpShape) / warpSize;

unsigned systolicDepth = dpasLayout.getSystolicDepth();
unsigned repeatCount = dpasLayout.getRepeatCount();
unsigned executionSize = dpasLayout.getExecutionSize();
unsigned opsPerChannel = dpasLayout.getOpsPerChannel();

unsigned rowsPerWarp, numElemPerInstPerRowPerThread;
switch (opIdx) {
case 0: {
assert((opsPerChannel == 4 || opsPerChannel == 2 || opsPerChannel == 1) &&
"invalid opsPerChannel number.");
SmallVector<unsigned> shapeA = dpasLayout.getShapeA();
// Unlike the operand B, to pack the value to i16 for scalar bit width
// <=16.
unsigned packedOpsPerLane = opsPerChannel == 4 ? 2 : 1;
unsigned packedColNum = shapeA[1] / packedOpsPerLane;
if (warpSize < packedColNum) {
llvm::report_fatal_error(
"DpasEncodingAttr sub-group size could not "
"be smaller than the threads required per row for A operand.");
}
rowsPerWarp = warpSize / packedColNum;
numElemPerInstPerRowPerThread = packedOpsPerLane;
} break;
case 1: {
if (warpSize < executionSize) {
llvm::report_fatal_error(
"DpasEncodingAttr sub-group size could not "
"be smaller than the execution size for B operand.");
}
rowsPerWarp = warpSize / executionSize;
rowsPerWarp = rowsPerWarp * opsPerChannel;
numElemPerInstPerRowPerThread = 1;
} break;
}

auto shapePerCTATile = triton::gpu::getShapePerCTATile(dotLayout);
int64_t numRepOuter = numReps[opIdx];
int64_t numRepK = numReps[(opIdx == 0) ? 1 : 0];
for (int dimOuter = 0; dimOuter < numRepOuter; ++dimOuter)
for (int k = 0; k < numRepK; ++k)
for (unsigned elemId = 0; elemId < numElemPerInstPerThread; ++elemId) {
uint32_t repRowIndex =
shapePerCTATile[0] * (opIdx == 0 ? dimOuter : k);
uint32_t repColIndex =
shapePerCTATile[1] * (opIdx == 0 ? k : dimOuter);
uint32_t elemRowIndex =
(elemId / numElemPerInstPerRowPerThread) * rowsPerWarp;
uint32_t elemColIndex = elemId % numElemPerInstPerRowPerThread;
offsets.push_back(
{repRowIndex + elemRowIndex, repColIndex + elemColIndex});
}

return offsets;
} else {
llvm_unreachable("unsupported parent layout in emitOffsetForDotOpLayout");
}
}

static SmallVector<SmallVector<unsigned>>
emitOffsetForDpasLayout(const DpasEncodingAttr &dpasLayout,
RankedTensorType type) {
Expand All @@ -197,6 +278,89 @@ emitOffsetForDpasLayout(const DpasEncodingAttr &dpasLayout,
// -----------------------------------------------------------------------
// Dpas layout indices
// -----------------------------------------------------------------------
static SmallVector<Value>
emitBaseIndexForDotOpLayout(Location loc, RewriterBase &rewriter,
const DotOperandEncodingAttr &dotLayout,
RankedTensorType type) {

if (auto dpasLayout = dotLayout.getParent().dyn_cast<DpasEncodingAttr>()) {
Value threadId = getThreadId(rewriter, loc);
unsigned warpSize = triton::gpu::getWarpSize(dpasLayout);
Value warpId = udiv(threadId, i32_val(warpSize));
Value laneId = urem(threadId, i32_val(warpSize));

const SmallVector<unsigned> warpsPerCTA = dpasLayout.getWarpsPerCTA();
SmallVector<unsigned> order = triton::gpu::getOrder(dpasLayout);
auto shapePerCTA = triton::gpu::getShapePerCTA(type);

SmallVector<unsigned> warpShape;
auto opIdx = dotLayout.getOpIdx();
if (opIdx == 0) {
warpShape = dpasLayout.getShapeA();
} else {
warpShape = dpasLayout.getShapeB();
}
SmallVector<int64_t> numReps =
dpasLayout.getDPASRepetitions(shapePerCTA, opIdx);

SmallVector<Value> multiDimWarpId =
mlir::LLVM::delinearize(rewriter, loc, warpId, warpsPerCTA, order);

Value rowWarpId =
urem(multiDimWarpId[0],
i32_val(mlir::ceil<unsigned>(shapePerCTA[0], warpShape[0])));
Value colWarpId =
urem(multiDimWarpId[1],
i32_val(mlir::ceil<unsigned>(shapePerCTA[1], warpShape[1])));
Value rowWarpOffset = mul(rowWarpId, i32_val(warpShape[0]));
Value colWarpOffset = mul(colWarpId, i32_val(warpShape[1]));

// Compute the 2-dim coordinates of the first element in the warp operated
// own by this thread.
unsigned systolicDepth = dpasLayout.getSystolicDepth();
unsigned repeatCount = dpasLayout.getRepeatCount();
unsigned executionSize = dpasLayout.getExecutionSize();
unsigned opsPerChannel = dpasLayout.getOpsPerChannel();

Value laneRowIndex, laneColIndex;
switch (opIdx) {
case 0: {
assert((opsPerChannel == 4 || opsPerChannel == 2 || opsPerChannel == 1) &&
"invalid opsPerChannel number.");
SmallVector<unsigned> shapeA = dpasLayout.getShapeA();
// Unlike the operand B, to pack the value to i16 for scalar bit width
// <=16.
unsigned packedOpsPerLane = opsPerChannel == 4 ? 2 : 1;
unsigned packedColNum = shapeA[1] / packedOpsPerLane;
if (warpSize < packedColNum) {
llvm::report_fatal_error(
"DpasEncodingAttr sub-group size could not "
"be smaller than the threads required per row for A operand.");
}
laneRowIndex = udiv(laneId, i32_val(packedColNum));
laneColIndex = urem(laneId, i32_val(packedColNum));
laneColIndex = mul(laneColIndex, i32_val(packedOpsPerLane));
} break;
case 1: {
if (warpSize < executionSize) {
llvm::report_fatal_error(
"DpasEncodingAttr sub-group size could not "
"be smaller than the execution size for B operand.");
}
laneRowIndex = udiv(laneId, i32_val(executionSize));
laneRowIndex = mul(laneRowIndex, i32_val(opsPerChannel));
laneColIndex = urem(laneId, i32_val(executionSize));
} break;
}

SmallVector<Value> multiDimBase = {add(laneRowIndex, rowWarpOffset),
add(laneColIndex, colWarpOffset)};
return multiDimBase;
} else {
llvm_unreachable(
"unsupported parent layout in emitBaseIndexForDotOpLayout");
}
}

static SmallVector<Value>
emitBaseIndexForDpasLayout(Location loc, RewriterBase &rewriter,
Expand Down Expand Up @@ -330,6 +494,8 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, Attribute layout,
result.erase(result.begin() + sliceLayout.getDim());
// CTAOffset has been added in emitBaseIndexForLayout of parentLayout
return result;
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
result = emitBaseIndexForDotOpLayout(loc, rewriter, dotLayout, type);
} else {
llvm_unreachable("unsupported emitBaseIndexForLayout");
}
Expand All @@ -348,6 +514,9 @@ emitOffsetForLayout(Attribute layout, RankedTensorType type) {
if (auto dpasLayout = layout.dyn_cast<DpasEncodingAttr>()) {
return emitOffsetForDpasLayout(dpasLayout, type);
}
if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
return emitOffsetForDotOpLayout(dotLayout, type);
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>())
return ::intel::emitOffsetForSliceLayout(sliceLayout, type);
return ::emitOffsetForLayout(layout, type);
Expand Down

0 comments on commit 20237ca

Please sign in to comment.