Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Apr 29, 2024
1 parent b875267 commit f55a08d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 20 deletions.
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonIntelGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ DPASEngineType getDPASType(DotOp op);
// Infers the encoding of the source of op given the result encoding.
std::optional<Attribute> inferSrcEncoding(Operation *op, Attribute encoding);

// Retuns true is the operation is an expensive load or store operation.
// Retuns true if the operation is an expensive load or store operation.
bool isExpensiveLoadOrStore(Operation *op);

// Returns true if the tensor type has a dot dpas encoding.
Expand Down
37 changes: 18 additions & 19 deletions third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,16 @@ struct LoadOpConversion
LoadStoreConversionBase(axisAnalysisPass) {}

/// Holds the values related to a block pointer
// It includes the offset base for Y and X, base height and width, row and
// It includes the offset base for X and Y, base height and width, row and
// column stride, and the base value.
struct BlockPointerValues {
Value base;
Value offsetBaseX;
Value offsetBaseY;
Value baseWidth;
Value baseHeight;
Value rowStride;
Value colStride;
Value offsetBaseX;
Value offsetBaseY;
};

// Unpack values as the params to 2DBlockLoad Payload:
Expand All @@ -168,12 +168,12 @@ struct LoadOpConversion
"unexpected number of values unpacked from a block pointer");
BlockPointerValues values{
.base = elems[6],
.offsetBaseX = elems[1],
.offsetBaseY = elems[0],
.baseWidth = elems[3],
.baseHeight = elems[2],
.rowStride = elems[4],
.colStride = elems[5],
.offsetBaseX = elems[1],
.offsetBaseY = elems[0],
};
return values;
}
Expand Down Expand Up @@ -211,8 +211,9 @@ struct LoadOpConversion
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);

bool isOperandA = (opIdx == 0);
SmallVector<unsigned> operandShape =
opIdx == 0 ? dpasLayout.getShapeA() : dpasLayout.getShapeB();
isOperandA ? dpasLayout.getShapeA() : dpasLayout.getShapeB();
SmallVector<int64_t> elemsPerInstr = {operandShape[0], operandShape[1]};
int64_t elemsPerLane = product<int64_t>(elemsPerInstr) /
product<unsigned>(getThreadsPerWarp(dpasLayout));
Expand All @@ -221,10 +222,10 @@ struct LoadOpConversion
typeConverter->convertType(eltTy), elemsPerLane);

// pack scalar to i16 for operand A, to i32 for operand B.
Type elemType = (opIdx == 0) ? i16_ty : i32_ty;
Type elemType = isOperandA ? i16_ty : i32_ty;
unsigned opsPerChannel = dpasLayout.getOpsPerChannel();
elemsPerLane = (opIdx == 0) ? elemsPerLane / (opsPerChannel == 4 ? 2 : 1)
: elemsPerLane / opsPerChannel;
elemsPerLane = isOperandA ? elemsPerLane / (opsPerChannel == 4 ? 2 : 1)
: elemsPerLane / opsPerChannel;
Type load2DGenXType = LLVM::getFixedVectorType(elemType, elemsPerLane);

// Outer dim, A is the M, B is the N. Inner dim, the K
Expand All @@ -233,14 +234,9 @@ struct LoadOpConversion
Value outerDimWarpId =
urem(multiDimWarpId[opIdx], i32_val(outerDimWarpNum));

BlockPointerValues blockPtrStruct =
auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
offsetBaseY] =
getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter);
Value base = blockPtrStruct.base;
Value offsetBaseX = blockPtrStruct.offsetBaseX;
Value offsetBaseY = blockPtrStruct.offsetBaseY;
Value baseHeight = blockPtrStruct.baseHeight;
Value baseWidth = blockPtrStruct.baseWidth;
Value rowStride = blockPtrStruct.rowStride;

// Load the operand.
int64_t numRepOuter = numReps[opIdx];
Expand All @@ -250,12 +246,12 @@ struct LoadOpConversion
for (int outer = 0; outer < numRepOuter; ++outer) {
for (int k = 0; k < numRepK; ++k) {
Value offsetX =
(opIdx == 0)
isOperandA
? i32_val(k * elemsPerInstr[1])
: add(mul(outerDimWarpId, i32_val(elemsPerInstr[opIdx])),
i32_val(outer * outerDimWarpNum * elemsPerInstr[opIdx]));
Value offsetY =
(opIdx == 0)
isOperandA
? add(mul(outerDimWarpId, i32_val(elemsPerInstr[opIdx])),
i32_val(outer * outerDimWarpNum * elemsPerInstr[opIdx]))
: i32_val(k * elemsPerInstr[0]);
Expand Down Expand Up @@ -292,7 +288,7 @@ struct LoadOpConversion
getIntAttr(i1_ty, 0),
/*vnni_transform*/
getIntAttr(i1_ty,
(opIdx == 0 || eltTy.getIntOrFloatBitWidth() == 32)
(isOperandA || eltTy.getIntOrFloatBitWidth() == 32)
? /*A vnni=false*/ 0
: /*B vnni=true*/ 1));

Expand Down Expand Up @@ -333,6 +329,9 @@ struct LoadOpConversion
if (isTensorPointerType(ptr.getType()))
return rewriteTensorPointerLoad(op, adaptor, rewriter);

assert(!isTensorPointerType(ptr.getType()) &&
"Cannot convert load with a tensor pointer into LLVM; "
"this case should be transformed to normal load before lowering");
Value llPtr = adaptor.getPtr();
Value llMask = adaptor.getMask();
Value llOther = adaptor.getOther();
Expand Down

0 comments on commit f55a08d

Please sign in to comment.