Skip to content

Commit

Permalink
Avoid post-processing generated LLVM IR
Browse files Browse the repository at this point in the history
Signed-off-by: Whitney Tsang <whitney.tsang@intel.com>
  • Loading branch information
whitneywhtsang committed May 18, 2024
1 parent 3acbc5d commit e3064fd
Show file tree
Hide file tree
Showing 14 changed files with 147 additions and 83 deletions.
1 change: 1 addition & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS)
get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)

add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED)
Expand Down
69 changes: 1 addition & 68 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,70 +135,6 @@ static uint32_t findKernels(llvm::Module &M,
return numKernels;
}

/// Amend SPIR kernels in the given LLVM module by translating GEN passthrough
/// attributes into LLVM metadata.
static void amendLLVMIR(llvm::Module &llvmMod, llvm::LLVMContext &ctx) {
// Collect SPIR kernels.
std::set<llvm::Function *> kernels;
uint32_t numKernels = findKernels(llvmMod, kernels);
assert(numKernels == 1 && "Expecting a single SPIR kernel");
llvm::Function *kernel = *kernels.begin();

// Given a string \p str of the form "n1,n2,...", parse it as a
// vector of integers (n1,n2,...).
auto extractFromString = [](StringRef str) -> SmallVector<int64_t> {
auto parseAsInt = [](StringRef str, int64_t &intVal) {
bool failed = str.getAsInteger(10, intVal);
return !failed;
};

SmallVector<int64_t> result;
std::pair<StringRef, StringRef> pair;
do {
pair = str.split(',');
str = pair.second;
int64_t intVal;
if (!parseAsInt(pair.first, intVal))
break;

result.push_back(intVal);
} while (true);

return result;
};

// Attach metadata to \p func given its name \p attrName and value \p attrVal.
auto attachMetadata = [&](StringRef attrName, StringRef attrVal,
llvm::Function *func) {
SmallVector<llvm::Metadata *, 3> metadata;
llvm::Type *i64 = llvm::IntegerType::get(ctx, 64);
for (int64_t val : extractFromString(attrVal))
metadata.push_back(
llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i64, val)));

llvm::MDNode *node = llvm::MDNode::get(ctx, metadata);
func->setMetadata(attrName, node);
};

// Attach required metadata to the kernel.
using namespace mlir::triton;
SmallVector<llvm::StringLiteral> genAttrs{
TritonGEN::TritonGENDialect::getMaxWorkGroupSizeAttrName(),
TritonGEN::TritonGENDialect::getReqdWorkGroupSizeAttrName(),
TritonGEN::TritonGENDialect::getReqdSubGroupSizeAttrName()};

for (llvm::StringLiteral genAttr : genAttrs) {
if (!kernel->hasFnAttribute(genAttr))
continue;

Attribute fnAttr = kernel->getFnAttribute(genAttr);
assert(fnAttr.isStringAttribute() && "Expecting a string attribute");
attachMetadata(fnAttr.getKindAsString().split('.').second,
fnAttr.getValueAsString(), kernel);
kernel->removeFnAttr(genAttr);
}
}

void init_triton_llvm(py::module &&m) {

py::class_<llvm::LLVMContext>(m, "context", py::module_local())
Expand Down Expand Up @@ -271,10 +207,7 @@ void init_triton_llvm(py::module &&m) {
m.def(
"to_module",
[](mlir::ModuleOp &mod, llvm::LLVMContext &ctx) {
std::unique_ptr<llvm::Module> llvmMod =
mlir::translateModuleToLLVMIR(mod, ctx);
amendLLVMIR(*llvmMod, ctx);
return llvmMod;
return mlir::translateModuleToLLVMIR(mod, ctx);
},
py::keep_alive<0, 2>());

Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/intel/tritongpu_to_gen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func spir_kernelcc @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// CHECK-SAME: attributes {passthrough = {{.*}}"gen.intel_reqd_sub_group_size", "32"{{.*}}"gen.max_work_group_size", "128,1,1"{{.*}}} {
// CHECK-SAME: attributes {triton_gen.intel_reqd_sub_group_size = [32 : i32], triton_gen.max_work_group_size = [128 : i32, 1 : i32, 1 : i32]} {
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
tt.return
Expand Down
1 change: 1 addition & 0 deletions third_party/intel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_triton_plugin(TritonXPU

LINK_LIBS
TritonGENToLLVM
TritonGENToLLVMIRTranslation
TritonIntelGPUToLLVM
TritonIntelGPUTransforms
)
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ def TritonGEN_Dialect : Dialect {
/// Get the name of the attribute used to annotate max work group size
/// required for kernels.
static constexpr ::llvm::StringLiteral getMaxWorkGroupSizeAttrName() {
return ::llvm::StringLiteral("gen.max_work_group_size");
return ::llvm::StringLiteral("triton_gen.max_work_group_size");
}

/// Get the name of the attribute used to annotate exact work group size
/// required for kernels.
static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
return ::llvm::StringLiteral("gen.reqd_work_group_size");
return ::llvm::StringLiteral("triton_gen.reqd_work_group_size");
}

/// Get the name for the attribute used to annotate the exact sub group
/// size required for kernels.
static constexpr ::llvm::StringLiteral getReqdSubGroupSizeAttrName() {
return ::llvm::StringLiteral("gen.intel_reqd_sub_group_size");
return ::llvm::StringLiteral("triton_gen.intel_reqd_sub_group_size");
}
}];
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//===-TritonGENToLLVMIRTranslation.h-TritonGEN Dialect to LLVM IR - C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This provides registration calls for TritonGEN dialect to LLVM IR
// translation.
//
//===----------------------------------------------------------------------===//

#ifndef TRITON_TARGET_LLVMIR_DIALECT_TRITONGEN_TRITONGENTOLLVMIRTRANSLATION_H
#define TRITON_TARGET_LLVMIR_DIALECT_TRITONGEN_TRITONGENTOLLVMIRTRANSLATION_H

namespace mlir {

class DialectRegistry;
class MLIRContext;

/// Register the TritonGEN dialect and the translation from it to the LLVM IR in
/// the given registry;
void registerTritonGENDialectTranslation(DialectRegistry &registry);

/// Register the TritonGEN dialect and the translation from it in the registry
/// associated with the given context.
void registerTritonGENDialectTranslation(MLIRContext &context);

} // namespace mlir

#endif // TRITON_TARGET_LLVMIR_DIALECT_TRITONGEN_TRITONGENTOLLVMIRTRANSLATION_H
1 change: 1 addition & 0 deletions third_party/intel/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_subdirectory(Dialect)
add_subdirectory(GPUToTritonGEN)
add_subdirectory(Target)
add_subdirectory(TritonGENToLLVM)
add_subdirectory(TritonIntelGPUToLLVM)
add_subdirectory(TritonIntelGPUTransforms)
1 change: 1 addition & 0 deletions third_party/intel/lib/Target/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(LLVMIR)
1 change: 1 addition & 0 deletions third_party/intel/lib/Target/LLVMIR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(Dialect)
1 change: 1 addition & 0 deletions third_party/intel/lib/Target/LLVMIR/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(TritonGEN)
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
add_mlir_translation_library(TritonGENToLLVMIRTranslation
TritonGENToLLVMIRTranslation.cpp

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRIR
TritonGENIR
MLIRLLVMDialect
MLIRSupport
MLIRTargetLLVMIRExport
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
//===-TritonGENToLLVMIRTranslation.cpp - TritonGEN Dialect to LLVM IR -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a translation between the TritonGEN dialect and LLVM IR.
//
//===----------------------------------------------------------------------===//

#include "Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.h"

#include "Dialect/TritonGEN/IR/TritonGENDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"

namespace {
using namespace mlir;
class TritonGENDialectLLVMIRTranslationInterface
: public LLVMTranslationDialectInterface {
public:
using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;

LogicalResult
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
llvm::Function *llvmFunc =
moduleTranslation.lookupFunction(cast<LLVM::LLVMFuncOp>(op).getName());
if (isKernel(op))
amendKernel(llvmContext, llvmFunc, attribute);
return success();
}

private:
// Checks if the given operation is a kernel function.
bool isKernel(Operation *op) const {
auto fn = dyn_cast<LLVM::LLVMFuncOp>(op);
return fn && fn.getCConv() == LLVM::CConv::SPIR_KERNEL;
}

// The attribute is converted into metadata and added to the function.
void amendKernel(llvm::LLVMContext &llvmContext, llvm::Function *llvmFunc,
NamedAttribute attribute) const {
SmallVector<llvm::Metadata *, 3> metadata;
llvm::Type *i64 = llvm::IntegerType::get(llvmContext, 64);
for (int64_t i :
extractFromIntegerArrayAttr<int64_t>(attribute.getValue())) {
llvm::Constant *constant = llvm::ConstantInt::get(i64, i);
metadata.push_back(llvm::ConstantAsMetadata::get(constant));
}
llvm::MDNode *node = llvm::MDNode::get(llvmContext, metadata);
llvmFunc->setMetadata(attribute.getName().getValue().drop_front(11), node);
}
};
} // namespace

namespace mlir {
void registerTritonGENDialectTranslation(DialectRegistry &registry) {
registry.insert<triton::TritonGEN::TritonGENDialect>();
registry.addExtension(
+[](MLIRContext *ctx, triton::TritonGEN::TritonGENDialect *dialect) {
dialect->addInterfaces<TritonGENDialectLLVMIRTranslationInterface>();
});
}

void registerTritonGENDialectTranslation(MLIRContext &context) {
DialectRegistry registry;
registerTritonGENDialectTranslation(registry);
context.appendDialectRegistry(registry);
}
} // namespace mlir
17 changes: 6 additions & 11 deletions third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,12 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
if (LLVM::isKernel(funcOp))
newFuncOp.setCConv(LLVM::CConv::SPIR_KERNEL);

auto maxWorkGroupSizeAttr = rewriter.getArrayAttr(
{rewriter.getStringAttr(
TritonGEN::TritonGENDialect::getMaxWorkGroupSizeAttrName()),
rewriter.getStringAttr(std::to_string(threadsPerWarp * numWarps) +
",1,1")});
auto reqSubGroupSizeAttr = rewriter.getArrayAttr(
{rewriter.getStringAttr(
TritonGEN::TritonGENDialect::getReqdSubGroupSizeAttrName()),
rewriter.getStringAttr(std::to_string(threadsPerWarp))});
newFuncOp.setPassthroughAttr(
ArrayAttr::get(ctx, {reqSubGroupSizeAttr, maxWorkGroupSizeAttr}));
NamedAttrList attrs;
attrs.append(TritonGEN::TritonGENDialect::getMaxWorkGroupSizeAttrName(),
rewriter.getI32ArrayAttr({threadsPerWarp * numWarps, 1, 1}));
attrs.append(TritonGEN::TritonGENDialect::getReqdSubGroupSizeAttrName(),
rewriter.getI32ArrayAttr({threadsPerWarp}));
newFuncOp->setDialectAttrs(attrs);

if (!LLVM::isKernel(funcOp)) {
newFuncOp.setPassthroughAttr(
Expand Down
2 changes: 2 additions & 0 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
#include "intel/include/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.h"
#include "intel/include/TritonIntelGPUToLLVM/Passes.h"

#include "triton/Conversion/TritonToTritonGPU/Passes.h"
Expand Down Expand Up @@ -74,6 +75,7 @@ void init_triton_intel(py::module &&m) {
mlir::DialectRegistry registry;
registry
.insert<TritonGEN::TritonGENDialect, intel::TritonIntelGPUDialect>();
mlir::registerTritonGENDialectTranslation(registry);
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
});
Expand Down

0 comments on commit e3064fd

Please sign in to comment.