Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TritonGEN] Add triton_gen.cache_controls operation #1087

Merged
merged 13 commits into from
Jun 1, 2024
30 changes: 30 additions & 0 deletions test/Target/LLVMIR/triton-gen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,33 @@ llvm.func spir_kernelcc @test_reqd_work_group_size() attributes {triton_gen.reqd
// CHECK-DAG: ![[REQD_SUB_GROUP_SIZE]] = !{i64 32}
// CHECK-DAG: ![[MAX_WORK_GROUP_SIZE]] = !{i64 128, i64 1, i64 1}
// CHECK-DAG: ![[REQD_WORK_GROUP_SIZE]] = !{i64 128, i64 1, i64 2}

// -----

llvm.func @foo(%arg0: !llvm.ptr, %arg1: !llvm.ptr)

// CHECK-LABEL: define void @triton_gen.cache_controls(
// CHECK-SAME: ptr %[[#ARG0:]]) {
llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) {
// CHECK: %[[#LOAD:]] = load i32, ptr %[[#ARG0]], align 4, !spirv.DecorationCacheControlINTEL ![[#DECORATION0:]]
%0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<1, WriteThrough, 0>, #triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32
// CHECK: store i32 %[[#LOAD]], ptr %[[#ARG0]], align 4, !spirv.DecorationCacheControlINTEL ![[#DECORATION1:]]
llvm.store %0, %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, WriteBack, 1>, #triton_gen.store_cache_control<1, Streaming, 1>, #triton_gen.load_cache_control<0, Streaming, 1>, #triton_gen.load_cache_control<1, InvalidateAfterRead, 1>, #triton_gen.load_cache_control<2, ConstCached, 1>>} : i32, !llvm.ptr
// CHECK: call void @foo(ptr %[[#ARG0]], ptr %[[#ARG0]]), !spirv.DecorationCacheControlINTEL ![[#DECORATION2:]]
llvm.call @foo(%arg0, %arg0) {triton_gen.DecorationCacheControlINTEL =#triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<0, Cached, 1>>} : (!llvm.ptr, !llvm.ptr) -> ()
victor-eds marked this conversation as resolved.
Show resolved Hide resolved
llvm.return
}

// CHECK-DAG: ![[#DECORATION0]] = !{![[#CACHECONTROL0:]], ![[#CACHECONTROL1:]], ![[#CACHECONTROL2:]], ![[#CACHECONTROL3:]]}
// CHECK-DAG: ![[#DECORATION1]] = !{![[#CACHECONTROL4:]], ![[#CACHECONTROL5:]], ![[#CACHECONTROL6:]], ![[#CACHECONTROL7:]], ![[#CACHECONTROL8:]]}
// CHECK-DAG: ![[#DECORATION2]] = !{![[#CACHECONTROL0:]], ![[#CACHECONTROL9:]]}
// CHECK-DAG: ![[#CACHECONTROL0]] = !{i32 6443, i32 0, i32 0, i32 0}
// CHECK-DAG: ![[#CACHECONTROL1]] = !{i32 6443, i32 1, i32 1, i32 0}
// CHECK-DAG: ![[#CACHECONTROL2]] = !{i32 6442, i32 0, i32 1, i32 0}
// CHECK-DAG: ![[#CACHECONTROL3]] = !{i32 6442, i32 1, i32 0, i32 0}
// CHECK-DAG: ![[#CACHECONTROL4]] = !{i32 6443, i32 0, i32 2, i32 1}
// CHECK-DAG: ![[#CACHECONTROL5]] = !{i32 6443, i32 1, i32 3, i32 1}
// CHECK-DAG: ![[#CACHECONTROL6]] = !{i32 6442, i32 0, i32 2, i32 1}
// CHECK-DAG: ![[#CACHECONTROL7]] = !{i32 6442, i32 1, i32 3, i32 1}
// CHECK-DAG: ![[#CACHECONTROL8]] = !{i32 6442, i32 2, i32 4, i32 1}
// CHECK-DAG: ![[#CACHECONTROL9]] = !{i32 6442, i32 0, i32 1, i32 1}
16 changes: 16 additions & 0 deletions test/TritonGEN/tritongen-invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
// RUN: triton-opt -split-input-file -verify-diagnostics %s

llvm.func @triton_gen.duplicated_cache_controls(%arg0: !llvm.ptr) {
// expected-error @+1 {{'triton_gen.decoration_cache_controls' cannot specify more than one cache control decoration of the same nature for the same cache level}}
%0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<0, Streaming, 0>>} : !llvm.ptr -> i32
victor-eds marked this conversation as resolved.
Show resolved Hide resolved
llvm.return
}

// -----

llvm.func @triton_gen.illegal_cache_controls_attr(%arg0: !llvm.ptr) {
// expected-error @+1 {{'triton_gen.decoration_cache_controls' only accepts LoadCacheControlDecorationAttr and StoreCacheControlDecorationAttr attributes}}
%0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL =#triton_gen.decoration_cache_control<1 : i32>} : !llvm.ptr -> i32
llvm.return
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) {
// expected-error @+1 {{'triton_gen.dpas' op expecting repeat count to be 1, 2, 4, or 8}}
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=16} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32>
Expand Down
13 changes: 13 additions & 0 deletions test/TritonGEN/tritongen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,19 @@ llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8
llvm.return
}

llvm.func @foo(%arg0: !llvm.ptr, %arg1: !llvm.ptr)

llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) {
// CHECK: llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr)
// CHECK-NEXT: %0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<1, WriteThrough, 0>, #triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32
%0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<1, WriteThrough, 0>, #triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32
// CHECK-NEXT: llvm.store %0, %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, WriteBack, 1>, #triton_gen.store_cache_control<1, Streaming, 1>, #triton_gen.load_cache_control<0, Streaming, 1>, #triton_gen.load_cache_control<1, InvalidateAfterRead, 1>, #triton_gen.load_cache_control<2, ConstCached, 1>>} : i32, !llvm.ptr
llvm.store %0, %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, WriteBack, 1>, #triton_gen.store_cache_control<1, Streaming, 1>, #triton_gen.load_cache_control<0, Streaming, 1>, #triton_gen.load_cache_control<1, InvalidateAfterRead, 1>, #triton_gen.load_cache_control<2, ConstCached, 1>>} : i32, !llvm.ptr
// CHECK-NEXT: llvm.call @foo(%arg0, %arg0) {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<0, Cached, 1>>} : (!llvm.ptr, !llvm.ptr) -> ()
llvm.call @foo(%arg0, %arg0) {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.load_cache_control<0, Cached, 1>>} : (!llvm.ptr, !llvm.ptr) -> ()
llvm.return
}

llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
// CHECK: llvm.func @triton_gen.2Dblockload(%arg0: !llvm.ptr, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) {
// CHECK-NEXT: %0 = triton_gen.2Dblockload %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 1, transpose = false, vnni_transform = false, cache_control = Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<16xf16>
Expand Down
107 changes: 107 additions & 0 deletions third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,17 @@
#ifndef TRITONGEN_ATTRDEFS
#define TRITONGEN_ATTRDEFS

include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td"

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"

class TritonGEN_Attr<string name, string attrMnemonic, list<Trait> traits = []>
: AttrDef<TritonGEN_Dialect, name, traits> {
let mnemonic = attrMnemonic;
let cppNamespace = "::mlir::triton::TritonGEN";
}

/// Enum attribute of the different shuffle kinds.
def TritonGEN_ShflKindAttr : I32EnumAttr<"ShflKind", "TritonGEN shuffle kind",
[
Expand Down Expand Up @@ -109,4 +118,102 @@ def TritonGEN_MemScope : I32EnumAttr<"MemScope",
let cppNamespace = "::mlir::triton::TritonGEN";
}

class TritonGEN_LoadStoreCacheControlDecoration<string loadOrStore>
: TritonGEN_Attr<loadOrStore # "CacheControlDecoration", !tolower(loadOrStore) # "_cache_control"> {
let summary = "An attribute specifying " # !tolower(loadOrStore) # " cache control";
let description = [{
A }] # !tolower(loadOrStore) # [{ cache control attribute has a one-to-one
correspondance with the SPIR-V decoration shown in
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.asciidoc#decorations. The
only differences are there is no need to add the SPIR-V decoration key, as
that will be inferred from the attribute type, and an additional
`operand_number` parameter is needed, as this is applied to pointer user
operations, following
https://github.com/KhronosGroup/SPIRV-LLVM-Translator/pull/2587 design.
}];
let parameters = (ins "uint32_t":$cache_level,
loadOrStore # "CacheControlDecorationEnum":$cache_control,
"uint32_t":$operand_number);
let assemblyFormat = "`<` $cache_level `,` $cache_control `,` $operand_number `>`";
}

/// Enum attribute for load cache controls.
///
/// See
/// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.asciidoc#Load_Cache_Control.
def TritonGEN_LoadCacheControlDecorationEnum : I32EnumAttr<"LoadCacheControlDecorationEnum",
"TritonGEN load cache controls",
[ I32EnumAttrCase<"Uncached", 0, "Uncached">,
I32EnumAttrCase<"Cached", 1, "Cached">,
I32EnumAttrCase<"Streaming", 2, "Streaming">,
I32EnumAttrCase<"InvalidateAfterRead", 3, "InvalidateAfterRead">,
I32EnumAttrCase<"ConstCached", 4, "ConstCached">
]> {
let cppNamespace = "::mlir::triton::TritonGEN";
}

def TritonGEN_LoadCacheControlDecoration
: TritonGEN_LoadStoreCacheControlDecoration<"Load">;

/// Enum attribute for store cache controls.
///
/// See
/// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.asciidoc#Store_Cache_Control.
def TritonGEN_StoreCacheControlDecorationEnum : I32EnumAttr<"StoreCacheControlDecorationEnum",
"TritonGEN store cache controls",
[ I32EnumAttrCase<"Uncached", 0, "Uncached">,
I32EnumAttrCase<"WriteThrough", 1, "WriteThrough">,
I32EnumAttrCase<"WriteBack", 2, "WriteBack">,
I32EnumAttrCase<"Streaming", 3, "Streaming">,
]> {
let cppNamespace = "::mlir::triton::TritonGEN";
}

def TritonGEN_StoreCacheControlDecoration
: TritonGEN_LoadStoreCacheControlDecoration<"Store">;

def TritonGEN_CacheControlsDecoration
: TritonGEN_Attr<"DecorationCacheControl", "decoration_cache_control"> {
let summary = "An attribute specifying an operation cache controls";
let description = [{
Attribute corresponding to `!spirv.DecorationCacheControlINTEL` metadata
described in
https://github.com/KhronosGroup/SPIRV-LLVM-Translator/pull/2587. This
metadata is a list of `LoadCacheControlDecoration` and
`StoreCacheControlDecoration` attributes specifying cache control
information.

The following MLIR code using this attribute:

```mlir
llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) {
%0 = llvm.load %arg0 {triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.store_cache_control<0, Uncached, 0>, #triton_gen.store_cache_control<1, WriteThrough, 0>, #triton_gen.load_cache_control<0, Cached, 0>, #triton_gen.load_cache_control<1, Uncached, 0>>} : !llvm.ptr -> i32
llvm.return
}
```

Will be translated to the following LLVM IR:

```llvm
define void @triton_gen.cache_controls(ptr %0) {
%1 = load i32, ptr %0, align 4, !spirv.DecorationCacheControlINTEL !1
ret void
}

!1 = !{!2, !3, !4, !5}
!2 = !{i32 6443, i32 0, i32 0, i32 0}
!3 = !{i32 6443, i32 1, i32 1, i32 0}
!4 = !{i32 6442, i32 0, i32 1, i32 0}
!5 = !{i32 6442, i32 1, i32 0, i32 0}
```

$decorations must be a non-empty list of cache controls attributes and two
attributes of the same nature (load or store) in the list cannot be applied
to the same cache level, as per the SPIR-V validation rules.
}];
let parameters = (ins ArrayRefParameter<"::mlir::Attribute">:$decorations);
let assemblyFormat = "`<` $decorations `>`";
let genVerifyDecl = 1;
}

#endif // TRITONGEN_ATTRDEFS
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def TritonGEN_Dialect : Dialect {
let name = "triton_gen";
let cppNamespace = "::mlir::triton::TritonGEN";
let summary = "The TritonGEN dialect in Triton.";
let useDefaultAttributePrinterParser = 1;

let description = [{
TritonGEN is a dialect for representing operations on Intel GPUs.
Expand All @@ -40,6 +41,12 @@ def TritonGEN_Dialect : Dialect {
static constexpr ::llvm::StringLiteral getReqdSubGroupSizeAttrName() {
return ::llvm::StringLiteral("triton_gen.intel_reqd_sub_group_size");
}

/// Get the name for the attribute used to specify cache control
/// decorations.
static constexpr ::llvm::StringRef getCacheControlsAttrName() {
return ::llvm::StringLiteral("triton_gen.DecorationCacheControlINTEL");
}
}];
}

Expand Down
1 change: 1 addition & 0 deletions third_party/intel/lib/Dialect/TritonGEN/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_triton_library(TritonGENIR
TritonGENAttrs.cpp
TritonGENDialect.cpp
TritonGENOps.cpp

Expand Down
57 changes: 57 additions & 0 deletions third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENAttrs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
//===- TritonGENAttrs.cpp - TritonGEN Attributes Definition --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"

#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"

namespace mlir::triton::TritonGEN {
//===----------------------------------------------------------------------===//
// triton_gen.decoration_cache_control
//===----------------------------------------------------------------------===//

LogicalResult TritonGEN::DecorationCacheControlAttr::verify(
llvm::function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Attribute> decorations) {
if (decorations.empty())
return emitError() << "'triton_gen.decoration_cache_controls' expecting at "
"least one cache control decoration";
llvm::SmallSet<uint32_t, 3> loadCacheLevels;
llvm::SmallSet<uint32_t, 3> storeCacheLevels;
for (Attribute attr : decorations) {
LogicalResult res =
TypeSwitch<Attribute, LogicalResult>(attr)
.Case<LoadCacheControlDecorationAttr,
StoreCacheControlDecorationAttr>([emitError, &loadCacheLevels,
&storeCacheLevels](
auto attr)
-> LogicalResult {
llvm::SmallSet<uint32_t, 3> &cacheLevels =
std::is_same_v<decltype(attr), LoadCacheControlDecorationAttr>
? loadCacheLevels
: storeCacheLevels;
if (!cacheLevels.insert(attr.getCacheLevel()).second)
return emitError()
<< "'triton_gen.decoration_cache_controls' cannot "
"specify more than one cache control decoration of "
"the same nature for the same cache level";
return success();
})
.Default([emitError](Attribute attr) -> LogicalResult {
return emitError()
<< "'triton_gen.decoration_cache_controls' only accepts "
"LoadCacheControlDecorationAttr and "
"StoreCacheControlDecorationAttr attributes";
});
if (failed(res))
return res;
}
return success();
}
} // namespace mlir::triton::TritonGEN
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,77 @@ class TritonGENDialectLLVMIRTranslationInterface
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
// Skip the attribute if it is not a TritonGEN attribute.
if (!attribute.getName().getValue().starts_with("triton_gen"))
return success();
StringRef attrName = attribute.getName().getValue();
if (attrName ==
triton::TritonGEN::TritonGENDialect::getCacheControlsAttrName()) {
auto decorationAttr =
dyn_cast<triton::TritonGEN::DecorationCacheControlAttr>(
attribute.getValue());
if (!decorationAttr)
return op->emitOpError(
"Expecting triton_gen.decoration_cache_control attribute");
if (instructions.size() != 1)
return op->emitOpError("Expecting a single instruction");
return handleDecorationCacheControl(op, instructions.front(),
decorationAttr);
}
if (attrName.starts_with("triton_gen"))
return handleTritonGenAttr(op, attribute, moduleTranslation);
return success();
}

private:
template <typename IntTy>
static llvm::Metadata *getConstantIntMD(llvm::Type *type, IntTy val) {
return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(type, val));
}

static LogicalResult handleDecorationCacheControl(
Operation *op, llvm::Instruction *inst,
triton::TritonGEN::DecorationCacheControlAttr attribute) {
ArrayRef<Attribute> attrs = attribute.getDecorations();
SmallVector<llvm::Metadata *> decorations;
llvm::LLVMContext &ctx = inst->getContext();
llvm::Type *i32Ty = llvm::IntegerType::getInt32Ty(ctx);
llvm::transform(
attrs, std::back_inserter(decorations),
[&ctx, i32Ty](Attribute attr) -> llvm::Metadata * {
return TypeSwitch<Attribute, llvm::Metadata *>(attr)
.Case<triton::TritonGEN::LoadCacheControlDecorationAttr,
triton::TritonGEN::StoreCacheControlDecorationAttr>(
[&ctx, i32Ty](auto attr) {
constexpr size_t decorationCacheControlArity = 4;
constexpr uint32_t loadCacheControlKey = 6442;
constexpr uint32_t storeCacheControlKey = 6443;
constexpr uint32_t decorationKey =
std::is_same_v<
decltype(attr),
triton::TritonGEN::LoadCacheControlDecorationAttr>
? loadCacheControlKey
: storeCacheControlKey;
std::array<uint32_t, decorationCacheControlArity> values{
decorationKey, attr.getCacheLevel(),
static_cast<uint32_t>(attr.getCacheControl()),
attr.getOperandNumber()};
std::array<llvm::Metadata *, decorationCacheControlArity>
metadata;
llvm::transform(values, metadata.begin(),
[i32Ty](uint32_t value) {
return getConstantIntMD(i32Ty, value);
});
return llvm::MDNode::get(ctx, metadata);
});
});
constexpr llvm::StringLiteral decorationCacheControlMDName =
"spirv.DecorationCacheControlINTEL";
inst->setMetadata(decorationCacheControlMDName,
llvm::MDNode::get(ctx, decorations));
return success();
}

LogicalResult
handleTritonGenAttr(Operation *op, NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const {
llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
llvm::Function *llvmFunc =
moduleTranslation.lookupFunction(cast<LLVM::LLVMFuncOp>(op).getName());
Expand All @@ -47,7 +114,6 @@ class TritonGENDialectLLVMIRTranslationInterface
return success();
}

private:
// Checks if the given operation is a kernel function.
bool isKernel(Operation *op) const {
auto fn = dyn_cast<LLVM::LLVMFuncOp>(op);
Expand Down