Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid post-processing generated LLVM IR #1150

Merged
merged 4 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS)
get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)

add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED)
Expand Down Expand Up @@ -95,3 +96,26 @@ target_link_libraries(triton-llvm-opt PRIVATE
TritonIntelGPUIR
)
export_executable_symbols_for_plugins(triton-llvm-opt)

add_llvm_executable(triton-translate
triton-translate.cpp

PARTIAL_SOURCES_INTENDED
DEPENDS
intrinsics_gen
SUPPORT_PLUGINS
)
llvm_update_compile_flags(triton-translate)
target_link_libraries(triton-translate
PRIVATE
${dialect_libs}
${translation_libs}
MLIRIR
MLIRParser
MLIRPass
MLIRSPIRVDialect
MLIRTranslateLib
MLIRSupport
TritonGENToLLVMIRTranslation
)
mlir_check_link_libraries(triton-translate)
51 changes: 51 additions & 0 deletions bin/triton-translate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//===- triton-translate.cpp - Triton Translate Driver ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is a command line utility that translates a file from/to MLIR using one
// of the registered translations.
//
//===----------------------------------------------------------------------===//

#include "intel/include/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/InitAllTranslations.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/All.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
#include "mlir/Tools/mlir-translate/Translation.h"
#include "llvm/IR/Module.h"

using namespace mlir;

namespace mlir {
inline void registerTritonTranslations() {
static TranslateFromMLIRRegistration registration(
"triton-to-llvmir", "Translate Triton to LLVMIR",
[](Operation *op, raw_ostream &output) {
llvm::LLVMContext llvmContext;
auto llvmModule = translateModuleToLLVMIR(op, llvmContext);
if (!llvmModule)
return failure();

llvmModule->print(output, nullptr);
return success();
},
[](DialectRegistry &registry) {
registry.insert<func::FuncDialect>();
registerAllToLLVMIRTranslations(registry);
registerTritonGENDialectTranslation(registry);
});
}
} // namespace mlir

int main(int argc, char **argv) {
registerAllTranslations();
registerTritonTranslations();
return failed(mlirTranslateMain(argc, argv, "MLIR Translation Testing Tool"));
}
69 changes: 1 addition & 68 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,70 +135,6 @@ static uint32_t findKernels(llvm::Module &M,
return numKernels;
}

/// Amend SPIR kernels in the given LLVM module by translating GEN passthrough
/// attributes into LLVM metadata.
static void amendLLVMIR(llvm::Module &llvmMod, llvm::LLVMContext &ctx) {
// Collect SPIR kernels.
std::set<llvm::Function *> kernels;
uint32_t numKernels = findKernels(llvmMod, kernels);
assert(numKernels == 1 && "Expecting a single SPIR kernel");
llvm::Function *kernel = *kernels.begin();

// Given a string \p str of the form "n1,n2,...", parse it as a
// vector of integers (n1,n2,...).
auto extractFromString = [](StringRef str) -> SmallVector<int64_t> {
auto parseAsInt = [](StringRef str, int64_t &intVal) {
bool failed = str.getAsInteger(10, intVal);
return !failed;
};

SmallVector<int64_t> result;
std::pair<StringRef, StringRef> pair;
do {
pair = str.split(',');
str = pair.second;
int64_t intVal;
if (!parseAsInt(pair.first, intVal))
break;

result.push_back(intVal);
} while (true);

return result;
};

// Attach metadata to \p func given its name \p attrName and value \p attrVal.
auto attachMetadata = [&](StringRef attrName, StringRef attrVal,
llvm::Function *func) {
SmallVector<llvm::Metadata *, 3> metadata;
llvm::Type *i64 = llvm::IntegerType::get(ctx, 64);
for (int64_t val : extractFromString(attrVal))
metadata.push_back(
llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i64, val)));

llvm::MDNode *node = llvm::MDNode::get(ctx, metadata);
func->setMetadata(attrName, node);
};

// Attach required metadata to the kernel.
using namespace mlir::triton;
SmallVector<llvm::StringLiteral> genAttrs{
TritonGEN::TritonGENDialect::getMaxWorkGroupSizeAttrName(),
TritonGEN::TritonGENDialect::getReqdWorkGroupSizeAttrName(),
TritonGEN::TritonGENDialect::getReqdSubGroupSizeAttrName()};

for (llvm::StringLiteral genAttr : genAttrs) {
if (!kernel->hasFnAttribute(genAttr))
continue;

Attribute fnAttr = kernel->getFnAttribute(genAttr);
assert(fnAttr.isStringAttribute() && "Expecting a string attribute");
attachMetadata(fnAttr.getKindAsString().split('.').second,
fnAttr.getValueAsString(), kernel);
kernel->removeFnAttr(genAttr);
}
}

void init_triton_llvm(py::module &&m) {

py::class_<llvm::LLVMContext>(m, "context", py::module_local())
Expand Down Expand Up @@ -271,10 +207,7 @@ void init_triton_llvm(py::module &&m) {
m.def(
"to_module",
[](mlir::ModuleOp &mod, llvm::LLVMContext &ctx) {
std::unique_ptr<llvm::Module> llvmMod =
mlir::translateModuleToLLVMIR(mod, ctx);
amendLLVMIR(*llvmMod, ctx);
return llvmMod;
return mlir::translateModuleToLLVMIR(mod, ctx);
},
py::keep_alive<0, 2>());

Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ configure_lit_site_cfg(

set(TRITON_TEST_DEPENDS
triton-opt
triton-translate
)

set(FILECHECK_PATH "${LLVM_LIBRARY_DIR}/../bin/FileCheck")
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/intel/tritongpu_to_gen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func spir_kernelcc @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// CHECK-SAME: attributes {passthrough = {{.*}}"gen.intel_reqd_sub_group_size", "32"{{.*}}"gen.max_work_group_size", "128,1,1"{{.*}}} {
// CHECK-SAME: attributes {triton_gen.intel_reqd_sub_group_size = [32 : i32], triton_gen.max_work_group_size = [128 : i32, 1 : i32, 1 : i32]} {
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
tt.return
Expand Down
18 changes: 18 additions & 0 deletions test/Target/LLVMIR/triton-gen.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: triton-translate -triton-to-llvmir -split-input-file %s | FileCheck %s

// CHECK: define spir_kernel void @test_intel_reqd_sub_group_size() !intel_reqd_sub_group_size ![[REQD_SUB_GROUP_SIZE:.*]] {
llvm.func spir_kernelcc @test_intel_reqd_sub_group_size() attributes {triton_gen.intel_reqd_sub_group_size = [32 : i32]} {
llvm.return
}
// CHECK: define spir_kernel void @test_max_work_group_size() !max_work_group_size ![[MAX_WORK_GROUP_SIZE:.*]] {
llvm.func spir_kernelcc @test_max_work_group_size() attributes {triton_gen.max_work_group_size = [128 : i32, 1 : i32, 1 : i32]} {
llvm.return
}
// CHECK: define spir_kernel void @test_reqd_work_group_size() !reqd_work_group_size ![[REQD_WORK_GROUP_SIZE:.*]] {
llvm.func spir_kernelcc @test_reqd_work_group_size() attributes {triton_gen.reqd_work_group_size = [128 : i32, 1 : i32, 2 : i32]} {
llvm.return
}

// CHECK-DAG: ![[REQD_SUB_GROUP_SIZE]] = !{i64 32}
// CHECK-DAG: ![[MAX_WORK_GROUP_SIZE]] = !{i64 128, i64 1, i64 1}
// CHECK-DAG: ![[REQD_WORK_GROUP_SIZE]] = !{i64 128, i64 1, i64 2}
1 change: 1 addition & 0 deletions third_party/intel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_triton_plugin(TritonXPU

LINK_LIBS
TritonGENToLLVM
TritonGENToLLVMIRTranslation
TritonIntelGPUToLLVM
TritonIntelGPUTransforms
)
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ def TritonGEN_Dialect : Dialect {
/// Get the name of the attribute used to annotate max work group size
/// required for kernels.
static constexpr ::llvm::StringLiteral getMaxWorkGroupSizeAttrName() {
return ::llvm::StringLiteral("gen.max_work_group_size");
return ::llvm::StringLiteral("triton_gen.max_work_group_size");
whitneywhtsang marked this conversation as resolved.
Show resolved Hide resolved
}

/// Get the name of the attribute used to annotate exact work group size
/// required for kernels.
static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
return ::llvm::StringLiteral("gen.reqd_work_group_size");
return ::llvm::StringLiteral("triton_gen.reqd_work_group_size");
}

/// Get the name for the attribute used to annotate the exact sub group
/// size required for kernels.
static constexpr ::llvm::StringLiteral getReqdSubGroupSizeAttrName() {
return ::llvm::StringLiteral("gen.intel_reqd_sub_group_size");
return ::llvm::StringLiteral("triton_gen.intel_reqd_sub_group_size");
}
}];
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//===-TritonGENToLLVMIRTranslation.h-TritonGEN Dialect to LLVM IR - C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This provides registration calls for TritonGEN dialect to LLVM IR
// translation.
//
//===----------------------------------------------------------------------===//

#ifndef TRITON_TARGET_LLVMIR_DIALECT_TRITONGEN_TRITONGENTOLLVMIRTRANSLATION_H
#define TRITON_TARGET_LLVMIR_DIALECT_TRITONGEN_TRITONGENTOLLVMIRTRANSLATION_H

namespace mlir {

class DialectRegistry;
class MLIRContext;

/// Register the TritonGEN dialect and the translation from it to the LLVM IR in
/// the given registry;
void registerTritonGENDialectTranslation(DialectRegistry &registry);

/// Register the TritonGEN dialect and the translation from it in the registry
/// associated with the given context.
void registerTritonGENDialectTranslation(MLIRContext &context);

} // namespace mlir

#endif // TRITON_TARGET_LLVMIR_DIALECT_TRITONGEN_TRITONGENTOLLVMIRTRANSLATION_H
1 change: 1 addition & 0 deletions third_party/intel/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_subdirectory(Dialect)
add_subdirectory(GPUToTritonGEN)
add_subdirectory(Target)
add_subdirectory(TritonGENToLLVM)
add_subdirectory(TritonIntelGPUToLLVM)
add_subdirectory(TritonIntelGPUTransforms)
1 change: 1 addition & 0 deletions third_party/intel/lib/Target/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(LLVMIR)
1 change: 1 addition & 0 deletions third_party/intel/lib/Target/LLVMIR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(Dialect)
1 change: 1 addition & 0 deletions third_party/intel/lib/Target/LLVMIR/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(TritonGEN)
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
add_mlir_translation_library(TritonGENToLLVMIRTranslation
TritonGENToLLVMIRTranslation.cpp

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRIR
TritonGENIR
MLIRLLVMDialect
MLIRSupport
MLIRTargetLLVMIRExport
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//===-TritonGENToLLVMIRTranslation.cpp - TritonGEN Dialect to LLVM IR -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a translation between the TritonGEN dialect and LLVM IR.
//
//===----------------------------------------------------------------------===//

#include "Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.h"

#include "Dialect/TritonGEN/IR/TritonGENDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"

namespace {
using namespace mlir;
class TritonGENDialectLLVMIRTranslationInterface
: public LLVMTranslationDialectInterface {
public:
using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;

LogicalResult
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
// Skip the attribute if it is not a TritonGEN attribute.
if (!attribute.getName().getValue().starts_with("triton_gen"))
return success();

llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
llvm::Function *llvmFunc =
moduleTranslation.lookupFunction(cast<LLVM::LLVMFuncOp>(op).getName());
if (isKernel(op))
amendKernel(llvmContext, llvmFunc, attribute);
whitneywhtsang marked this conversation as resolved.
Show resolved Hide resolved
return success();
}

private:
// Checks if the given operation is a kernel function.
bool isKernel(Operation *op) const {
auto fn = dyn_cast<LLVM::LLVMFuncOp>(op);
return fn && fn.getCConv() == LLVM::CConv::SPIR_KERNEL;
}

// The attribute is converted into metadata and added to the function.
void amendKernel(llvm::LLVMContext &llvmContext, llvm::Function *llvmFunc,
NamedAttribute attribute) const {
StringRef name = attribute.getName().getValue();
assert((name == triton::TritonGEN::TritonGENDialect::
getMaxWorkGroupSizeAttrName() ||
name == triton::TritonGEN::TritonGENDialect::
getReqdWorkGroupSizeAttrName() ||
name == triton::TritonGEN::TritonGENDialect::
getReqdSubGroupSizeAttrName()) &&
"Unexpected attribute");
SmallVector<llvm::Metadata *, 3> metadata;
llvm::Type *i64 = llvm::IntegerType::get(llvmContext, 64);
for (int64_t i :
extractFromIntegerArrayAttr<int64_t>(attribute.getValue())) {
llvm::Constant *constant = llvm::ConstantInt::get(i64, i);
metadata.push_back(llvm::ConstantAsMetadata::get(constant));
}
llvm::MDNode *node = llvm::MDNode::get(llvmContext, metadata);
llvmFunc->setMetadata(name.drop_front(11), node);
}
};
} // namespace

namespace mlir {
void registerTritonGENDialectTranslation(DialectRegistry &registry) {
registry.insert<triton::TritonGEN::TritonGENDialect>();
registry.addExtension(
+[](MLIRContext *ctx, triton::TritonGEN::TritonGENDialect *dialect) {
dialect->addInterfaces<TritonGENDialectLLVMIRTranslationInterface>();
});
}

void registerTritonGENDialectTranslation(MLIRContext &context) {
DialectRegistry registry;
registerTritonGENDialectTranslation(registry);
context.appendDialectRegistry(registry);
}
} // namespace mlir
Loading