From 49719dd320b84e8837b585a89b89eeb3de3a7e52 Mon Sep 17 00:00:00 2001 From: Victor Perez Date: Fri, 10 May 2024 16:06:07 +0100 Subject: [PATCH 1/9] [TritonGEN] Add `triton_gen.cache_control` operation Add `triton_gen.cache_control` operation to represent [SPV_INTEL_cache_controls](https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.html) decorations in MLIR. This operation does not convert to any operation in a different dialect, as it supports translation straight to LLVM IR as metadata. `triton-translate` is a new tool used to test this translation. Signed-off-by: Victor Perez --- test/Target/LLVMIR/triton-gen.mlir | 26 ++++ test/TritonGEN/tritongen-invalid.mlir | 16 +++ test/TritonGEN/tritongen.mlir | 15 +++ .../Dialect/TritonGEN/IR/TritonGENAttrDefs.td | 66 +++++++++++ .../Dialect/TritonGEN/IR/TritonGENDialect.td | 1 + .../Dialect/TritonGEN/IR/TritonGENOps.td | 16 +++ .../lib/Dialect/TritonGEN/IR/TritonGENOps.cpp | 36 ++++++ .../TritonGENToLLVMIRTranslation.cpp | 111 +++++++++++++++++- 8 files changed, 283 insertions(+), 4 deletions(-) diff --git a/test/Target/LLVMIR/triton-gen.mlir b/test/Target/LLVMIR/triton-gen.mlir index 6ef952f8c3..6aed1a541f 100644 --- a/test/Target/LLVMIR/triton-gen.mlir +++ b/test/Target/LLVMIR/triton-gen.mlir @@ -16,3 +16,29 @@ llvm.func spir_kernelcc @test_reqd_work_group_size() attributes {triton_gen.reqd // CHECK-DAG: ![[REQD_SUB_GROUP_SIZE]] = !{i64 32} // CHECK-DAG: ![[MAX_WORK_GROUP_SIZE]] = !{i64 128, i64 1, i64 1} // CHECK-DAG: ![[REQD_WORK_GROUP_SIZE]] = !{i64 128, i64 1, i64 2} + +// ----- + +// CHECK-LABEL: define void @triton_gen.cache_controls( +// CHECK-SAME: ptr %[[#ARG0:]]) { +llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) { + %0 = triton_gen.cache_controls %arg0, [#triton_gen.store_cache_control<0, Uncached>, #triton_gen.store_cache_control<1, WriteThrough>, #triton_gen.load_cache_control<0, Cached>, #triton_gen.load_cache_control<1, Uncached>] : !llvm.ptr + %1 = triton_gen.cache_controls %arg0, [#triton_gen.store_cache_control<0, WriteBack>, #triton_gen.store_cache_control<1, Streaming>, #triton_gen.load_cache_control<0, Streaming>, #triton_gen.load_cache_control<1, InvalidateAfterRead>, #triton_gen.load_cache_control<2, ConstCached>] : !llvm.ptr + // CHECK: %[[#LOAD:]] = load i32, ptr %[[#ARG0]], align 4, !spirv.DecorationCacheControlINTEL ![[#DECORATION0:]] + %2 = llvm.load %0 : !llvm.ptr -> i32 + // CHECK: store i32 %[[#LOAD]], ptr %[[#ARG0]], align 4, !spirv.DecorationCacheControlINTEL ![[#DECORATION1:]] + llvm.store %2, %1 : i32, !llvm.ptr + llvm.return +} + +// CHECK-DAG: ![[#DECORATION0]] = !{![[#CACHECONTROL0:]], ![[#CACHECONTROL1:]], ![[#CACHECONTROL2:]], ![[#CACHECONTROL3:]]} +// CHECK-DAG: ![[#CACHECONTROL0]] = !{i32 6443, i32 0, i32 0, i32 0} +// CHECK-DAG: ![[#CACHECONTROL1]] = !{i32 6443, i32 1, i32 1, i32 0} +// CHECK-DAG: ![[#CACHECONTROL2]] = !{i32 6442, i32 0, i32 1, i32 0} +// CHECK-DAG: ![[#CACHECONTROL3]] = !{i32 6442, i32 1, i32 0, i32 0} +// CHECK-DAG: ![[#DECORATION1]] = !{![[#CACHECONTROL4:]], ![[#CACHECONTROL5:]], ![[#CACHECONTROL6:]], ![[#CACHECONTROL7:]], ![[#CACHECONTROL8:]]} +// CHECK-DAG: ![[#CACHECONTROL4]] = !{i32 6443, i32 0, i32 2, i32 1} +// CHECK-DAG: ![[#CACHECONTROL5]] = !{i32 6443, i32 1, i32 3, i32 1} +// CHECK-DAG: ![[#CACHECONTROL6]] = !{i32 6442, i32 0, i32 2, i32 1} +// CHECK-DAG: ![[#CACHECONTROL7]] = !{i32 6442, i32 1, i32 3, i32 1} +// CHECK-DAG: ![[#CACHECONTROL8]] = !{i32 6442, i32 2, i32 4, i32 1} diff --git a/test/TritonGEN/tritongen-invalid.mlir b/test/TritonGEN/tritongen-invalid.mlir index 69097c990f..ed77454988 100644 --- a/test/TritonGEN/tritongen-invalid.mlir +++ b/test/TritonGEN/tritongen-invalid.mlir @@ -1,5 +1,21 @@ // RUN: triton-opt -split-input-file -verify-diagnostics %s +llvm.func @triton_gen.empty_cache_controls(%arg0: !llvm.ptr) { + // expected-error @+1 {{'triton_gen.cache_controls' op expecting at least one cache control decoration}} + %0 = triton_gen.cache_controls %arg0, [] : !llvm.ptr + llvm.return +} + +// ----- + +llvm.func @triton_gen.duplicated_cache_controls(%arg0: !llvm.ptr) { + // expected-error @+1 {{'triton_gen.cache_controls' op cannot specify more than one cache control decoration of the same nature for the same cache level}} + %0 = triton_gen.cache_controls %arg0, [#triton_gen.store_cache_control<0, Uncached>, #triton_gen.store_cache_control<0, Streaming>] : !llvm.ptr + llvm.return +} + +// ----- + llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) { // expected-error @+1 {{'triton_gen.dpas' op expecting repeat count to be 1, 2, 4, or 8}} %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=16} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32> diff --git a/test/TritonGEN/tritongen.mlir b/test/TritonGEN/tritongen.mlir index 4bfbc1c549..9489b66e5a 100644 --- a/test/TritonGEN/tritongen.mlir +++ b/test/TritonGEN/tritongen.mlir @@ -105,6 +105,21 @@ llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8 llvm.return } +llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) { + // CHECK-LABEL: llvm.func @triton_gen.cache_controls( + // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) { + // CHECK: %[[VAL_1:.*]] = triton_gen.cache_controls %[[VAL_0]], [#triton_gen.store_cache_control<0, Uncached>, #triton_gen.store_cache_control<1, WriteThrough>, #triton_gen.load_cache_control<0, Cached>, #triton_gen.load_cache_control<1, Uncached>] : !llvm.ptr + // CHECK: %[[VAL_2:.*]] = triton_gen.cache_controls %[[VAL_0]], [#triton_gen.store_cache_control<0, WriteBack>, #triton_gen.store_cache_control<1, Streaming>, #triton_gen.load_cache_control<0, Streaming>, #triton_gen.load_cache_control<1, InvalidateAfterRead>, #triton_gen.load_cache_control<2, ConstCached>] : !llvm.ptr + // CHECK: %[[VAL_3:.*]] = llvm.load %[[VAL_1]] : !llvm.ptr -> i32 + // CHECK: llvm.store %[[VAL_3]], %[[VAL_2]] : i32, !llvm.ptr + // CHECK: llvm.return + %0 = triton_gen.cache_controls %arg0, [#triton_gen.store_cache_control<0, Uncached>, #triton_gen.store_cache_control<1, WriteThrough>, #triton_gen.load_cache_control<0, Cached>, #triton_gen.load_cache_control<1, Uncached>] : !llvm.ptr + %1 = triton_gen.cache_controls %arg0, [#triton_gen.store_cache_control<0, WriteBack>, #triton_gen.store_cache_control<1, Streaming>, #triton_gen.load_cache_control<0, Streaming>, #triton_gen.load_cache_control<1, InvalidateAfterRead>, #triton_gen.load_cache_control<2, ConstCached>] : !llvm.ptr + %2 = llvm.load %0 : !llvm.ptr -> i32 + llvm.store %2, %1 : i32, !llvm.ptr + llvm.return +} + llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // CHECK: llvm.func @triton_gen.2Dblockload(%arg0: !llvm.ptr, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) { // CHECK-NEXT: %0 = triton_gen.2Dblockload %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 1, transpose = false, vnni_transform = false, cache_control = Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<16xf16> diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td index b609719818..71ab9d2e56 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td @@ -9,8 +9,17 @@ #ifndef TRITONGEN_ATTRDEFS #define TRITONGEN_ATTRDEFS +include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td" + +include "mlir/IR/AttrTypeBase.td" include "mlir/IR/EnumAttr.td" +class TritonGEN_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; + let cppNamespace = "::mlir::triton::TritonGEN"; +} + /// Enum attribute of the different shuffle kinds. def TritonGEN_ShflKindAttr : I32EnumAttr<"ShflKind", "TritonGEN shuffle kind", [ @@ -109,4 +118,61 @@ def TritonGEN_MemScope : I32EnumAttr<"MemScope", let cppNamespace = "::mlir::triton::TritonGEN"; } +/// Enum attribute for load cache controls. +/// +/// See +/// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.asciidoc#Load_Cache_Control. +def TritonGEN_LoadCacheControlDecorationEnum : I32EnumAttr<"LoadCacheControlDecorationEnum", + "TritonGEN load cache controls", + [ I32EnumAttrCase<"Uncached", 0, "Uncached">, + I32EnumAttrCase<"Cached", 1, "Cached">, + I32EnumAttrCase<"Streaming", 2, "Streaming">, + I32EnumAttrCase<"InvalidateAfterRead", 3, "InvalidateAfterRead">, + I32EnumAttrCase<"ConstCached", 4, "ConstCached"> + ]> { + let cppNamespace = "::mlir::triton::TritonGEN"; +} + +class TritonGEN_LoadStoreCacheControlDecoration + : TritonGEN_Attr { + let summary = "An attribute specifying " # !tolower(loadOrStore) # " cache control"; + let description = [{ + A }] # !tolower(loadOrStore) # [{ cache control attribute has a one-to-one correspondance with the SPIR-V + decoration shown in + https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.asciidoc#decorations. The + only difference is there is no need to add the SPIR-V decoration key as that + will be inferred from the value. + }]; + let parameters = (ins "uint32_t":$cache_level, + loadOrStore # "CacheControlDecorationEnum":$cache_control); + let assemblyFormat = "`<` $cache_level `,` $cache_control `>`"; +} + +def TritonGEN_LoadCacheControlDecoration + : TritonGEN_LoadStoreCacheControlDecoration<"Load">; + +/// Enum attribute for store cache controls. +/// +/// See +/// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.asciidoc#Store_Cache_Control. +def TritonGEN_StoreCacheControlDecorationEnum : I32EnumAttr<"StoreCacheControlDecorationEnum", + "TritonGEN store cache controls", + [ I32EnumAttrCase<"Uncached", 0, "Uncached">, + I32EnumAttrCase<"WriteThrough", 1, "WriteThrough">, + I32EnumAttrCase<"WriteBack", 2, "WriteBack">, + I32EnumAttrCase<"Streaming", 3, "Streaming">, + ]> { + let cppNamespace = "::mlir::triton::TritonGEN"; +} + +def TritonGEN_StoreCacheControlDecoration + : TritonGEN_LoadStoreCacheControlDecoration<"Store">; + +def CacheControls + : AnyAttrOf<[TritonGEN_LoadCacheControlDecoration, + TritonGEN_StoreCacheControlDecoration]>; + +def CacheControlsArray + : TypedArrayAttrBase; + #endif // TRITONGEN_ATTRDEFS diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td index 232e849f6f..872d48ebc3 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td @@ -15,6 +15,7 @@ def TritonGEN_Dialect : Dialect { let name = "triton_gen"; let cppNamespace = "::mlir::triton::TritonGEN"; let summary = "The TritonGEN dialect in Triton."; + let useDefaultAttributePrinterParser = 1; let description = [{ TritonGEN is a dialect for representing operations on Intel GPUs. diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td index 98f35a8c84..a12d24a99a 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td @@ -263,6 +263,22 @@ def TritonGEN_SubGroupShuffleOp : TritonGEN_Op<"sub_group_shuffle", [ }]; } +//===----------------------------------------------------------------------===// +// Metaoperations +//===----------------------------------------------------------------------===// + +def TritonGEN_CacheControls + : TritonGEN_Op<"cache_controls", + [AllTypesMatch<["ptr", "decorated_ptr"]>, Pure]> { + let arguments = (ins LLVM_AnyPointer:$ptr, + CacheControlsArray:$cache_controls); + let results = (outs LLVM_AnyPointer:$decorated_ptr); + let assemblyFormat = [{ + $ptr `,` $cache_controls attr-dict `:` type($ptr) + }]; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // Matrix operations //===----------------------------------------------------------------------===// diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp index cae45ade80..01f63371f0 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp @@ -12,6 +12,8 @@ #include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/TypeSwitch.h" #include using namespace mlir; @@ -148,6 +150,40 @@ LogicalResult TritonGEN::MatrixDPASOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// triton_gen.cache_controls +//===----------------------------------------------------------------------===// + +LogicalResult TritonGEN::CacheControls::verify() { + ArrayRef cacheControls = getCacheControls().getValue(); + if (cacheControls.empty()) + return emitOpError("expecting at least one cache control decoration"); + llvm::SmallSetVector loadCacheLevels; + llvm::SmallSetVector storeCacheLevels; + for (Attribute attr : cacheControls) { + LogicalResult res = + TypeSwitch(attr) + .Case([this, &loadCacheLevels, + &storeCacheLevels]( + auto attr) + -> LogicalResult { + llvm::SmallSetVector &cacheLevels = + std::is_same_v + ? loadCacheLevels + : storeCacheLevels; + if (!cacheLevels.insert(attr.getCacheLevel())) + return emitOpError( + "cannot specify more than one cache control " + "decoration of the same nature for the same cache level"); + return success(); + }); + if (failed(res)) + return res; + } + return success(); +} + //===----------------------------------------------------------------------===// // gen.2Dblockload //===----------------------------------------------------------------------===// diff --git a/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp index 1c710e6527..36137f5e98 100644 --- a/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp +++ b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp @@ -31,14 +31,118 @@ class TritonGENDialectLLVMIRTranslationInterface public: using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + constexpr static std::size_t decorationCacheControlArity = 4; + constexpr static llvm::StringLiteral decorationCacheControlAttrName = + "triton_gen.DecorationCacheControlINTEL"; + LogicalResult amendOperation(Operation *op, ArrayRef instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const final { - // Skip the attribute if it is not a TritonGEN attribute. - if (!attribute.getName().getValue().starts_with("triton_gen")) - return success(); + StringRef attrName = attribute.getName().getValue(); + if (attrName == decorationCacheControlAttrName) { + assert(instructions.size() == 1 && "Expecting a single instruction"); + return handleDecorationCacheControl(instructions.front(), attribute, + moduleTranslation); + } + if (attrName.starts_with("triton_gen")) + return handleTritonGenAttr(op, attribute, moduleTranslation); + return success(); + } + LogicalResult + convertOperation(Operation *operation, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final { + return TypeSwitch(operation) + .Case([&moduleTranslation](triton::TritonGEN::CacheControls op) { + llvm::Value *ptr = moduleTranslation.lookupValue(op.getPtr()); + moduleTranslation.mapValue(op, ptr); + Builder mlirBuilder(op); + for (OpOperand &use : op->getUses()) + appendDecoration(mlirBuilder, decorationCacheControlAttrName, + use.getOwner(), op.getCacheControls(), + use.getOperandNumber()); + return success(); + }) + .Default([](Operation *op) { + return op->emitError("unsupported TritonGEN operation: ") + << op->getName(); + }); + } + +private: + template + static llvm::Metadata *getConstantIntMD(llvm::Type *type, IntTy val) { + return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(type, val)); + } + + static void appendDecoration(Builder &mlirBuilder, llvm::StringRef mdKey, + Operation *op, ArrayAttr newAttrs, + unsigned operandNumber) { + auto attr = op->getAttrOfType(mdKey); + SmallVector attrs = attr + ? SmallVector{attr.getValue()} + : SmallVector{}; + llvm::transform( + newAttrs.getValue(), std::back_inserter(attrs), + [&mlirBuilder, operandNumber](Attribute attr) -> Attribute { + return TypeSwitch(attr) + .Case( + [&mlirBuilder, operandNumber](auto attr) { + constexpr int32_t loadCacheControlKey = 6442; + constexpr int32_t storeCacheControlKey = 6443; + constexpr int32_t key = + std::is_same_v< + decltype(attr), + triton::TritonGEN::LoadCacheControlDecorationAttr> + ? loadCacheControlKey + : storeCacheControlKey; + int32_t cacheLevel = attr.getCacheLevel(); + int32_t cacheControl = + static_cast(attr.getCacheControl()); + return mlirBuilder.getDenseI32ArrayAttr( + {key, cacheLevel, cacheControl, + static_cast(operandNumber)}); + }); + }); + op->setAttr(mdKey, mlirBuilder.getArrayAttr(attrs)); + } + + static LogicalResult + handleDecorationCacheControl(llvm::Instruction *inst, + NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) { + assert(attribute.getName() == decorationCacheControlAttrName && + "Expecting decoration cache key"); + auto arrayAttr = cast(attribute.getValue()); + ArrayRef attrs = arrayAttr.getValue(); + SmallVector decorations; + llvm::LLVMContext &ctx = inst->getContext(); + llvm::transform( + attrs, std::back_inserter(decorations), [&ctx](Attribute attr) { + auto arrayAttr = cast(attr); + ArrayRef attrs = arrayAttr.asArrayRef(); + assert(attrs.size() == decorationCacheControlArity && + "Invalid decoration cache attribute arity"); + constexpr unsigned numBits = 32; + llvm::Type *type = llvm::IntegerType::get(ctx, numBits); + std::array metadata; + llvm::transform(attrs, metadata.begin(), [type](int val) { + return getConstantIntMD(type, val); + }); + return llvm::MDNode::get(ctx, metadata); + }); + constexpr static llvm::StringLiteral decorationCacheControlMDName = + "spirv.DecorationCacheControlINTEL"; + inst->setMetadata(decorationCacheControlMDName, + llvm::MDNode::get(ctx, decorations)); + return success(); + } + + LogicalResult + handleTritonGenAttr(Operation *op, NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const { llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); llvm::Function *llvmFunc = moduleTranslation.lookupFunction(cast(op).getName()); @@ -47,7 +151,6 @@ class TritonGENDialectLLVMIRTranslationInterface return success(); } -private: // Checks if the given operation is a kernel function. bool isKernel(Operation *op) const { auto fn = dyn_cast(op); From aca16892865d99e7c74e43408f9cd6d0e6f32184 Mon Sep 17 00:00:00 2001 From: Victor Perez Date: Wed, 29 May 2024 16:03:42 +0100 Subject: [PATCH 2/9] Document new operation. --- .../Dialect/TritonGEN/IR/TritonGENOps.td | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td index a12d24a99a..5db9a58429 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td @@ -270,6 +270,43 @@ def TritonGEN_SubGroupShuffleOp : TritonGEN_Op<"sub_group_shuffle", [ def TritonGEN_CacheControls : TritonGEN_Op<"cache_controls", [AllTypesMatch<["ptr", "decorated_ptr"]>, Pure]> { + let summary = "Operation implementing SPIR-V cache control decorations"; + let description = [{ + This operation implements an MLIR representation of SPIR-V cache controls. + `triton_gen.cache_controls` can be translated directly to LLVM IR. User + instructions in the resulting translation will present metadata representing + cache controls as specified in + https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.asciidoc. + This way, the following MLIR code: + + ```mlir + llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) { + %0 = triton_gen.cache_controls %arg0, [#triton_gen.store_cache_control<0, WriteBack>, #triton_gen.store_cache_control<1, Streaming>, #triton_gen.load_cache_control<0, Streaming>, #triton_gen.load_cache_control<1, InvalidateAfterRead>, #triton_gen.load_cache_control<2, ConstCached>] : !llvm.ptr + %1 = llvm.load %0 : !llvm.ptr -> i32 + llvm.return + } + ``` + + Will be translated to the following LLVM IR: + + ```llvm + define void @triton_gen.cache_controls(ptr %0) { + %2 = load i32, ptr %0, align 4, !spirv.DecorationCacheControlINTEL !1 + ret void + } + + !1 = !{!2, !3, !4, !5, !6} + !2 = !{i32 6443, i32 0, i32 2, i32 0} + !3 = !{i32 6443, i32 1, i32 3, i32 0} + !4 = !{i32 6442, i32 0, i32 2, i32 0} + !5 = !{i32 6442, i32 1, i32 3, i32 0} + !6 = !{i32 6442, i32 2, i32 4, i32 0} + ``` + + $cache_controls must be a non-empty list of cache controls attributes + and two attributes of the same nature (load or store) in the list cannot be + applied to the same cache level, as per the SPIR-V validation rules. + }]; let arguments = (ins LLVM_AnyPointer:$ptr, CacheControlsArray:$cache_controls); let results = (outs LLVM_AnyPointer:$decorated_ptr); From 8172daa1a86ac45390f0093521acc535e50c23fd Mon Sep 17 00:00:00 2001 From: Victor Perez Date: Wed, 29 May 2024 16:09:15 +0100 Subject: [PATCH 3/9] NIT --- .../Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp index 36137f5e98..be0a6aa0f3 100644 --- a/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp +++ b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp @@ -31,7 +31,6 @@ class TritonGENDialectLLVMIRTranslationInterface public: using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; - constexpr static std::size_t decorationCacheControlArity = 4; constexpr static llvm::StringLiteral decorationCacheControlAttrName = "triton_gen.DecorationCacheControlINTEL"; @@ -121,6 +120,8 @@ class TritonGENDialectLLVMIRTranslationInterface llvm::LLVMContext &ctx = inst->getContext(); llvm::transform( attrs, std::back_inserter(decorations), [&ctx](Attribute attr) { + constexpr std::size_t decorationCacheControlArity = 4; + auto arrayAttr = cast(attr); ArrayRef attrs = arrayAttr.asArrayRef(); assert(attrs.size() == decorationCacheControlArity && @@ -133,7 +134,7 @@ class TritonGENDialectLLVMIRTranslationInterface }); return llvm::MDNode::get(ctx, metadata); }); - constexpr static llvm::StringLiteral decorationCacheControlMDName = + constexpr llvm::StringLiteral decorationCacheControlMDName = "spirv.DecorationCacheControlINTEL"; inst->setMetadata(decorationCacheControlMDName, llvm::MDNode::get(ctx, decorations)); From 2cca769037f64efb002c8ad843fc7459bef27282 Mon Sep 17 00:00:00 2001 From: Victor Perez Date: Wed, 29 May 2024 16:29:30 +0100 Subject: [PATCH 4/9] Fail gracefully instead of erroring out --- .../TritonGENToLLVMIRTranslation.cpp | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp index be0a6aa0f3..b122373150 100644 --- a/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp +++ b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp @@ -40,8 +40,9 @@ class TritonGENDialectLLVMIRTranslationInterface LLVM::ModuleTranslation &moduleTranslation) const final { StringRef attrName = attribute.getName().getValue(); if (attrName == decorationCacheControlAttrName) { - assert(instructions.size() == 1 && "Expecting a single instruction"); - return handleDecorationCacheControl(instructions.front(), attribute, + if (instructions.size() != 1) + return op->emitOpError("Expecting a single instruction"); + return handleDecorationCacheControl(op, instructions.front(), attribute, moduleTranslation); } if (attrName.starts_with("triton_gen")) @@ -64,7 +65,7 @@ class TritonGENDialectLLVMIRTranslationInterface return success(); }) .Default([](Operation *op) { - return op->emitError("unsupported TritonGEN operation: ") + return op->emitOpError("unsupported TritonGEN operation: ") << op->getName(); }); } @@ -109,31 +110,33 @@ class TritonGENDialectLLVMIRTranslationInterface } static LogicalResult - handleDecorationCacheControl(llvm::Instruction *inst, + handleDecorationCacheControl(Operation *op, llvm::Instruction *inst, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) { assert(attribute.getName() == decorationCacheControlAttrName && "Expecting decoration cache key"); - auto arrayAttr = cast(attribute.getValue()); + auto arrayAttr = dyn_cast(attribute.getValue()); + if (!arrayAttr) + return op->emitOpError("unexpected attribute type"); ArrayRef attrs = arrayAttr.getValue(); SmallVector decorations; llvm::LLVMContext &ctx = inst->getContext(); - llvm::transform( - attrs, std::back_inserter(decorations), [&ctx](Attribute attr) { - constexpr std::size_t decorationCacheControlArity = 4; + for (Attribute attr : attrs) { + constexpr std::size_t decorationCacheControlArity = 4; - auto arrayAttr = cast(attr); - ArrayRef attrs = arrayAttr.asArrayRef(); - assert(attrs.size() == decorationCacheControlArity && - "Invalid decoration cache attribute arity"); - constexpr unsigned numBits = 32; - llvm::Type *type = llvm::IntegerType::get(ctx, numBits); - std::array metadata; - llvm::transform(attrs, metadata.begin(), [type](int val) { - return getConstantIntMD(type, val); - }); - return llvm::MDNode::get(ctx, metadata); - }); + auto arrayAttr = dyn_cast(attr); + if (!arrayAttr) + return op->emitOpError("unexpected attribute type"); + ArrayRef attrs = arrayAttr.asArrayRef(); + if (attrs.size() != decorationCacheControlArity) + return op->emitOpError("Invalid decoration cache attribute arity"); + constexpr unsigned numBits = 32; + llvm::Type *type = llvm::IntegerType::get(ctx, numBits); + std::array metadata; + llvm::transform(attrs, metadata.begin(), + [type](int val) { return getConstantIntMD(type, val); }); + decorations.push_back(llvm::MDNode::get(ctx, metadata)); + } constexpr llvm::StringLiteral decorationCacheControlMDName = "spirv.DecorationCacheControlINTEL"; inst->setMetadata(decorationCacheControlMDName, From 52063ecfadb1e62c9956d6f75ab32f2691819737 Mon Sep 17 00:00:00 2001 From: Victor Perez Date: Thu, 30 May 2024 09:39:08 +0100 Subject: [PATCH 5/9] Use `SmallSet` --- .../lib/Dialect/TritonGEN/IR/TritonGENOps.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp index 01f63371f0..0d7dfaf7ba 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp @@ -12,7 +12,7 @@ #include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/TypeSwitch.h" #include @@ -158,8 +158,8 @@ LogicalResult TritonGEN::CacheControls::verify() { ArrayRef cacheControls = getCacheControls().getValue(); if (cacheControls.empty()) return emitOpError("expecting at least one cache control decoration"); - llvm::SmallSetVector loadCacheLevels; - llvm::SmallSetVector storeCacheLevels; + llvm::SmallSet loadCacheLevels; + llvm::SmallSet storeCacheLevels; for (Attribute attr : cacheControls) { LogicalResult res = TypeSwitch(attr) @@ -168,14 +168,14 @@ LogicalResult TritonGEN::CacheControls::verify() { &storeCacheLevels]( auto attr) -> LogicalResult { - llvm::SmallSetVector &cacheLevels = + llvm::SmallSet &cacheLevels = std::is_same_v ? loadCacheLevels : storeCacheLevels; - if (!cacheLevels.insert(attr.getCacheLevel())) + if (!cacheLevels.insert(attr.getCacheLevel()).second) return emitOpError( - "cannot specify more than one cache control " - "decoration of the same nature for the same cache level"); + "cannot specify more than one cache control decoration of " + "the same nature for the same cache level"); return success(); }); if (failed(res)) From 076bf462503bddf7fe82fe5a4be59b6cc5613b92 Mon Sep 17 00:00:00 2001 From: Victor Perez Date: Thu, 30 May 2024 16:48:32 +0100 Subject: [PATCH 6/9] Use attribute to represent cache controls Signed-off-by: Victor Perez --- test/Target/LLVMIR/triton-gen.mlir | 14 ++- test/TritonGEN/tritongen-invalid.mlir | 12 +- test/TritonGEN/tritongen.mlir | 20 ++-- .../Dialect/TritonGEN/IR/TritonGENAttrDefs.td | 83 +++++++++---- .../Dialect/TritonGEN/IR/TritonGENDialect.td | 6 + .../Dialect/TritonGEN/IR/TritonGENOps.td | 53 -------- .../lib/Dialect/TritonGEN/IR/CMakeLists.txt | 1 + .../Dialect/TritonGEN/IR/TritonGENAttrs.cpp | 57 +++++++++ .../lib/Dialect/TritonGEN/IR/TritonGENOps.cpp | 36 ------ .../TritonGENToLLVMIRTranslation.cpp | 113 ++++++------------ 10 files changed, 186 insertions(+), 209 deletions(-) create mode 100644 third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENAttrs.cpp diff --git a/test/Target/LLVMIR/triton-gen.mlir b/test/Target/LLVMIR/triton-gen.mlir index 6aed1a541f..ef0a279baa 100644 --- a/test/Target/LLVMIR/triton-gen.mlir +++ b/test/Target/LLVMIR/triton-gen.mlir @@ -19,26 +19,30 @@ llvm.func spir_kernelcc @test_reqd_work_group_size() attributes {triton_gen.reqd // ----- +llvm.func @foo(%arg0: !llvm.ptr, %arg1: !llvm.ptr) + // CHECK-LABEL: define void @triton_gen.cache_controls( // CHECK-SAME: ptr %[[#ARG0:]]) { llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) { - %0 = triton_gen.cache_controls %arg0, [#triton_gen.store_cache_control<0, Uncached>, #triton_gen.store_cache_control<1, WriteThrough>, #triton_gen.load_cache_control<0, Cached>, #triton_gen.load_cache_control<1, Uncached>] : !llvm.ptr - %1 = triton_gen.cache_controls %arg0, [#triton_gen.store_cache_control<0, WriteBack>, #triton_gen.store_cache_control<1, Streaming>, #triton_gen.load_cache_control<0, Streaming>, #triton_gen.load_cache_control<1, InvalidateAfterRead>, #triton_gen.load_cache_control<2, ConstCached>] : !llvm.ptr // CHECK: %[[#LOAD:]] = load i32, ptr %[[#ARG0]], align 4, !spirv.DecorationCacheControlINTEL ![[#DECORATION0:]] - %2 = llvm.load %0 : !llvm.ptr -> i32 + %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<1, WriteThrough, 0>, #triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32 // CHECK: store i32 %[[#LOAD]], ptr %[[#ARG0]], align 4, !spirv.DecorationCacheControlINTEL ![[#DECORATION1:]] - llvm.store %2, %1 : i32, !llvm.ptr + llvm.store %0, %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, WriteBack, 1>, #triton_gen.store_cache_control<1, Streaming, 1>, #triton_gen.load_cache_control<0, Streaming, 1>, #triton_gen.load_cache_control<1, InvalidateAfterRead, 1>, #triton_gen.load_cache_control<2, ConstCached, 1>>} : i32, !llvm.ptr + // CHECK: call void @foo(ptr %[[#ARG0]], ptr %[[#ARG0]]), !spirv.DecorationCacheControlINTEL ![[#DECORATION2:]] + llvm.call @foo(%arg0, %arg0) {triton_gen.DecorationCacheControlINTEL =#triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<0, Cached, 1>>} : (!llvm.ptr, !llvm.ptr) -> () llvm.return } // CHECK-DAG: ![[#DECORATION0]] = !{![[#CACHECONTROL0:]], ![[#CACHECONTROL1:]], ![[#CACHECONTROL2:]], ![[#CACHECONTROL3:]]} +// CHECK-DAG: ![[#DECORATION1]] = !{![[#CACHECONTROL4:]], ![[#CACHECONTROL5:]], ![[#CACHECONTROL6:]], ![[#CACHECONTROL7:]], ![[#CACHECONTROL8:]]} +// CHECK-DAG: ![[#DECORATION2]] = !{![[#CACHECONTROL0:]], ![[#CACHECONTROL9:]]} // CHECK-DAG: ![[#CACHECONTROL0]] = !{i32 6443, i32 0, i32 0, i32 0} // CHECK-DAG: ![[#CACHECONTROL1]] = !{i32 6443, i32 1, i32 1, i32 0} // CHECK-DAG: ![[#CACHECONTROL2]] = !{i32 6442, i32 0, i32 1, i32 0} // CHECK-DAG: ![[#CACHECONTROL3]] = !{i32 6442, i32 1, i32 0, i32 0} -// CHECK-DAG: ![[#DECORATION1]] = !{![[#CACHECONTROL4:]], ![[#CACHECONTROL5:]], ![[#CACHECONTROL6:]], ![[#CACHECONTROL7:]], ![[#CACHECONTROL8:]]} // CHECK-DAG: ![[#CACHECONTROL4]] = !{i32 6443, i32 0, i32 2, i32 1} // CHECK-DAG: ![[#CACHECONTROL5]] = !{i32 6443, i32 1, i32 3, i32 1} // CHECK-DAG: ![[#CACHECONTROL6]] = !{i32 6442, i32 0, i32 2, i32 1} // CHECK-DAG: ![[#CACHECONTROL7]] = !{i32 6442, i32 1, i32 3, i32 1} // CHECK-DAG: ![[#CACHECONTROL8]] = !{i32 6442, i32 2, i32 4, i32 1} +// CHECK-DAG: ![[#CACHECONTROL9]] = !{i32 6442, i32 0, i32 1, i32 1} diff --git a/test/TritonGEN/tritongen-invalid.mlir b/test/TritonGEN/tritongen-invalid.mlir index ed77454988..fb0d117d60 100644 --- a/test/TritonGEN/tritongen-invalid.mlir +++ b/test/TritonGEN/tritongen-invalid.mlir @@ -1,16 +1,16 @@ // RUN: triton-opt -split-input-file -verify-diagnostics %s -llvm.func @triton_gen.empty_cache_controls(%arg0: !llvm.ptr) { - // expected-error @+1 {{'triton_gen.cache_controls' op expecting at least one cache control decoration}} - %0 = triton_gen.cache_controls %arg0, [] : !llvm.ptr +llvm.func @triton_gen.duplicated_cache_controls(%arg0: !llvm.ptr) { + // expected-error @+1 {{'triton_gen.decoration_cache_controls' cannot specify more than one cache control decoration of the same nature for the same cache level}} + %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<0, Streaming, 0>>} : !llvm.ptr -> i32 llvm.return } // ----- -llvm.func @triton_gen.duplicated_cache_controls(%arg0: !llvm.ptr) { - // expected-error @+1 {{'triton_gen.cache_controls' op cannot specify more than one cache control decoration of the same nature for the same cache level}} - %0 = triton_gen.cache_controls %arg0, [#triton_gen.store_cache_control<0, Uncached>, #triton_gen.store_cache_control<0, Streaming>] : !llvm.ptr +llvm.func @triton_gen.illegal_cache_controls_attr(%arg0: !llvm.ptr) { + // expected-error @+1 {{'triton_gen.decoration_cache_controls' only accepts LoadCacheControlDecorationAttr and StoreCacheControlDecorationAttr attributes}} + %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL =#triton_gen.decoration_cache_control<1 : i32>} : !llvm.ptr -> i32 llvm.return } diff --git a/test/TritonGEN/tritongen.mlir b/test/TritonGEN/tritongen.mlir index 9489b66e5a..8fe926dfbd 100644 --- a/test/TritonGEN/tritongen.mlir +++ b/test/TritonGEN/tritongen.mlir @@ -105,18 +105,16 @@ llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8 llvm.return } +llvm.func @foo(%arg0: !llvm.ptr, %arg1: !llvm.ptr) + llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) { - // CHECK-LABEL: llvm.func @triton_gen.cache_controls( - // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) { - // CHECK: %[[VAL_1:.*]] = triton_gen.cache_controls %[[VAL_0]], [#triton_gen.store_cache_control<0, Uncached>, #triton_gen.store_cache_control<1, WriteThrough>, #triton_gen.load_cache_control<0, Cached>, #triton_gen.load_cache_control<1, Uncached>] : !llvm.ptr - // CHECK: %[[VAL_2:.*]] = triton_gen.cache_controls %[[VAL_0]], [#triton_gen.store_cache_control<0, WriteBack>, #triton_gen.store_cache_control<1, Streaming>, #triton_gen.load_cache_control<0, Streaming>, #triton_gen.load_cache_control<1, InvalidateAfterRead>, #triton_gen.load_cache_control<2, ConstCached>] : !llvm.ptr - // CHECK: %[[VAL_3:.*]] = llvm.load %[[VAL_1]] : !llvm.ptr -> i32 - // CHECK: llvm.store %[[VAL_3]], %[[VAL_2]] : i32, !llvm.ptr - // CHECK: llvm.return - %0 = triton_gen.cache_controls %arg0, [#triton_gen.store_cache_control<0, Uncached>, #triton_gen.store_cache_control<1, WriteThrough>, #triton_gen.load_cache_control<0, Cached>, #triton_gen.load_cache_control<1, Uncached>] : !llvm.ptr - %1 = triton_gen.cache_controls %arg0, [#triton_gen.store_cache_control<0, WriteBack>, #triton_gen.store_cache_control<1, Streaming>, #triton_gen.load_cache_control<0, Streaming>, #triton_gen.load_cache_control<1, InvalidateAfterRead>, #triton_gen.load_cache_control<2, ConstCached>] : !llvm.ptr - %2 = llvm.load %0 : !llvm.ptr -> i32 - llvm.store %2, %1 : i32, !llvm.ptr + // CHECK: llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) + // CHECK-NEXT: %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<1, WriteThrough, 0>, #triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32 + %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<1, WriteThrough, 0>, #triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32 + // CHECK-NEXT: llvm.store %0, %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, WriteBack, 1>, #triton_gen.store_cache_control<1, Streaming, 1>, #triton_gen.load_cache_control<0, Streaming, 1>, #triton_gen.load_cache_control<1, InvalidateAfterRead, 1>, #triton_gen.load_cache_control<2, ConstCached, 1>>} : i32, !llvm.ptr + llvm.store %0, %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, WriteBack, 1>, #triton_gen.store_cache_control<1, Streaming, 1>, #triton_gen.load_cache_control<0, Streaming, 1>, #triton_gen.load_cache_control<1, InvalidateAfterRead, 1>, #triton_gen.load_cache_control<2, ConstCached, 1>>} : i32, !llvm.ptr + // CHECK-NEXT: llvm.call @foo(%arg0, %arg0) {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<0, Cached, 1>>} : (!llvm.ptr, !llvm.ptr) -> () + llvm.call @foo(%arg0, %arg0) {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<0, Cached, 1>>} : (!llvm.ptr, !llvm.ptr) -> () llvm.return } diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td index 71ab9d2e56..0003bfad29 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td @@ -118,6 +118,25 @@ def TritonGEN_MemScope : I32EnumAttr<"MemScope", let cppNamespace = "::mlir::triton::TritonGEN"; } +class TritonGEN_LoadStoreCacheControlDecoration + : TritonGEN_Attr { + let summary = "An attribute specifying " # !tolower(loadOrStore) # " cache control"; + let description = [{ + A }] # !tolower(loadOrStore) # [{ cache control attribute has a one-to-one + correspondance with the SPIR-V decoration shown in + https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.asciidoc#decorations. The + only differences are there is no need to add the SPIR-V decoration key, as + that will be inferred from the attribute type, and an additional + `operand_number` parameter is needed, as this is applied to pointer user + operations, following + https://github.com/KhronosGroup/SPIRV-LLVM-Translator/pull/2587 design. + }]; + let parameters = (ins "uint32_t":$cache_level, + loadOrStore # "CacheControlDecorationEnum":$cache_control, + "uint32_t":$operand_number); + let assemblyFormat = "`<` $cache_level `,` $cache_control `,` $operand_number `>`"; +} + /// Enum attribute for load cache controls. /// /// See @@ -133,21 +152,6 @@ def TritonGEN_LoadCacheControlDecorationEnum : I32EnumAttr<"LoadCacheControlDeco let cppNamespace = "::mlir::triton::TritonGEN"; } -class TritonGEN_LoadStoreCacheControlDecoration - : TritonGEN_Attr { - let summary = "An attribute specifying " # !tolower(loadOrStore) # " cache control"; - let description = [{ - A }] # !tolower(loadOrStore) # [{ cache control attribute has a one-to-one correspondance with the SPIR-V - decoration shown in - https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.asciidoc#decorations. The - only difference is there is no need to add the SPIR-V decoration key as that - will be inferred from the value. - }]; - let parameters = (ins "uint32_t":$cache_level, - loadOrStore # "CacheControlDecorationEnum":$cache_control); - let assemblyFormat = "`<` $cache_level `,` $cache_control `>`"; -} - def TritonGEN_LoadCacheControlDecoration : TritonGEN_LoadStoreCacheControlDecoration<"Load">; @@ -168,11 +172,48 @@ def TritonGEN_StoreCacheControlDecorationEnum : I32EnumAttr<"StoreCacheControlDe def TritonGEN_StoreCacheControlDecoration : TritonGEN_LoadStoreCacheControlDecoration<"Store">; -def CacheControls - : AnyAttrOf<[TritonGEN_LoadCacheControlDecoration, - TritonGEN_StoreCacheControlDecoration]>; - -def CacheControlsArray - : TypedArrayAttrBase; +def TritonGEN_CacheControlsDecoration + : TritonGEN_Attr<"DecorationCacheControl", "decoration_cache_control"> { + let summary = "An attribute specifying an operation cache controls"; + let description = [{ + Attribute corresponding to `!spirv.DecorationCacheControlINTEL` metadata + described in + https://github.com/KhronosGroup/SPIRV-LLVM-Translator/pull/2587. This + metadata is a list of `LoadCacheControlDecoration` and + `StoreCacheControlDecoration` attributes specifying cache control + information. + + The following MLIR code using this attribute: + + ```mlir + llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) { + %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<1, WriteThrough, 0>, #triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32 + llvm.return + } + ``` + + Will be translated to the following LLVM IR: + + ```llvm + define void @triton_gen.cache_controls(ptr %0) { + %1 = load i32, ptr %0, align 4, !spirv.DecorationCacheControlINTEL !1 + ret void + } + + !1 = !{!2, !3, !4, !5} + !2 = !{i32 6443, i32 0, i32 0, i32 0} + !3 = !{i32 6443, i32 1, i32 1, i32 0} + !4 = !{i32 6442, i32 0, i32 1, i32 0} + !5 = !{i32 6442, i32 1, i32 0, i32 0} + ``` + + $decorations must be a non-empty list of cache controls attributes and two + attributes of the same nature (load or store) in the list cannot be applied + to the same cache level, as per the SPIR-V validation rules. + }]; + let parameters = (ins ArrayRefParameter<"::mlir::Attribute">:$decorations); + let assemblyFormat = "`<` $decorations `>`"; + let genVerifyDecl = 1; +} #endif // TRITONGEN_ATTRDEFS diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td index 872d48ebc3..c7ade3af3d 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td @@ -41,6 +41,12 @@ def TritonGEN_Dialect : Dialect { static constexpr ::llvm::StringLiteral getReqdSubGroupSizeAttrName() { return ::llvm::StringLiteral("triton_gen.intel_reqd_sub_group_size"); } + + /// Get the name for the attribute used to specify cache control + /// decorations. + static constexpr ::llvm::StringRef getCacheControlsAttrName() { + return ::llvm::StringLiteral("triton_gen.DecorationCacheControlINTEL"); + } }]; } diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td index 5db9a58429..98f35a8c84 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td @@ -263,59 +263,6 @@ def TritonGEN_SubGroupShuffleOp : TritonGEN_Op<"sub_group_shuffle", [ }]; } -//===----------------------------------------------------------------------===// -// Metaoperations -//===----------------------------------------------------------------------===// - -def TritonGEN_CacheControls - : TritonGEN_Op<"cache_controls", - [AllTypesMatch<["ptr", "decorated_ptr"]>, Pure]> { - let summary = "Operation implementing SPIR-V cache control decorations"; - let description = [{ - This operation implements an MLIR representation of SPIR-V cache controls. - `triton_gen.cache_controls` can be translated directly to LLVM IR. User - instructions in the resulting translation will present metadata representing - cache controls as specified in - https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.asciidoc. - This way, the following MLIR code: - - ```mlir - llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) { - %0 = triton_gen.cache_controls %arg0, [#triton_gen.store_cache_control<0, WriteBack>, #triton_gen.store_cache_control<1, Streaming>, #triton_gen.load_cache_control<0, Streaming>, #triton_gen.load_cache_control<1, InvalidateAfterRead>, #triton_gen.load_cache_control<2, ConstCached>] : !llvm.ptr - %1 = llvm.load %0 : !llvm.ptr -> i32 - llvm.return - } - ``` - - Will be translated to the following LLVM IR: - - ```llvm - define void @triton_gen.cache_controls(ptr %0) { - %2 = load i32, ptr %0, align 4, !spirv.DecorationCacheControlINTEL !1 - ret void - } - - !1 = !{!2, !3, !4, !5, !6} - !2 = !{i32 6443, i32 0, i32 2, i32 0} - !3 = !{i32 6443, i32 1, i32 3, i32 0} - !4 = !{i32 6442, i32 0, i32 2, i32 0} - !5 = !{i32 6442, i32 1, i32 3, i32 0} - !6 = !{i32 6442, i32 2, i32 4, i32 0} - ``` - - $cache_controls must be a non-empty list of cache controls attributes - and two attributes of the same nature (load or store) in the list cannot be - applied to the same cache level, as per the SPIR-V validation rules. - }]; - let arguments = (ins LLVM_AnyPointer:$ptr, - CacheControlsArray:$cache_controls); - let results = (outs LLVM_AnyPointer:$decorated_ptr); - let assemblyFormat = [{ - $ptr `,` $cache_controls attr-dict `:` type($ptr) - }]; - let hasVerifier = 1; -} - //===----------------------------------------------------------------------===// // Matrix operations //===----------------------------------------------------------------------===// diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/CMakeLists.txt b/third_party/intel/lib/Dialect/TritonGEN/IR/CMakeLists.txt index 1c53ec8a2f..4cae1191e5 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/CMakeLists.txt +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/CMakeLists.txt @@ -1,4 +1,5 @@ add_triton_library(TritonGENIR + TritonGENAttrs.cpp TritonGENDialect.cpp TritonGENOps.cpp diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENAttrs.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENAttrs.cpp new file mode 100644 index 0000000000..8ef13f6ecb --- /dev/null +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENAttrs.cpp @@ -0,0 +1,57 @@ +//===- TritonGENAttrs.cpp - TritonGEN Attributes Definition --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" + +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir::triton::TritonGEN { +//===----------------------------------------------------------------------===// +// triton_gen.decoration_cache_control +//===----------------------------------------------------------------------===// + +LogicalResult TritonGEN::DecorationCacheControlAttr::verify( + llvm::function_ref emitError, + ArrayRef decorations) { + if (decorations.empty()) + return emitError() << "'triton_gen.decoration_cache_controls' expecting at " + "least one cache control decoration"; + llvm::SmallSet loadCacheLevels; + llvm::SmallSet storeCacheLevels; + for (Attribute attr : decorations) { + LogicalResult res = + TypeSwitch(attr) + .Case([emitError, &loadCacheLevels, + &storeCacheLevels]( + auto attr) + -> LogicalResult { + llvm::SmallSet &cacheLevels = + std::is_same_v + ? loadCacheLevels + : storeCacheLevels; + if (!cacheLevels.insert(attr.getCacheLevel()).second) + return emitError() + << "'triton_gen.decoration_cache_controls' cannot " + "specify more than one cache control decoration of " + "the same nature for the same cache level"; + return success(); + }) + .Default([emitError](Attribute attr) -> LogicalResult { + return emitError() + << "'triton_gen.decoration_cache_controls' only accepts " + "LoadCacheControlDecorationAttr and " + "StoreCacheControlDecorationAttr attributes"; + }); + if (failed(res)) + return res; + } + return success(); +} +} // namespace mlir::triton::TritonGEN diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp index 0d7dfaf7ba..cae45ade80 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp @@ -12,8 +12,6 @@ #include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/TypeSwitch.h" #include using namespace mlir; @@ -150,40 +148,6 @@ LogicalResult TritonGEN::MatrixDPASOp::verify() { return success(); } -//===----------------------------------------------------------------------===// -// triton_gen.cache_controls -//===----------------------------------------------------------------------===// - -LogicalResult TritonGEN::CacheControls::verify() { - ArrayRef cacheControls = getCacheControls().getValue(); - if (cacheControls.empty()) - return emitOpError("expecting at least one cache control decoration"); - llvm::SmallSet loadCacheLevels; - llvm::SmallSet storeCacheLevels; - for (Attribute attr : cacheControls) { - LogicalResult res = - TypeSwitch(attr) - .Case([this, &loadCacheLevels, - &storeCacheLevels]( - auto attr) - -> LogicalResult { - llvm::SmallSet &cacheLevels = - std::is_same_v - ? loadCacheLevels - : storeCacheLevels; - if (!cacheLevels.insert(attr.getCacheLevel()).second) - return emitOpError( - "cannot specify more than one cache control decoration of " - "the same nature for the same cache level"); - return success(); - }); - if (failed(res)) - return res; - } - return success(); -} - //===----------------------------------------------------------------------===// // gen.2Dblockload //===----------------------------------------------------------------------===// diff --git a/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp index b122373150..cce8f6c720 100644 --- a/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp +++ b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp @@ -31,112 +31,71 @@ class TritonGENDialectLLVMIRTranslationInterface public: using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; - constexpr static llvm::StringLiteral decorationCacheControlAttrName = - "triton_gen.DecorationCacheControlINTEL"; - LogicalResult amendOperation(Operation *op, ArrayRef instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const final { StringRef attrName = attribute.getName().getValue(); - if (attrName == decorationCacheControlAttrName) { + if (attrName == + triton::TritonGEN::TritonGENDialect::getCacheControlsAttrName()) { + auto decorationAttr = + dyn_cast( + attribute.getValue()); + if (!decorationAttr) + return op->emitOpError( + "Expecting triton_gen.decoration_cache_control attribute"); if (instructions.size() != 1) return op->emitOpError("Expecting a single instruction"); - return handleDecorationCacheControl(op, instructions.front(), attribute, - moduleTranslation); + return handleDecorationCacheControl(op, instructions.front(), + decorationAttr); } if (attrName.starts_with("triton_gen")) return handleTritonGenAttr(op, attribute, moduleTranslation); return success(); } - LogicalResult - convertOperation(Operation *operation, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) const final { - return TypeSwitch(operation) - .Case([&moduleTranslation](triton::TritonGEN::CacheControls op) { - llvm::Value *ptr = moduleTranslation.lookupValue(op.getPtr()); - moduleTranslation.mapValue(op, ptr); - Builder mlirBuilder(op); - for (OpOperand &use : op->getUses()) - appendDecoration(mlirBuilder, decorationCacheControlAttrName, - use.getOwner(), op.getCacheControls(), - use.getOperandNumber()); - return success(); - }) - .Default([](Operation *op) { - return op->emitOpError("unsupported TritonGEN operation: ") - << op->getName(); - }); - } - private: template static llvm::Metadata *getConstantIntMD(llvm::Type *type, IntTy val) { return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(type, val)); } - static void appendDecoration(Builder &mlirBuilder, llvm::StringRef mdKey, - Operation *op, ArrayAttr newAttrs, - unsigned operandNumber) { - auto attr = op->getAttrOfType(mdKey); - SmallVector attrs = attr - ? SmallVector{attr.getValue()} - : SmallVector{}; + static LogicalResult handleDecorationCacheControl( + Operation *op, llvm::Instruction *inst, + triton::TritonGEN::DecorationCacheControlAttr attribute) { + ArrayRef attrs = attribute.getDecorations(); + SmallVector decorations; + llvm::LLVMContext &ctx = inst->getContext(); + llvm::Type *i32Ty = llvm::IntegerType::getInt32Ty(ctx); llvm::transform( - newAttrs.getValue(), std::back_inserter(attrs), - [&mlirBuilder, operandNumber](Attribute attr) -> Attribute { - return TypeSwitch(attr) + attrs, std::back_inserter(decorations), + [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * { + return TypeSwitch(attr) .Case( - [&mlirBuilder, operandNumber](auto attr) { - constexpr int32_t loadCacheControlKey = 6442; - constexpr int32_t storeCacheControlKey = 6443; - constexpr int32_t key = + [&ctx, i32Ty](auto attr) { + constexpr size_t decorationCacheControlArity = 4; + constexpr uint32_t loadCacheControlKey = 6442; + constexpr uint32_t storeCacheControlKey = 6443; + constexpr uint32_t decorationKey = std::is_same_v< decltype(attr), triton::TritonGEN::LoadCacheControlDecorationAttr> ? loadCacheControlKey : storeCacheControlKey; - int32_t cacheLevel = attr.getCacheLevel(); - int32_t cacheControl = - static_cast(attr.getCacheControl()); - return mlirBuilder.getDenseI32ArrayAttr( - {key, cacheLevel, cacheControl, - static_cast(operandNumber)}); + std::array values{ + decorationKey, attr.getCacheLevel(), + static_cast(attr.getCacheControl()), + attr.getOperandNumber()}; + std::array + metadata; + llvm::transform(values, metadata.begin(), + [i32Ty](uint32_t value) { + return getConstantIntMD(i32Ty, value); + }); + return llvm::MDNode::get(ctx, metadata); }); }); - op->setAttr(mdKey, mlirBuilder.getArrayAttr(attrs)); - } - - static LogicalResult - handleDecorationCacheControl(Operation *op, llvm::Instruction *inst, - NamedAttribute attribute, - LLVM::ModuleTranslation &moduleTranslation) { - assert(attribute.getName() == decorationCacheControlAttrName && - "Expecting decoration cache key"); - auto arrayAttr = dyn_cast(attribute.getValue()); - if (!arrayAttr) - return op->emitOpError("unexpected attribute type"); - ArrayRef attrs = arrayAttr.getValue(); - SmallVector decorations; - llvm::LLVMContext &ctx = inst->getContext(); - for (Attribute attr : attrs) { - constexpr std::size_t decorationCacheControlArity = 4; - - auto arrayAttr = dyn_cast(attr); - if (!arrayAttr) - return op->emitOpError("unexpected attribute type"); - ArrayRef attrs = arrayAttr.asArrayRef(); - if (attrs.size() != decorationCacheControlArity) - return op->emitOpError("Invalid decoration cache attribute arity"); - constexpr unsigned numBits = 32; - llvm::Type *type = llvm::IntegerType::get(ctx, numBits); - std::array metadata; - llvm::transform(attrs, metadata.begin(), - [type](int val) { return getConstantIntMD(type, val); }); - decorations.push_back(llvm::MDNode::get(ctx, metadata)); - } constexpr llvm::StringLiteral decorationCacheControlMDName = "spirv.DecorationCacheControlINTEL"; inst->setMetadata(decorationCacheControlMDName, From 2ab1910911c3a42abe2bd2c7537591d04a2d9966 Mon Sep 17 00:00:00 2001 From: Victor Perez Date: Thu, 30 May 2024 17:00:15 +0100 Subject: [PATCH 7/9] Allow empty decorations --- third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENAttrs.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENAttrs.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENAttrs.cpp index 8ef13f6ecb..cecda52560 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENAttrs.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENAttrs.cpp @@ -19,9 +19,6 @@ namespace mlir::triton::TritonGEN { LogicalResult TritonGEN::DecorationCacheControlAttr::verify( llvm::function_ref emitError, ArrayRef decorations) { - if (decorations.empty()) - return emitError() << "'triton_gen.decoration_cache_controls' expecting at " - "least one cache control decoration"; llvm::SmallSet loadCacheLevels; llvm::SmallSet storeCacheLevels; for (Attribute attr : decorations) { From a2536768ee94b160116ef22b69f084706dc7a7d1 Mon Sep 17 00:00:00 2001 From: Victor Perez Date: Fri, 31 May 2024 15:36:25 +0100 Subject: [PATCH 8/9] Update tests --- test/Target/LLVMIR/triton-gen.mlir | 28 ++++++++++++---------------- test/TritonGEN/tritongen.mlir | 8 ++++---- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/test/Target/LLVMIR/triton-gen.mlir b/test/Target/LLVMIR/triton-gen.mlir index ef0a279baa..a4774ac7a2 100644 --- a/test/Target/LLVMIR/triton-gen.mlir +++ b/test/Target/LLVMIR/triton-gen.mlir @@ -25,24 +25,20 @@ llvm.func @foo(%arg0: !llvm.ptr, %arg1: !llvm.ptr) // CHECK-SAME: ptr %[[#ARG0:]]) { llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) { // CHECK: %[[#LOAD:]] = load i32, ptr %[[#ARG0]], align 4, !spirv.DecorationCacheControlINTEL ![[#DECORATION0:]] - %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<1, WriteThrough, 0>, #triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32 + %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32 // CHECK: store i32 %[[#LOAD]], ptr %[[#ARG0]], align 4, !spirv.DecorationCacheControlINTEL ![[#DECORATION1:]] - llvm.store %0, %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, WriteBack, 1>, #triton_gen.store_cache_control<1, Streaming, 1>, #triton_gen.load_cache_control<0, Streaming, 1>, #triton_gen.load_cache_control<1, InvalidateAfterRead, 1>, #triton_gen.load_cache_control<2, ConstCached, 1>>} : i32, !llvm.ptr + llvm.store %0, %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, WriteBack, 1>, #triton_gen.store_cache_control<1, Streaming, 1>>} : i32, !llvm.ptr // CHECK: call void @foo(ptr %[[#ARG0]], ptr %[[#ARG0]]), !spirv.DecorationCacheControlINTEL ![[#DECORATION2:]] - llvm.call @foo(%arg0, %arg0) {triton_gen.DecorationCacheControlINTEL =#triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<0, Cached, 1>>} : (!llvm.ptr, !llvm.ptr) -> () + llvm.call @foo(%arg0, %arg0) {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<0, Cached, 1>>} : (!llvm.ptr, !llvm.ptr) -> () llvm.return } -// CHECK-DAG: ![[#DECORATION0]] = !{![[#CACHECONTROL0:]], ![[#CACHECONTROL1:]], ![[#CACHECONTROL2:]], ![[#CACHECONTROL3:]]} -// CHECK-DAG: ![[#DECORATION1]] = !{![[#CACHECONTROL4:]], ![[#CACHECONTROL5:]], ![[#CACHECONTROL6:]], ![[#CACHECONTROL7:]], ![[#CACHECONTROL8:]]} -// CHECK-DAG: ![[#DECORATION2]] = !{![[#CACHECONTROL0:]], ![[#CACHECONTROL9:]]} -// CHECK-DAG: ![[#CACHECONTROL0]] = !{i32 6443, i32 0, i32 0, i32 0} -// CHECK-DAG: ![[#CACHECONTROL1]] = !{i32 6443, i32 1, i32 1, i32 0} -// CHECK-DAG: ![[#CACHECONTROL2]] = !{i32 6442, i32 0, i32 1, i32 0} -// CHECK-DAG: ![[#CACHECONTROL3]] = !{i32 6442, i32 1, i32 0, i32 0} -// CHECK-DAG: ![[#CACHECONTROL4]] = !{i32 6443, i32 0, i32 2, i32 1} -// CHECK-DAG: ![[#CACHECONTROL5]] = !{i32 6443, i32 1, i32 3, i32 1} -// CHECK-DAG: ![[#CACHECONTROL6]] = !{i32 6442, i32 0, i32 2, i32 1} -// CHECK-DAG: ![[#CACHECONTROL7]] = !{i32 6442, i32 1, i32 3, i32 1} -// CHECK-DAG: ![[#CACHECONTROL8]] = !{i32 6442, i32 2, i32 4, i32 1} -// CHECK-DAG: ![[#CACHECONTROL9]] = !{i32 6442, i32 0, i32 1, i32 1} +// CHECK-DAG: ![[#DECORATION0]] = !{![[#CACHECONTROL0:]], ![[#CACHECONTROL1:]]} +// CHECK-DAG: ![[#DECORATION1]] = !{![[#CACHECONTROL2:]], ![[#CACHECONTROL3:]]} +// CHECK-DAG: ![[#DECORATION2]] = !{![[#CACHECONTROL4:]], ![[#CACHECONTROL5:]]} +// CHECK-DAG: ![[#CACHECONTROL0]] = !{i32 6442, i32 0, i32 1, i32 0} +// CHECK-DAG: ![[#CACHECONTROL1]] = !{i32 6442, i32 1, i32 0, i32 0} +// CHECK-DAG: ![[#CACHECONTROL2]] = !{i32 6443, i32 0, i32 2, i32 1} +// CHECK-DAG: ![[#CACHECONTROL3]] = !{i32 6443, i32 1, i32 3, i32 1} +// CHECK-DAG: ![[#CACHECONTROL4]] = !{i32 6443, i32 0, i32 0, i32 0} +// CHECK-DAG: ![[#CACHECONTROL5]] = !{i32 6442, i32 0, i32 1, i32 1} diff --git a/test/TritonGEN/tritongen.mlir b/test/TritonGEN/tritongen.mlir index 8fe926dfbd..ba5d839805 100644 --- a/test/TritonGEN/tritongen.mlir +++ b/test/TritonGEN/tritongen.mlir @@ -109,10 +109,10 @@ llvm.func @foo(%arg0: !llvm.ptr, %arg1: !llvm.ptr) llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) { // CHECK: llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) - // CHECK-NEXT: %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<1, WriteThrough, 0>, #triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32 - %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<1, WriteThrough, 0>, #triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32 - // CHECK-NEXT: llvm.store %0, %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, WriteBack, 1>, #triton_gen.store_cache_control<1, Streaming, 1>, #triton_gen.load_cache_control<0, Streaming, 1>, #triton_gen.load_cache_control<1, InvalidateAfterRead, 1>, #triton_gen.load_cache_control<2, ConstCached, 1>>} : i32, !llvm.ptr - llvm.store %0, %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, WriteBack, 1>, #triton_gen.store_cache_control<1, Streaming, 1>, #triton_gen.load_cache_control<0, Streaming, 1>, #triton_gen.load_cache_control<1, InvalidateAfterRead, 1>, #triton_gen.load_cache_control<2, ConstCached, 1>>} : i32, !llvm.ptr + // CHECK-NEXT: %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32 + %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32 + // CHECK-NEXT: llvm.store %0, %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, WriteBack, 1>, #triton_gen.store_cache_control<1, Streaming, 1>>} : i32, !llvm.ptr + llvm.store %0, %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, WriteBack, 1>, #triton_gen.store_cache_control<1, Streaming, 1>>} : i32, !llvm.ptr // CHECK-NEXT: llvm.call @foo(%arg0, %arg0) {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<0, Cached, 1>>} : (!llvm.ptr, !llvm.ptr) -> () llvm.call @foo(%arg0, %arg0) {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<0, Cached, 1>>} : (!llvm.ptr, !llvm.ptr) -> () llvm.return From 440b16d9ad3f9b8357d6df25d91b8ccc9593b038 Mon Sep 17 00:00:00 2001 From: Victor Perez Date: Fri, 31 May 2024 16:02:58 +0100 Subject: [PATCH 9/9] Update invalid test --- test/TritonGEN/tritongen-invalid.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/TritonGEN/tritongen-invalid.mlir b/test/TritonGEN/tritongen-invalid.mlir index fb0d117d60..20af5ff2d8 100644 --- a/test/TritonGEN/tritongen-invalid.mlir +++ b/test/TritonGEN/tritongen-invalid.mlir @@ -2,7 +2,7 @@ llvm.func @triton_gen.duplicated_cache_controls(%arg0: !llvm.ptr) { // expected-error @+1 {{'triton_gen.decoration_cache_controls' cannot specify more than one cache control decoration of the same nature for the same cache level}} - %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<0, Streaming, 0>>} : !llvm.ptr -> i32 + %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<0, Cached, 0>>} : !llvm.ptr -> i32 llvm.return } @@ -10,7 +10,7 @@ llvm.func @triton_gen.duplicated_cache_controls(%arg0: !llvm.ptr) { llvm.func @triton_gen.illegal_cache_controls_attr(%arg0: !llvm.ptr) { // expected-error @+1 {{'triton_gen.decoration_cache_controls' only accepts LoadCacheControlDecorationAttr and StoreCacheControlDecorationAttr attributes}} - %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL =#triton_gen.decoration_cache_control<1 : i32>} : !llvm.ptr -> i32 + %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<1 : i32>} : !llvm.ptr -> i32 llvm.return }