diff --git a/test/Target/LLVMIR/triton-gen.mlir b/test/Target/LLVMIR/triton-gen.mlir index 6ef952f8c3..a4774ac7a2 100644 --- a/test/Target/LLVMIR/triton-gen.mlir +++ b/test/Target/LLVMIR/triton-gen.mlir @@ -16,3 +16,29 @@ llvm.func spir_kernelcc @test_reqd_work_group_size() attributes {triton_gen.reqd // CHECK-DAG: ![[REQD_SUB_GROUP_SIZE]] = !{i64 32} // CHECK-DAG: ![[MAX_WORK_GROUP_SIZE]] = !{i64 128, i64 1, i64 1} // CHECK-DAG: ![[REQD_WORK_GROUP_SIZE]] = !{i64 128, i64 1, i64 2} + +// ----- + +llvm.func @foo(%arg0: !llvm.ptr, %arg1: !llvm.ptr) + +// CHECK-LABEL: define void @triton_gen.cache_controls( +// CHECK-SAME: ptr %[[#ARG0:]]) { +llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) { + // CHECK: %[[#LOAD:]] = load i32, ptr %[[#ARG0]], align 4, !spirv.DecorationCacheControlINTEL ![[#DECORATION0:]] + %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32 + // CHECK: store i32 %[[#LOAD]], ptr %[[#ARG0]], align 4, !spirv.DecorationCacheControlINTEL ![[#DECORATION1:]] + llvm.store %0, %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, WriteBack, 1>, #triton_gen.store_cache_control<1, Streaming, 1>>} : i32, !llvm.ptr + // CHECK: call void @foo(ptr %[[#ARG0]], ptr %[[#ARG0]]), !spirv.DecorationCacheControlINTEL ![[#DECORATION2:]] + llvm.call @foo(%arg0, %arg0) {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<0, Cached, 1>>} : (!llvm.ptr, !llvm.ptr) -> () + llvm.return +} + +// CHECK-DAG: ![[#DECORATION0]] = !{![[#CACHECONTROL0:]], ![[#CACHECONTROL1:]]} +// CHECK-DAG: ![[#DECORATION1]] = !{![[#CACHECONTROL2:]], ![[#CACHECONTROL3:]]} +// CHECK-DAG: ![[#DECORATION2]] = !{![[#CACHECONTROL4:]], ![[#CACHECONTROL5:]]} +// CHECK-DAG: ![[#CACHECONTROL0]] = !{i32 6442, i32 0, i32 1, i32 0} +// CHECK-DAG: ![[#CACHECONTROL1]] = !{i32 6442, i32 1, i32 0, i32 0} +// CHECK-DAG: ![[#CACHECONTROL2]] = !{i32 6443, i32 0, i32 2, i32 1} +// CHECK-DAG: ![[#CACHECONTROL3]] = !{i32 6443, i32 1, i32 3, i32 1} +// CHECK-DAG: ![[#CACHECONTROL4]] = !{i32 6443, i32 0, i32 0, i32 0} +// CHECK-DAG: ![[#CACHECONTROL5]] = !{i32 6442, i32 0, i32 1, i32 1} diff --git a/test/TritonGEN/tritongen-invalid.mlir b/test/TritonGEN/tritongen-invalid.mlir index 69097c990f..20af5ff2d8 100644 --- a/test/TritonGEN/tritongen-invalid.mlir +++ b/test/TritonGEN/tritongen-invalid.mlir @@ -1,5 +1,21 @@ // RUN: triton-opt -split-input-file -verify-diagnostics %s +llvm.func @triton_gen.duplicated_cache_controls(%arg0: !llvm.ptr) { + // expected-error @+1 {{'triton_gen.decoration_cache_controls' cannot specify more than one cache control decoration of the same nature for the same cache level}} + %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<0, Cached, 0>>} : !llvm.ptr -> i32 + llvm.return +} + +// ----- + +llvm.func @triton_gen.illegal_cache_controls_attr(%arg0: !llvm.ptr) { + // expected-error @+1 {{'triton_gen.decoration_cache_controls' only accepts LoadCacheControlDecorationAttr and StoreCacheControlDecorationAttr attributes}} + %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<1 : i32>} : !llvm.ptr -> i32 + llvm.return +} + +// ----- + llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) { // expected-error @+1 {{'triton_gen.dpas' op expecting repeat count to be 1, 2, 4, or 8}} %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=16} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32> diff --git a/test/TritonGEN/tritongen.mlir b/test/TritonGEN/tritongen.mlir index b1baa74d25..8fadf8552e 100644 --- a/test/TritonGEN/tritongen.mlir +++ b/test/TritonGEN/tritongen.mlir @@ -138,6 +138,19 @@ llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8 llvm.return } +llvm.func @foo(%arg0: !llvm.ptr, %arg1: !llvm.ptr) + +llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) { + // CHECK: llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) + // CHECK-NEXT: %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32 + %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32 + // CHECK-NEXT: llvm.store %0, %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, WriteBack, 1>, #triton_gen.store_cache_control<1, Streaming, 1>>} : i32, !llvm.ptr + llvm.store %0, %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, WriteBack, 1>, #triton_gen.store_cache_control<1, Streaming, 1>>} : i32, !llvm.ptr + // CHECK-NEXT: llvm.call @foo(%arg0, %arg0) {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<0, Cached, 1>>} : (!llvm.ptr, !llvm.ptr) -> () + llvm.call @foo(%arg0, %arg0) {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<0, Cached, 1>>} : (!llvm.ptr, !llvm.ptr) -> () + llvm.return +} + llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // CHECK: llvm.func @triton_gen.2Dblockload(%arg0: !llvm.ptr, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) { // CHECK-NEXT: %0 = triton_gen.2Dblockload %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 1, transpose = false, vnni_transform = false, cache_control = Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<16xf16> diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td index 46c2098653..8029990166 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td @@ -9,8 +9,17 @@ #ifndef TRITONGEN_ATTRDEFS #define TRITONGEN_ATTRDEFS +include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td" + +include "mlir/IR/AttrTypeBase.td" include "mlir/IR/EnumAttr.td" +class TritonGEN_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; + let cppNamespace = "::mlir::triton::TritonGEN"; +} + /// Enum attribute of the different reduce kinds. def TritonGEN_ReduceKindAttr : I32EnumAttr<"ReduceKind", "TritonGEN reduce kind", [ @@ -129,4 +138,102 @@ def TritonGEN_MemScope : I32EnumAttr<"MemScope", let cppNamespace = "::mlir::triton::TritonGEN"; } +class TritonGEN_LoadStoreCacheControlDecoration + : TritonGEN_Attr { + let summary = "An attribute specifying " # !tolower(loadOrStore) # " cache control"; + let description = [{ + A }] # !tolower(loadOrStore) # [{ cache control attribute has a one-to-one + correspondance with the SPIR-V decoration shown in + https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.asciidoc#decorations. The + only differences are there is no need to add the SPIR-V decoration key, as + that will be inferred from the attribute type, and an additional + `operand_number` parameter is needed, as this is applied to pointer user + operations, following + https://github.com/KhronosGroup/SPIRV-LLVM-Translator/pull/2587 design. + }]; + let parameters = (ins "uint32_t":$cache_level, + loadOrStore # "CacheControlDecorationEnum":$cache_control, + "uint32_t":$operand_number); + let assemblyFormat = "`<` $cache_level `,` $cache_control `,` $operand_number `>`"; +} + +/// Enum attribute for load cache controls. +/// +/// See +/// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.asciidoc#Load_Cache_Control. +def TritonGEN_LoadCacheControlDecorationEnum : I32EnumAttr<"LoadCacheControlDecorationEnum", + "TritonGEN load cache controls", + [ I32EnumAttrCase<"Uncached", 0, "Uncached">, + I32EnumAttrCase<"Cached", 1, "Cached">, + I32EnumAttrCase<"Streaming", 2, "Streaming">, + I32EnumAttrCase<"InvalidateAfterRead", 3, "InvalidateAfterRead">, + I32EnumAttrCase<"ConstCached", 4, "ConstCached"> + ]> { + let cppNamespace = "::mlir::triton::TritonGEN"; +} + +def TritonGEN_LoadCacheControlDecoration + : TritonGEN_LoadStoreCacheControlDecoration<"Load">; + +/// Enum attribute for store cache controls. +/// +/// See +/// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.asciidoc#Store_Cache_Control. +def TritonGEN_StoreCacheControlDecorationEnum : I32EnumAttr<"StoreCacheControlDecorationEnum", + "TritonGEN store cache controls", + [ I32EnumAttrCase<"Uncached", 0, "Uncached">, + I32EnumAttrCase<"WriteThrough", 1, "WriteThrough">, + I32EnumAttrCase<"WriteBack", 2, "WriteBack">, + I32EnumAttrCase<"Streaming", 3, "Streaming">, + ]> { + let cppNamespace = "::mlir::triton::TritonGEN"; +} + +def TritonGEN_StoreCacheControlDecoration + : TritonGEN_LoadStoreCacheControlDecoration<"Store">; + +def TritonGEN_CacheControlsDecoration + : TritonGEN_Attr<"DecorationCacheControl", "decoration_cache_control"> { + let summary = "An attribute specifying an operation cache controls"; + let description = [{ + Attribute corresponding to `!spirv.DecorationCacheControlINTEL` metadata + described in + https://github.com/KhronosGroup/SPIRV-LLVM-Translator/pull/2587. This + metadata is a list of `LoadCacheControlDecoration` and + `StoreCacheControlDecoration` attributes specifying cache control + information. + + The following MLIR code using this attribute: + + ```mlir + llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) { + %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<1, WriteThrough, 0>, #triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32 + llvm.return + } + ``` + + Will be translated to the following LLVM IR: + + ```llvm + define void @triton_gen.cache_controls(ptr %0) { + %1 = load i32, ptr %0, align 4, !spirv.DecorationCacheControlINTEL !1 + ret void + } + + !1 = !{!2, !3, !4, !5} + !2 = !{i32 6443, i32 0, i32 0, i32 0} + !3 = !{i32 6443, i32 1, i32 1, i32 0} + !4 = !{i32 6442, i32 0, i32 1, i32 0} + !5 = !{i32 6442, i32 1, i32 0, i32 0} + ``` + + $decorations must be a non-empty list of cache controls attributes and two + attributes of the same nature (load or store) in the list cannot be applied + to the same cache level, as per the SPIR-V validation rules. + }]; + let parameters = (ins ArrayRefParameter<"::mlir::Attribute">:$decorations); + let assemblyFormat = "`<` $decorations `>`"; + let genVerifyDecl = 1; +} + #endif // TRITONGEN_ATTRDEFS diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td index 232e849f6f..c7ade3af3d 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td @@ -15,6 +15,7 @@ def TritonGEN_Dialect : Dialect { let name = "triton_gen"; let cppNamespace = "::mlir::triton::TritonGEN"; let summary = "The TritonGEN dialect in Triton."; + let useDefaultAttributePrinterParser = 1; let description = [{ TritonGEN is a dialect for representing operations on Intel GPUs. @@ -40,6 +41,12 @@ def TritonGEN_Dialect : Dialect { static constexpr ::llvm::StringLiteral getReqdSubGroupSizeAttrName() { return ::llvm::StringLiteral("triton_gen.intel_reqd_sub_group_size"); } + + /// Get the name for the attribute used to specify cache control + /// decorations. + static constexpr ::llvm::StringRef getCacheControlsAttrName() { + return ::llvm::StringLiteral("triton_gen.DecorationCacheControlINTEL"); + } }]; } diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/CMakeLists.txt b/third_party/intel/lib/Dialect/TritonGEN/IR/CMakeLists.txt index 1c53ec8a2f..4cae1191e5 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/CMakeLists.txt +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/CMakeLists.txt @@ -1,4 +1,5 @@ add_triton_library(TritonGENIR + TritonGENAttrs.cpp TritonGENDialect.cpp TritonGENOps.cpp diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENAttrs.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENAttrs.cpp new file mode 100644 index 0000000000..cecda52560 --- /dev/null +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENAttrs.cpp @@ -0,0 +1,54 @@ +//===- TritonGENAttrs.cpp - TritonGEN Attributes Definition --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" + +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir::triton::TritonGEN { +//===----------------------------------------------------------------------===// +// triton_gen.decoration_cache_control +//===----------------------------------------------------------------------===// + +LogicalResult TritonGEN::DecorationCacheControlAttr::verify( + llvm::function_ref emitError, + ArrayRef decorations) { + llvm::SmallSet loadCacheLevels; + llvm::SmallSet storeCacheLevels; + for (Attribute attr : decorations) { + LogicalResult res = + TypeSwitch(attr) + .Case([emitError, &loadCacheLevels, + &storeCacheLevels]( + auto attr) + -> LogicalResult { + llvm::SmallSet &cacheLevels = + std::is_same_v + ? loadCacheLevels + : storeCacheLevels; + if (!cacheLevels.insert(attr.getCacheLevel()).second) + return emitError() + << "'triton_gen.decoration_cache_controls' cannot " + "specify more than one cache control decoration of " + "the same nature for the same cache level"; + return success(); + }) + .Default([emitError](Attribute attr) -> LogicalResult { + return emitError() + << "'triton_gen.decoration_cache_controls' only accepts " + "LoadCacheControlDecorationAttr and " + "StoreCacheControlDecorationAttr attributes"; + }); + if (failed(res)) + return res; + } + return success(); +} +} // namespace mlir::triton::TritonGEN diff --git a/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp index 1c710e6527..cce8f6c720 100644 --- a/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp +++ b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp @@ -35,10 +35,77 @@ class TritonGENDialectLLVMIRTranslationInterface amendOperation(Operation *op, ArrayRef instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const final { - // Skip the attribute if it is not a TritonGEN attribute. - if (!attribute.getName().getValue().starts_with("triton_gen")) - return success(); + StringRef attrName = attribute.getName().getValue(); + if (attrName == + triton::TritonGEN::TritonGENDialect::getCacheControlsAttrName()) { + auto decorationAttr = + dyn_cast( + attribute.getValue()); + if (!decorationAttr) + return op->emitOpError( + "Expecting triton_gen.decoration_cache_control attribute"); + if (instructions.size() != 1) + return op->emitOpError("Expecting a single instruction"); + return handleDecorationCacheControl(op, instructions.front(), + decorationAttr); + } + if (attrName.starts_with("triton_gen")) + return handleTritonGenAttr(op, attribute, moduleTranslation); + return success(); + } + +private: + template + static llvm::Metadata *getConstantIntMD(llvm::Type *type, IntTy val) { + return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(type, val)); + } + static LogicalResult handleDecorationCacheControl( + Operation *op, llvm::Instruction *inst, + triton::TritonGEN::DecorationCacheControlAttr attribute) { + ArrayRef attrs = attribute.getDecorations(); + SmallVector decorations; + llvm::LLVMContext &ctx = inst->getContext(); + llvm::Type *i32Ty = llvm::IntegerType::getInt32Ty(ctx); + llvm::transform( + attrs, std::back_inserter(decorations), + [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * { + return TypeSwitch(attr) + .Case( + [&ctx, i32Ty](auto attr) { + constexpr size_t decorationCacheControlArity = 4; + constexpr uint32_t loadCacheControlKey = 6442; + constexpr uint32_t storeCacheControlKey = 6443; + constexpr uint32_t decorationKey = + std::is_same_v< + decltype(attr), + triton::TritonGEN::LoadCacheControlDecorationAttr> + ? loadCacheControlKey + : storeCacheControlKey; + std::array values{ + decorationKey, attr.getCacheLevel(), + static_cast(attr.getCacheControl()), + attr.getOperandNumber()}; + std::array + metadata; + llvm::transform(values, metadata.begin(), + [i32Ty](uint32_t value) { + return getConstantIntMD(i32Ty, value); + }); + return llvm::MDNode::get(ctx, metadata); + }); + }); + constexpr llvm::StringLiteral decorationCacheControlMDName = + "spirv.DecorationCacheControlINTEL"; + inst->setMetadata(decorationCacheControlMDName, + llvm::MDNode::get(ctx, decorations)); + return success(); + } + + LogicalResult + handleTritonGenAttr(Operation *op, NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const { llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); llvm::Function *llvmFunc = moduleTranslation.lookupFunction(cast(op).getName()); @@ -47,7 +114,6 @@ class TritonGENDialectLLVMIRTranslationInterface return success(); } -private: // Checks if the given operation is a kernel function. bool isKernel(Operation *op) const { auto fn = dyn_cast(op);