diff --git a/python/tutorials/09-experimental-block-pointer.py b/python/tutorials/09-experimental-block-pointer.py index e8c1d18d4d..1bf0e88009 100644 --- a/python/tutorials/09-experimental-block-pointer.py +++ b/python/tutorials/09-experimental-block-pointer.py @@ -167,7 +167,7 @@ def matmul_kernel_with_block_pointers( # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block. # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=c_ptr.type.element_ty) for k in range(0, K, BLOCK_SIZE_K): # Load with boundary checks, no need to calculate the mask manually. # For better performance, you may remove some axis from the boundary @@ -182,7 +182,7 @@ def matmul_kernel_with_block_pointers( # See above `Advance a Block Pointer` section for details. a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K)) b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0)) - c = accumulator.to(tl.float32) + c = accumulator.to(c_ptr.type.element_ty) # ---------------------------------------------------------------- # Write back the block of the output matrix C with boundary checks. # See above `Load/Store a Block Pointer` section for details. @@ -194,7 +194,7 @@ def matmul_kernel_with_block_pointers( # We can now create a convenience wrapper function that only takes two input tensors, # and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. -def matmul(a, b): +def matmul(a, b, res_dtype): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.is_contiguous(), "Matrix A must be contiguous" @@ -202,7 +202,7 @@ def matmul(a, b): M, K = a.shape K, N = b.shape # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=torch.float32) + c = torch.empty((M, N), device=a.device, dtype=res_dtype) # 1D launch kernel where each block gets its own program. grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) matmul_kernel_with_block_pointers[grid]( @@ -222,11 +222,22 @@ def matmul(a, b): # Still we can test our matrix multiplication with block pointers against a native torch implementation (i.e., cuBLAS). torch.manual_seed(0) -for dtype in [torch.float16, torch.bfloat16]: - a = torch.randn((512, 512), device='xpu', dtype=dtype) - b = torch.randn((512, 512), device='xpu', dtype=dtype) - triton_output = matmul(a, b) - torch_output = torch.matmul(a, b).to(torch.float32) +for dtype, res_dtype in [(torch.float16, torch.float32), (torch.bfloat16, torch.float32), (torch.int8, torch.int32)]: + if dtype.is_floating_point: + a = torch.randn((512, 512), device='xpu', dtype=dtype) + b = torch.randn((512, 512), device='xpu', dtype=dtype) + else: + a = torch.randint(low=-127, high=128, size=(512, 512), device='xpu', dtype=dtype) + b = torch.randint(low=-127, high=128, size=(512, 512), device='xpu', dtype=dtype) + + triton_output = matmul(a, b, res_dtype) + if dtype.is_floating_point: + torch_output = torch.matmul(a, b).to(res_dtype) + else: + # torch.matmul clamps values to input dtype; IPEX doesn't support int32 matmul + torch_output = torch.matmul(a.to(device='cpu', dtype=res_dtype), b.to(device='cpu', + dtype=res_dtype)).to(device='xpu') + print(f"triton_output={triton_output}") print(f"torch_output={torch_output}") diff --git a/test/TritonIntelGPU/match-target-size.mlir b/test/TritonIntelGPU/match-target-size.mlir index 3110907987..c121aa5aab 100644 --- a/test/TritonIntelGPU/match-target-size.mlir +++ b/test/TritonIntelGPU/match-target-size.mlir @@ -167,3 +167,59 @@ tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16 } tt.return } + +// ----- + +// COM: Test transformation for int8 datatype + +// CHECK-LABEL: @matmul_kernel_with_block_pointers +#warp = #triton_intel_gpu.warp<{sizePerThread = [8, 32], threadsPerWarp = [1, 1], order = [1, 0]}> +tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32) { + // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : i32 + // CHECK-DAG: [[C32:%.*]] = arith.constant 32 : i32 + %cst = arith.constant dense<0> : tensor<8x32xi32, #warp> + %c0_i32 = arith.constant 0 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + // CHECK: [[TPTR_A:%.*]] = tt.make_tensor_ptr %arg0, [{{.*}}, {{.*}}], [{{.*}}, {{.*}}], [{{.*}}, [[C0]]] + // CHECK: [[TPTR_B1:%.*]] = tt.make_tensor_ptr %arg1, [{{.*}}, {{.*}}], [{{.*}}, {{.*}}], [[[C0]], {{.*}}] + // CHECK: [[TPTR_B2:%.*]] = tt.make_tensor_ptr %arg1, [{{.*}}, {{.*}}], [{{.*}}, {{.*}}], [[[C32]], {{.*}}] + %tptr_a = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> + %tptr_b = tt.make_tensor_ptr %arg1, [%c0_i64,%c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> + // CHECK: [[LOOP_RES:%.*]]:5 = scf.for {{.*}} = {{.*}} to {{.*}} step {{.*}} iter_args([[ITER_1:%.*]] = {{.*}}, [[ITER_2:%.*]] = {{.*}}, [[TPTR_A_ITER:%.*]] = [[TPTR_A]], [[TPTR_B1_ITER:%.*]] = [[TPTR_B1]], [[TPTR_B2_ITER:%.*]] = [[TPTR_B2]]) + %35:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %tptr_a, %arg12 = %tptr_b) -> (tensor<8x32xi32, #warp>, !tt.ptr>>, !tt.ptr>>) : i32 { + // CHECK: [[LD_A:%.*]] = tt.load [[TPTR_A_ITER]] {DotIdx = 0 : i32, boundaryCheck = array} : !tt.ptr> + // CHECK: [[LD_B1:%.*]] = tt.load [[TPTR_B1_ITER]] {DotIdx = 1 : i32, boundaryCheck = array} : !tt.ptr> + // CHECK: [[LD_B2:%.*]] = tt.load [[TPTR_B2_ITER]] {DotIdx = 1 : i32, boundaryCheck = array} : !tt.ptr> + %46 = tt.load %arg11 {boundaryCheck = array} : !tt.ptr>> + %47 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr>> + // CHECK: [[EX_A_0:%.*]] = triton_intel_gpu.extract [[LD_A]][0] : tensor<8x64xi8> -> tensor<8x32xi8> + // CHECK: [[EX_B1_0:%.*]] = triton_intel_gpu.extract [[LD_B1]][0] : tensor<32x32xi8> -> tensor<32x16xi8> + // CHECK: [[DOT_1:%.*]] = tt.dot [[EX_A_0]], [[EX_B1_0]], [[ITER_1]], inputPrecision = tf32 : tensor<8x32xi8> * tensor<32x16xi8> -> tensor<8x16xi32> + // CHECK: [[EX_A_1:%.*]] = triton_intel_gpu.extract [[LD_A]][1] : tensor<8x64xi8> -> tensor<8x32xi8> + // CHECK: [[EX_B2_0:%.*]] = triton_intel_gpu.extract [[LD_B2]][0] : tensor<32x32xi8> -> tensor<32x16xi8> + // CHECK: [[DOT_2:%.*]] = tt.dot [[EX_A_1]], [[EX_B2_0]], [[DOT_1]], inputPrecision = tf32 : tensor<8x32xi8> * tensor<32x16xi8> -> tensor<8x16xi32> + // CHECK: [[EX_A_0:%.*]] = triton_intel_gpu.extract [[LD_A]][0] : tensor<8x64xi8> -> tensor<8x32xi8> + // CHECK: [[EX_B1_1:%.*]] = triton_intel_gpu.extract [[LD_B1]][1] : tensor<32x32xi8> -> tensor<32x16xi8> + // CHECK: [[DOT_3:%.*]] = tt.dot [[EX_A_0]], [[EX_B1_1]], [[ITER_2]], inputPrecision = tf32 : tensor<8x32xi8> * tensor<32x16xi8> -> tensor<8x16xi32> + // CHECK: [[EX_A_1:%.*]] = triton_intel_gpu.extract [[LD_A]][1] : tensor<8x64xi8> -> tensor<8x32xi8> + // CHECK: [[EX_B2_1:%.*]] = triton_intel_gpu.extract [[LD_B2]][1] : tensor<32x32xi8> -> tensor<32x16xi8> + // CHECK: [[DOT_4:%.*]] = tt.dot [[EX_A_1]], [[EX_B2_1]], [[DOT_3]], inputPrecision = tf32 : tensor<8x32xi8> * tensor<32x16xi8> -> tensor<8x16xi32> + %48 = tt.dot %46, %47, %arg10, inputPrecision = tf32 : tensor<8x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>> -> tensor<8x32xi32, #warp> + // CHECK: [[ADV_A:%.*]] = tt.advance [[TPTR_A_ITER]], + // CHECK: [[ADV_B1:%.*]] = tt.advance [[TPTR_B1_ITER]], + // CHECK: [[ADV_B2:%.*]] = tt.advance [[TPTR_B2_ITER]], + %49 = tt.advance %arg11, [%c0_i32, %c64_i32] : >> + %50 = tt.advance %arg12, [%c64_i32, %c0_i32] : >> + // CHECK: scf.yield [[DOT_2]], [[DOT_4]], [[ADV_A]], [[ADV_B1]], [[ADV_B2]] + scf.yield %48, %49, %50 : tensor<8x32xi32, #warp>, !tt.ptr>>, !tt.ptr>> + } {triton_gpu.workload = 3 : i32} + // CHECK: [[TPTR_C1:%.*]] = tt.make_tensor_ptr %arg2, + // CHECK: [[TPTR_C2:%.*]] = tt.make_tensor_ptr %arg2, + %tptr_c = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + // CHECK: tt.store [[TPTR_C1:%.*]], [[LOOP_RES]]#0 {boundaryCheck = array} : !tt.ptr> + // CHECK: tt.store [[TPTR_C2:%.*]], [[LOOP_RES]]#1 {boundaryCheck = array} : !tt.ptr> + tt.store %tptr_c, %35#0 {boundaryCheck = array} : !tt.ptr> + tt.return +} diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index 81149a8a1b..98208c2ed3 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -121,11 +121,22 @@ class LoadStorePrefetchOpConversion "only support 1d/2d load/store/prefetch for now"); unsigned dataSize = tensorType.getElementType().getIntOrFloatBitWidth(); - unsigned blockWidth = tensorType.getShape()[1]; - assert(blockWidth == 16 || blockWidth == 32 && "only support 16/32 block"); - unsigned vBlks = blockWidth == 32 ? 2 : 1; - blockWidth = 16; unsigned blockHeight = tensorType.getShape()[0]; + unsigned blockWidth = tensorType.getShape()[1]; + assert((blockWidth == 16 || blockWidth == 32 || blockWidth == 64) && + "only support 16/32/64 block"); + auto idxAttr = op->template getAttrOfType("DotIdx"); + unsigned vBlks = 1; + if (dataSize == 16) { + vBlks = ceil(blockWidth, 16U); + blockWidth = 16; + } else if (dataSize == 8 && idxAttr) { + unsigned blockWidthUnit = idxAttr.getInt() == 0 ? 32 : 16; + vBlks = ceil(blockWidth, blockWidthUnit); + blockWidth = blockWidthUnit; + } + assert((vBlks == 1 || vBlks == 2) && "only support 1 or 2 blocks"); + Value ptr = op.getPtr(); if (auto cast = dyn_cast(ptr.getDefiningOp())) @@ -160,7 +171,7 @@ class LoadStorePrefetchOpConversion Value offsetY = extract_element(tensorPtr, i32_val(1)); if constexpr (std::is_same_v) { - auto idxAttr = op->template getAttrOfType("DotIdx"); + assert(idxAttr && "Dot index attribute missing"); unsigned idx = idxAttr.getInt(); Type resType = this->getTypeConverter()->convertType(op->getResult(0).getType()); @@ -215,6 +226,12 @@ class DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { return TritonGEN::PrecisionType::FP16; else if (type == rewriter.getTF32Type()) return TritonGEN::PrecisionType::TF32; + else if (type.isInteger(8)) { + if (type.isUnsignedInteger()) + return TritonGEN::PrecisionType::U8; + return TritonGEN::PrecisionType::S8; + } + llvm_unreachable("add more support for PrecisionType"); return TritonGEN::PrecisionType::UNUSED; }; diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp index 35828c29ac..aa00a7886e 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp @@ -71,27 +71,60 @@ namespace { class TargetArchNativeSizes { public: struct DotShape { - DotShape() = default; DotShape(unsigned m, unsigned n, unsigned k) : m(m), n(n), k(k) { assert(m != 0 && n != 0 && k != 0 && "expecting valid shape"); } - unsigned m = 0; - unsigned n = 0; - unsigned k = 0; + const unsigned m; + const unsigned n; + const unsigned k; + }; + + struct BlockMemShape { + BlockMemShape(unsigned rowsA, unsigned columnsA, unsigned rowsB, + unsigned columnsB) + : rowsA(rowsA), columnsA(columnsA), rowsB(rowsB), columnsB(columnsB) { + assert(rowsA != 0 && columnsA != 0 && rowsB != 0 && columnsB != 0 && + "expecting valid shape"); + } + + const unsigned rowsA; + const unsigned columnsA; + const unsigned rowsB; + const unsigned columnsB; }; TargetArchNativeSizes() = default; - TargetArchNativeSizes(DotShape dotShape, unsigned loadStoreSize) - : dotShape(dotShape), loadStoreSize(loadStoreSize) {} - void setDotShape(DotShape shape) { dotShape = shape; } + void setDotShape(unsigned bitWidth, DotShape &&shape) { + assert(!dotShapes.contains(bitWidth) && "Dot shape already set"); + dotShapes.try_emplace(bitWidth, std::move(shape)); + } + void setBlockMemShape(unsigned bitWidth, BlockMemShape &&shape) { + assert(!blockMemShapes.contains(bitWidth) && + "Block memory access shape already set"); + blockMemShapes.try_emplace(bitWidth, std::move(shape)); + } void setLoadStoreSize(unsigned size) { loadStoreSize = size; } - const DotShape &getDotShape() const { return dotShape; } + const DotShape &getDotShape(unsigned bitWidth) const { + assert(dotShapes.contains(bitWidth) && + "No dot shape configured for bit width"); + return dotShapes.at(bitWidth); + } + const BlockMemShape &getBlockMemShape(unsigned bitWidth) const { + assert(blockMemShapes.contains(bitWidth) && + "No block memory access shape configured for bit width"); + return blockMemShapes.at(bitWidth); + } unsigned getLoadStoreSize() const { return loadStoreSize; } private: - DotShape dotShape; + /// Stores the natively supported dot shape per bitwidth of the operand data + /// type, e.g. 16 -> 8x16x16 (MxKxN) for [b]float16 on PVC. + llvm::SmallDenseMap dotShapes; + /// Stores the natively supported shapes for 2D block reads of dot operands, + /// per element type bitwidth. + llvm::SmallDenseMap blockMemShapes; unsigned loadStoreSize = 0; }; @@ -388,9 +421,16 @@ class ScfPattern : public OpRewritePattern { void MatchTargetSizePass::initNativeOperationSizes() { // FIXME: sets the target dot shape natively supported by the target // architecture using the target architecture information when available. - // These value works for PVC. - TargetArchNativeSizes::DotShape shape(8, 16, 16); - nativeSizes.setDotShape(shape); + // These values works for PVC. + + nativeSizes.setDotShape(8, {8, 16, 32}); + nativeSizes.setDotShape(16, {8, 16, 16}); + nativeSizes.setDotShape(32, {8, 16, 8}); + + nativeSizes.setBlockMemShape(8, {16, 64, 32, 32}); + nativeSizes.setBlockMemShape(16, {32, 32, 32, 32}); + nativeSizes.setBlockMemShape(32, {8, 8, 8, 16}); + nativeSizes.setLoadStoreSize(512); // max 512DW; } @@ -453,14 +493,16 @@ MatchTargetSizePass::getSubOpSize(RankedTensorType type) const { // Dot operation. if (dotAttrs.count(layout)) { - const auto &dotShape = nativeSizes.getDotShape(); + const TargetArchNativeSizes::DotShape &dotShape = + nativeSizes.getDotShape(type.getElementTypeBitWidth()); SmallVector nativeDotSize{dotShape.m, dotShape.n}; return nativeDotSize; } // Load/Store operations. ArrayRef shape = type.getShape(); - const unsigned sizeInBytes = type.getElementTypeBitWidth() / 8; + const unsigned sizeInBits = type.getElementTypeBitWidth(); + const unsigned sizeInBytes = sizeInBits / 8; unsigned maxLoadStoreSize = nativeSizes.getLoadStoreSize(); SmallVector subSize(shape.size()); @@ -470,14 +512,29 @@ MatchTargetSizePass::getSubOpSize(RankedTensorType type) const { subSize[0] = std::min(max, shape[0]); } break; case 2: { - // 32 = 2 * 16(subgroupSize) which is for large load/store - int64_t colLimit = - (isa(layout)) ? 32 - : 0; - subSize[1] = (shape[1] > colLimit) ? colLimit : shape[1]; - // FIXME: From gfxspec, max 2d block load height is 32 - int64_t max = 32; - subSize[0] = std::min(max, shape[0]); + if (isa(layout)) { + // 32 = 2 * 16(subgroupSize) which is for large load/store + subSize[1] = std::min(32L, shape[1]); + // FIXME: From gfxspec, max 2d block load height is 32 + subSize[0] = std::min(32L, shape[0]); + } else if (auto dotLayout = dyn_cast(layout)) { + const TargetArchNativeSizes::BlockMemShape &memShape = + nativeSizes.getBlockMemShape(sizeInBits); + switch (dotLayout.getOpIdx()) { + case 0: + subSize[1] = + std::min(static_cast(memShape.columnsA), shape[1]); + subSize[0] = std::min(static_cast(memShape.rowsA), shape[0]); + break; + case 1: + subSize[1] = + std::min(static_cast(memShape.columnsB), shape[1]); + subSize[0] = std::min(static_cast(memShape.rowsB), shape[0]); + break; + } + } else { + llvm_unreachable("Unsupported layout"); + } } break; default: llvm_unreachable("Unsupported shape"); @@ -593,7 +650,9 @@ void MatchTargetSizePass::transformDotOp(tt::DotOp dot) { int64_t m = aShape[0]; int64_t n = bShape[1]; int64_t k = aShape[1]; - const auto &dotShape = nativeSizes.getDotShape(); + const TargetArchNativeSizes::DotShape &dotShape = + nativeSizes.getDotShape(aType.getElementTypeBitWidth()); + OpBuilder b(dot); Location loc = dot.getLoc();