Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal viable support for int8 on the block pointer path #1155

Merged
merged 21 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions python/tutorials/09-experimental-block-pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def matmul_kernel_with_block_pointers(
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block.
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=c_ptr.type.element_ty)
for k in range(0, K, BLOCK_SIZE_K):
# Load with boundary checks, no need to calculate the mask manually.
# For better performance, you may remove some axis from the boundary
Expand All @@ -184,7 +184,7 @@ def matmul_kernel_with_block_pointers(
# See above `Advance a Block Pointer` section for details.
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
c = accumulator.to(tl.float32)
c = accumulator.to(c_ptr.type.element_ty)
# ----------------------------------------------------------------
# Write back the block of the output matrix C with boundary checks.
# See above `Load/Store a Block Pointer` section for details.
Expand All @@ -196,15 +196,15 @@ def matmul_kernel_with_block_pointers(

# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.
def matmul(a, b):
def matmul(a, b, res_dtype):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
assert b.is_contiguous(), "Matrix B must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
c = torch.empty((M, N), device=a.device, dtype=res_dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel_with_block_pointers[grid](
Expand All @@ -224,11 +224,22 @@ def matmul(a, b):
# Still we can test our matrix multiplication with block pointers against a native torch implementation (i.e., cuBLAS).

torch.manual_seed(0)
for dtype in [torch.float16, torch.bfloat16]:
a = torch.randn((512, 512), device='xpu', dtype=dtype)
b = torch.randn((512, 512), device='xpu', dtype=dtype)
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b).to(torch.float32)
for dtype, res_dtype in [(torch.float16, torch.float32), (torch.bfloat16, torch.float32), (torch.int8, torch.int32)]:
if dtype.is_floating_point:
a = torch.randn((512, 512), device='xpu', dtype=dtype)
b = torch.randn((512, 512), device='xpu', dtype=dtype)
else:
a = torch.randint(low=-127, high=128, size=(512, 512), device='xpu', dtype=dtype)
b = torch.randint(low=-127, high=128, size=(512, 512), device='xpu', dtype=dtype)

triton_output = matmul(a, b, res_dtype)
if dtype.is_floating_point:
torch_output = torch.matmul(a, b).to(res_dtype)
else:
# torch.matmul clamps values to input dtype; IPEX doesn't support int32 matmul
torch_output = torch.matmul(a.to(device='cpu', dtype=res_dtype), b.to(device='cpu',
dtype=res_dtype)).to(device='xpu')

print(f"triton_output={triton_output}")
print(f"torch_output={torch_output}")

Expand Down
23 changes: 18 additions & 5 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,22 @@ class LoadStorePrefetchOpConversion
"only support 1d/2d load/store/prefetch for now");

unsigned dataSize = tensorType.getElementType().getIntOrFloatBitWidth();
unsigned blockWidth = tensorType.getShape()[1];
assert(blockWidth == 16 || blockWidth == 32 && "only support 16/32 block");
unsigned vBlks = blockWidth == 32 ? 2 : 1;
blockWidth = 16;
unsigned blockHeight = tensorType.getShape()[0];
unsigned blockWidth = tensorType.getShape()[1];
assert(blockWidth == 16 || blockWidth == 32 ||
blockWidth == 64 && "only support 16/32/64 block");
jopperm marked this conversation as resolved.
Show resolved Hide resolved
auto idxAttr = op->template getAttrOfType<mlir::IntegerAttr>("DotIdx");
unsigned vBlks = 1;
if (dataSize == 16) {
vBlks = blockWidth / 16;
jopperm marked this conversation as resolved.
Show resolved Hide resolved
whitneywhtsang marked this conversation as resolved.
Show resolved Hide resolved
blockWidth = 16;
} else if (dataSize == 8 && idxAttr) {
unsigned blockWidthUnit = idxAttr.getInt() == 0 ? 32 : 16;
vBlks = llvm::divideCeil(blockWidth, blockWidthUnit);
whitneywhtsang marked this conversation as resolved.
Show resolved Hide resolved
blockWidth = blockWidthUnit;
}
assert(vBlks == 1 || vBlks == 2 && "only support 1 or 2 blocks");
jopperm marked this conversation as resolved.
Show resolved Hide resolved

Value ptr = op.getPtr();
if (auto cast =
dyn_cast<mlir::UnrealizedConversionCastOp>(ptr.getDefiningOp()))
Expand All @@ -151,7 +162,7 @@ class LoadStorePrefetchOpConversion
Value offsetY = extract_element(tensorPtr, i32_val(1));

if constexpr (std::is_same_v<OpType, LoadOp>) {
auto idxAttr = op->template getAttrOfType<mlir::IntegerAttr>("DotIdx");
assert(idxAttr && "Dot index attribute missing");
unsigned idx = idxAttr.getInt();
Type resType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
Expand Down Expand Up @@ -206,6 +217,8 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<DotOp> {
return TritonGEN::PrecisionType::FP16;
else if (type == rewriter.getTF32Type())
return TritonGEN::PrecisionType::TF32;
else if (type == i8_ty)
return TritonGEN::PrecisionType::S8;
llvm_unreachable("add more support for PrecisionType");
return TritonGEN::PrecisionType::UNUSED;
};
Expand Down
63 changes: 44 additions & 19 deletions third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp
whitneywhtsang marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,18 @@ class TargetArchNativeSizes {
};

TargetArchNativeSizes() = default;
TargetArchNativeSizes(DotShape dotShape, unsigned loadStoreSize)
: dotShape(dotShape), loadStoreSize(loadStoreSize) {}

void setDotShape(DotShape shape) { dotShape = shape; }
void setDotShape(unsigned bitWidth, DotShape shape) {
dotShapes[bitWidth] = shape;
}
void setLoadStoreSize(unsigned size) { loadStoreSize = size; }
const DotShape &getDotShape() const { return dotShape; }
std::optional<DotShape> getDotShape(unsigned bitWidth) const {
return dotShapes.lookup(bitWidth);
}
unsigned getLoadStoreSize() const { return loadStoreSize; }

private:
DotShape dotShape;
llvm::SmallDenseMap<unsigned, std::optional<DotShape>> dotShapes;
jopperm marked this conversation as resolved.
Show resolved Hide resolved
jopperm marked this conversation as resolved.
Show resolved Hide resolved
unsigned loadStoreSize = 0;
};

Expand Down Expand Up @@ -389,8 +391,9 @@ void MatchTargetSizePass::initNativeOperationSizes() {
// FIXME: sets the target dot shape natively supported by the target
// architecture using the target architecture information when available.
// These value works for PVC.
TargetArchNativeSizes::DotShape shape(8, 16, 16);
nativeSizes.setDotShape(shape);
nativeSizes.setDotShape(8, {/*m=*/8, /*n=*/16, /*k=*/32});
nativeSizes.setDotShape(16, {/*m=*/8, /*n=*/16, /*k=*/16});
nativeSizes.setDotShape(32, {/*m=*/8, /*n=*/16, /*k=*/8});
nativeSizes.setLoadStoreSize(512); // max 512DW;
}

Expand Down Expand Up @@ -453,14 +456,16 @@ MatchTargetSizePass::getSubOpSize(RankedTensorType type) const {

// Dot operation.
if (dotAttrs.count(layout)) {
const auto &dotShape = nativeSizes.getDotShape();
SmallVector<int64_t> nativeDotSize{dotShape.m, dotShape.n};
auto dotShape = nativeSizes.getDotShape(type.getElementTypeBitWidth());
jopperm marked this conversation as resolved.
Show resolved Hide resolved
assert(dotShape.has_value() && "Unknown dot shape");
jopperm marked this conversation as resolved.
Show resolved Hide resolved
SmallVector<int64_t> nativeDotSize{dotShape->m, dotShape->n};
return nativeDotSize;
}

// Load/Store operations.
ArrayRef<int64_t> shape = type.getShape();
const unsigned sizeInBytes = type.getElementTypeBitWidth() / 8;
const unsigned sizeInBits = type.getElementTypeBitWidth();
const unsigned sizeInBytes = sizeInBits / 8;
unsigned maxLoadStoreSize = nativeSizes.getLoadStoreSize();

SmallVector<int64_t> subSize(shape.size());
Expand All @@ -470,14 +475,31 @@ MatchTargetSizePass::getSubOpSize(RankedTensorType type) const {
subSize[0] = std::min(max, shape[0]);
} break;
case 2: {
// 32 = 2 * 16(subgroupSize) which is for large load/store
int64_t colLimit =
(isa<ttgi::WarpEncodingAttr, ttg::DotOperandEncodingAttr>(layout)) ? 32
: 0;
subSize[1] = (shape[1] > colLimit) ? colLimit : shape[1];
// FIXME: From gfxspec, max 2d block load height is 32
int64_t max = 32;
subSize[0] = std::min(max, shape[0]);
if (isa<ttgi::WarpEncodingAttr>(layout) ||
(isa<triton::gpu::DotOperandEncodingAttr>(layout) &&
sizeInBits == 16)) {
// 32 = 2 * 16(subgroupSize) which is for large load/store
int64_t colLimit = 32;
subSize[1] = std::min(colLimit, shape[1]);
// FIXME: From gfxspec, max 2d block load height is 32
int64_t rowLimit = 32;
subSize[0] = std::min(rowLimit, shape[0]);
} else if (auto dotLayout = dyn_cast<ttg::DotOperandEncodingAttr>(layout);
dotLayout && sizeInBits == 8) {
// FIXME: These settings underutilize the memory bandwidth.
switch (dotLayout.getOpIdx()) {
case 0:
subSize[0] = std::min(16L, shape[0]);
subSize[1] = std::min(64L, shape[1]); // 2 blocks of 32 cols each
break;
case 1:
subSize[0] = std::min(32L, shape[0]);
subSize[1] = std::min(32L, shape[1]); // 2 blocks of 16 cols each
break;
}
jopperm marked this conversation as resolved.
Show resolved Hide resolved
} else {
llvm_unreachable("Unsupported layout");
}
} break;
default:
llvm_unreachable("Unsupported shape");
Expand Down Expand Up @@ -593,7 +615,10 @@ void MatchTargetSizePass::transformDotOp(tt::DotOp dot) {
int64_t m = aShape[0];
int64_t n = bShape[1];
int64_t k = aShape[1];
const auto &dotShape = nativeSizes.getDotShape();
auto dotShapeOrNone = nativeSizes.getDotShape(aType.getElementTypeBitWidth());
assert(dotShapeOrNone.has_value() && "Unknown dot shape");
const auto &dotShape = *dotShapeOrNone;

OpBuilder b(dot);
Location loc = dot.getLoc();

Expand Down