From e3064fd8edc9436bb15ef3359527323e78c0c5b6 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Wed, 15 May 2024 23:50:53 +0000 Subject: [PATCH] Avoid post-processing generated LLVM IR Signed-off-by: Whitney Tsang --- bin/CMakeLists.txt | 1 + python/src/llvm.cc | 69 +-------------- test/Conversion/intel/tritongpu_to_gen.mlir | 2 +- third_party/intel/CMakeLists.txt | 1 + .../Dialect/TritonGEN/IR/TritonGENDialect.td | 6 +- .../TritonGEN/TritonGENToLLVMIRTranslation.h | 32 +++++++ third_party/intel/lib/CMakeLists.txt | 1 + third_party/intel/lib/Target/CMakeLists.txt | 1 + .../intel/lib/Target/LLVMIR/CMakeLists.txt | 1 + .../lib/Target/LLVMIR/Dialect/CMakeLists.txt | 1 + .../LLVMIR/Dialect/TritonGEN/CMakeLists.txt | 13 +++ .../TritonGENToLLVMIRTranslation.cpp | 83 +++++++++++++++++++ .../TritonIntelGPUToLLVM/PipelineManager.h | 17 ++-- third_party/intel/triton_xpu.cc | 2 + 14 files changed, 147 insertions(+), 83 deletions(-) create mode 100644 third_party/intel/include/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.h create mode 100644 third_party/intel/lib/Target/CMakeLists.txt create mode 100644 third_party/intel/lib/Target/LLVMIR/CMakeLists.txt create mode 100644 third_party/intel/lib/Target/LLVMIR/Dialect/CMakeLists.txt create mode 100644 third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/CMakeLists.txt create mode 100644 third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 9b8badd94a..524ad21898 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -1,5 +1,6 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED) diff --git a/python/src/llvm.cc b/python/src/llvm.cc index f8b8aedc09..afe5eca589 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -135,70 +135,6 @@ static uint32_t findKernels(llvm::Module &M, return numKernels; } -/// Amend SPIR kernels in the given LLVM module by translating GEN passthrough -/// attributes into LLVM metadata. -static void amendLLVMIR(llvm::Module &llvmMod, llvm::LLVMContext &ctx) { - // Collect SPIR kernels. - std::set kernels; - uint32_t numKernels = findKernels(llvmMod, kernels); - assert(numKernels == 1 && "Expecting a single SPIR kernel"); - llvm::Function *kernel = *kernels.begin(); - - // Given a string \p str of the form "n1,n2,...", parse it as a - // vector of integers (n1,n2,...). - auto extractFromString = [](StringRef str) -> SmallVector { - auto parseAsInt = [](StringRef str, int64_t &intVal) { - bool failed = str.getAsInteger(10, intVal); - return !failed; - }; - - SmallVector result; - std::pair pair; - do { - pair = str.split(','); - str = pair.second; - int64_t intVal; - if (!parseAsInt(pair.first, intVal)) - break; - - result.push_back(intVal); - } while (true); - - return result; - }; - - // Attach metadata to \p func given its name \p attrName and value \p attrVal. - auto attachMetadata = [&](StringRef attrName, StringRef attrVal, - llvm::Function *func) { - SmallVector metadata; - llvm::Type *i64 = llvm::IntegerType::get(ctx, 64); - for (int64_t val : extractFromString(attrVal)) - metadata.push_back( - llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i64, val))); - - llvm::MDNode *node = llvm::MDNode::get(ctx, metadata); - func->setMetadata(attrName, node); - }; - - // Attach required metadata to the kernel. - using namespace mlir::triton; - SmallVector genAttrs{ - TritonGEN::TritonGENDialect::getMaxWorkGroupSizeAttrName(), - TritonGEN::TritonGENDialect::getReqdWorkGroupSizeAttrName(), - TritonGEN::TritonGENDialect::getReqdSubGroupSizeAttrName()}; - - for (llvm::StringLiteral genAttr : genAttrs) { - if (!kernel->hasFnAttribute(genAttr)) - continue; - - Attribute fnAttr = kernel->getFnAttribute(genAttr); - assert(fnAttr.isStringAttribute() && "Expecting a string attribute"); - attachMetadata(fnAttr.getKindAsString().split('.').second, - fnAttr.getValueAsString(), kernel); - kernel->removeFnAttr(genAttr); - } -} - void init_triton_llvm(py::module &&m) { py::class_(m, "context", py::module_local()) @@ -271,10 +207,7 @@ void init_triton_llvm(py::module &&m) { m.def( "to_module", [](mlir::ModuleOp &mod, llvm::LLVMContext &ctx) { - std::unique_ptr llvmMod = - mlir::translateModuleToLLVMIR(mod, ctx); - amendLLVMIR(*llvmMod, ctx); - return llvmMod; + return mlir::translateModuleToLLVMIR(mod, ctx); }, py::keep_alive<0, 2>()); diff --git a/test/Conversion/intel/tritongpu_to_gen.mlir b/test/Conversion/intel/tritongpu_to_gen.mlir index d732b76f9b..916b9e4c28 100644 --- a/test/Conversion/intel/tritongpu_to_gen.mlir +++ b/test/Conversion/intel/tritongpu_to_gen.mlir @@ -3,7 +3,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.func spir_kernelcc @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>) // Here the 128 comes from the 4 in module attribute multiples 32 - // CHECK-SAME: attributes {passthrough = {{.*}}"gen.intel_reqd_sub_group_size", "32"{{.*}}"gen.max_work_group_size", "128,1,1"{{.*}}} { + // CHECK-SAME: attributes {triton_gen.intel_reqd_sub_group_size = [32 : i32], triton_gen.max_work_group_size = [128 : i32, 1 : i32, 1 : i32]} { tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { // CHECK: llvm.return tt.return diff --git a/third_party/intel/CMakeLists.txt b/third_party/intel/CMakeLists.txt index 54fb5a374a..6a1b044a32 100644 --- a/third_party/intel/CMakeLists.txt +++ b/third_party/intel/CMakeLists.txt @@ -8,6 +8,7 @@ add_triton_plugin(TritonXPU LINK_LIBS TritonGENToLLVM + TritonGENToLLVMIRTranslation TritonIntelGPUToLLVM TritonIntelGPUTransforms ) diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td index e3243812d0..232e849f6f 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENDialect.td @@ -26,19 +26,19 @@ def TritonGEN_Dialect : Dialect { /// Get the name of the attribute used to annotate max work group size /// required for kernels. static constexpr ::llvm::StringLiteral getMaxWorkGroupSizeAttrName() { - return ::llvm::StringLiteral("gen.max_work_group_size"); + return ::llvm::StringLiteral("triton_gen.max_work_group_size"); } /// Get the name of the attribute used to annotate exact work group size /// required for kernels. static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() { - return ::llvm::StringLiteral("gen.reqd_work_group_size"); + return ::llvm::StringLiteral("triton_gen.reqd_work_group_size"); } /// Get the name for the attribute used to annotate the exact sub group /// size required for kernels. static constexpr ::llvm::StringLiteral getReqdSubGroupSizeAttrName() { - return ::llvm::StringLiteral("gen.intel_reqd_sub_group_size"); + return ::llvm::StringLiteral("triton_gen.intel_reqd_sub_group_size"); } }]; } diff --git a/third_party/intel/include/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.h b/third_party/intel/include/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.h new file mode 100644 index 0000000000..0d30a06de5 --- /dev/null +++ b/third_party/intel/include/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.h @@ -0,0 +1,32 @@ +//===-TritonGENToLLVMIRTranslation.h-TritonGEN Dialect to LLVM IR - C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for TritonGEN dialect to LLVM IR +// translation. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_TARGET_LLVMIR_DIALECT_TRITONGEN_TRITONGENTOLLVMIRTRANSLATION_H +#define TRITON_TARGET_LLVMIR_DIALECT_TRITONGEN_TRITONGENTOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the TritonGEN dialect and the translation from it to the LLVM IR in +/// the given registry; +void registerTritonGENDialectTranslation(DialectRegistry ®istry); + +/// Register the TritonGEN dialect and the translation from it in the registry +/// associated with the given context. +void registerTritonGENDialectTranslation(MLIRContext &context); + +} // namespace mlir + +#endif // TRITON_TARGET_LLVMIR_DIALECT_TRITONGEN_TRITONGENTOLLVMIRTRANSLATION_H diff --git a/third_party/intel/lib/CMakeLists.txt b/third_party/intel/lib/CMakeLists.txt index 9babba170e..7be0ec45ca 100644 --- a/third_party/intel/lib/CMakeLists.txt +++ b/third_party/intel/lib/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(Dialect) add_subdirectory(GPUToTritonGEN) +add_subdirectory(Target) add_subdirectory(TritonGENToLLVM) add_subdirectory(TritonIntelGPUToLLVM) add_subdirectory(TritonIntelGPUTransforms) diff --git a/third_party/intel/lib/Target/CMakeLists.txt b/third_party/intel/lib/Target/CMakeLists.txt new file mode 100644 index 0000000000..39d31dc9b5 --- /dev/null +++ b/third_party/intel/lib/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) diff --git a/third_party/intel/lib/Target/LLVMIR/CMakeLists.txt b/third_party/intel/lib/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 0000000000..0ca0f41c5a --- /dev/null +++ b/third_party/intel/lib/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/third_party/intel/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/third_party/intel/lib/Target/LLVMIR/Dialect/CMakeLists.txt new file mode 100644 index 0000000000..92e4da0a8e --- /dev/null +++ b/third_party/intel/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TritonGEN) diff --git a/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/CMakeLists.txt b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/CMakeLists.txt new file mode 100644 index 0000000000..18fe611f1b --- /dev/null +++ b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_translation_library(TritonGENToLLVMIRTranslation + TritonGENToLLVMIRTranslation.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + TritonGENIR + MLIRLLVMDialect + MLIRSupport + MLIRTargetLLVMIRExport +) diff --git a/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp new file mode 100644 index 0000000000..3b5a94cabc --- /dev/null +++ b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp @@ -0,0 +1,83 @@ +//===-TritonGENToLLVMIRTranslation.cpp - TritonGEN Dialect to LLVM IR -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the TritonGEN dialect and LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.h" + +#include "Dialect/TritonGEN/IR/TritonGENDialect.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" + +namespace { +using namespace mlir; +class TritonGENDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + LogicalResult + amendOperation(Operation *op, ArrayRef instructions, + NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const final { + llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); + llvm::Function *llvmFunc = + moduleTranslation.lookupFunction(cast(op).getName()); + if (isKernel(op)) + amendKernel(llvmContext, llvmFunc, attribute); + return success(); + } + +private: + // Checks if the given operation is a kernel function. + bool isKernel(Operation *op) const { + auto fn = dyn_cast(op); + return fn && fn.getCConv() == LLVM::CConv::SPIR_KERNEL; + } + + // The attribute is converted into metadata and added to the function. + void amendKernel(llvm::LLVMContext &llvmContext, llvm::Function *llvmFunc, + NamedAttribute attribute) const { + SmallVector metadata; + llvm::Type *i64 = llvm::IntegerType::get(llvmContext, 64); + for (int64_t i : + extractFromIntegerArrayAttr(attribute.getValue())) { + llvm::Constant *constant = llvm::ConstantInt::get(i64, i); + metadata.push_back(llvm::ConstantAsMetadata::get(constant)); + } + llvm::MDNode *node = llvm::MDNode::get(llvmContext, metadata); + llvmFunc->setMetadata(attribute.getName().getValue().drop_front(11), node); + } +}; +} // namespace + +namespace mlir { +void registerTritonGENDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addExtension( + +[](MLIRContext *ctx, triton::TritonGEN::TritonGENDialect *dialect) { + dialect->addInterfaces(); + }); +} + +void registerTritonGENDialectTranslation(MLIRContext &context) { + DialectRegistry registry; + registerTritonGENDialectTranslation(registry); + context.appendDialectRegistry(registry); +} +} // namespace mlir diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h index b9b88f621f..cf7b3b9d99 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h @@ -110,17 +110,12 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { if (LLVM::isKernel(funcOp)) newFuncOp.setCConv(LLVM::CConv::SPIR_KERNEL); - auto maxWorkGroupSizeAttr = rewriter.getArrayAttr( - {rewriter.getStringAttr( - TritonGEN::TritonGENDialect::getMaxWorkGroupSizeAttrName()), - rewriter.getStringAttr(std::to_string(threadsPerWarp * numWarps) + - ",1,1")}); - auto reqSubGroupSizeAttr = rewriter.getArrayAttr( - {rewriter.getStringAttr( - TritonGEN::TritonGENDialect::getReqdSubGroupSizeAttrName()), - rewriter.getStringAttr(std::to_string(threadsPerWarp))}); - newFuncOp.setPassthroughAttr( - ArrayAttr::get(ctx, {reqSubGroupSizeAttr, maxWorkGroupSizeAttr})); + NamedAttrList attrs; + attrs.append(TritonGEN::TritonGENDialect::getMaxWorkGroupSizeAttrName(), + rewriter.getI32ArrayAttr({threadsPerWarp * numWarps, 1, 1})); + attrs.append(TritonGEN::TritonGENDialect::getReqdSubGroupSizeAttrName(), + rewriter.getI32ArrayAttr({threadsPerWarp})); + newFuncOp->setDialectAttrs(attrs); if (!LLVM::isKernel(funcOp)) { newFuncOp.setPassthroughAttr( diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index d8e3153288..3312f1270a 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -4,6 +4,7 @@ #include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "intel/include/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.h" #include "intel/include/TritonIntelGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/Passes.h" @@ -74,6 +75,7 @@ void init_triton_intel(py::module &&m) { mlir::DialectRegistry registry; registry .insert(); + mlir::registerTritonGENDialectTranslation(registry); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); });