diff --git a/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp index 3b5a94cabc..1c710e6527 100644 --- a/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp +++ b/third_party/intel/lib/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.cpp @@ -35,6 +35,10 @@ class TritonGENDialectLLVMIRTranslationInterface amendOperation(Operation *op, ArrayRef instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const final { + // Skip the attribute if it is not a TritonGEN attribute. + if (!attribute.getName().getValue().starts_with("triton_gen")) + return success(); + llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); llvm::Function *llvmFunc = moduleTranslation.lookupFunction(cast(op).getName()); @@ -53,6 +57,14 @@ class TritonGENDialectLLVMIRTranslationInterface // The attribute is converted into metadata and added to the function. void amendKernel(llvm::LLVMContext &llvmContext, llvm::Function *llvmFunc, NamedAttribute attribute) const { + StringRef name = attribute.getName().getValue(); + assert((name == triton::TritonGEN::TritonGENDialect:: + getMaxWorkGroupSizeAttrName() || + name == triton::TritonGEN::TritonGENDialect:: + getReqdWorkGroupSizeAttrName() || + name == triton::TritonGEN::TritonGENDialect:: + getReqdSubGroupSizeAttrName()) && + "Unexpected attribute"); SmallVector metadata; llvm::Type *i64 = llvm::IntegerType::get(llvmContext, 64); for (int64_t i : @@ -61,7 +73,7 @@ class TritonGENDialectLLVMIRTranslationInterface metadata.push_back(llvm::ConstantAsMetadata::get(constant)); } llvm::MDNode *node = llvm::MDNode::get(llvmContext, metadata); - llvmFunc->setMetadata(attribute.getName().getValue().drop_front(11), node); + llvmFunc->setMetadata(name.drop_front(11), node); } }; } // namespace