Skip to content

Commit

Permalink
Merge commit '14800bfa0cfa0930fcaed22fc1e91a8e13430580'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed May 22, 2024
2 parents f634d91 + 14800bf commit 323baad
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 69 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
Loop strength reduction is known to cause up to 10% performance changes for
certain kernels with register pressure.
- `TRITON_ALWAYS_COMPILE=1` forces to compile kernels regardless of cache hit.
- `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass.
- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass.

# Changelog

Expand Down
2 changes: 2 additions & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"DISABLE_MMA_V3",
"DISABLE_PTXAS_OPT",
"LLVM_IR_ENABLE_DUMP",
"LLVM_ENABLE_TIMING",
"MLIR_ENABLE_DIAGNOSTICS",
"MLIR_ENABLE_DUMP",
"MLIR_ENABLE_TIMING",
"TRITON_DISABLE_LINE_INFO",
"TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE",
"TRITON_ENABLE_LLVM_DEBUG",
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,15 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
// Set an attribute to indicate this function is a kernel entry.
newFuncOp->setAttr("nvvm.kernel",
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
newFuncOp.setLinkage(LLVM::Linkage::External);
} else {
// The noinline attribute will be used by the LLVM codegen to prevent
// inlining.
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267
newFuncOp.setPassthroughAttr(
ArrayAttr::get(ctx, rewriter.getStringAttr("noinline")));
rewriter.eraseOp(amendedFuncOp);
newFuncOp.setLinkage(LLVM::Linkage::Internal);
}
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata.
Expand Down
6 changes: 6 additions & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,7 @@ void init_triton_ir(py::module &&m) {
.def("run", [](PassManager &self, ModuleOp &mod) {
// TODO: maybe dump module to file and print error for better
// diagnostics

auto reproducerPath =
triton::tools::getStrEnv("TRITON_REPRODUCER_PATH");
if (!reproducerPath.empty()) {
Expand Down Expand Up @@ -1632,6 +1633,11 @@ void init_triton_ir(py::module &&m) {
::llvm::setCurrentDebugTypes(debugTypes.data(), debugTypes.size());
}

bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING");
if (haveTiming) {
self.enableTiming();
}

if (failed(self.run(mod.getOperation())))
throw std::runtime_error("PassManager::run failed");
});
Expand Down
73 changes: 69 additions & 4 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,24 @@ std::string translateLLVMIRToASM(llvm::Module &module,
llvm::legacy::PassManager pm;
pm.add(llvm::createAlwaysInlinerLegacyPass());
pm.add(llvm::createVerifierPass());

const bool enabledTiming =
mlir::triton::tools::getBoolEnv("LLVM_ENABLE_TIMING");
if (enabledTiming) {
llvm::TimePassesIsEnabled = true;
llvm::TimePassesPerRun = true;
}

pm.run(module);

SmallString<0> timePassesStr;
raw_svector_ostream reportStream(timePassesStr);

if (enabledTiming) {
reportAndResetTimings(&reportStream);
llvm::dbgs() << reportStream.str();
timePassesStr.clear();
}
// module->print(llvm::outs(), nullptr);

// create machine
Expand Down Expand Up @@ -117,6 +134,12 @@ std::string translateLLVMIRToASM(llvm::Module &module,
: llvm::CodeGenFileType::AssemblyFile;
machine->addPassesToEmitFile(pass, pstream, nullptr, fileType);
pass.run(module);

if (enabledTiming) {
reportAndResetTimings(&reportStream);
llvm::dbgs() << reportStream.str();
timePassesStr.clear();
}
}
return result;
}
Expand Down Expand Up @@ -175,6 +198,9 @@ void init_triton_llvm(py::module &&m) {
.def(
"get_functions",
[](llvm::Module *mod) -> llvm::Module::FunctionListType & {
// Note: Backends assume that we are compiling exactly one kernel
// (i.e. one function that's that's called by the CPU) and that it's
// the first function in this list.
return mod->getFunctionList();
},
ret::reference_internal)
Expand All @@ -185,14 +211,33 @@ void init_triton_llvm(py::module &&m) {
});

py::class_<llvm::Function>(m, "function", py::module_local())
.def_property_readonly(
"name", [](llvm::Function *fn) { return fn->getName().str(); })
.def("set_calling_conv", &llvm::Function::setCallingConv)
.def("add_fn_attr", [](llvm::Function *fn, std::string &name,
std::string &val) { fn->addFnAttr(name, val); })
.def("has_public_visibility",
[](llvm::Function *fn) {
return fn->getVisibility() == llvm::GlobalValue::DefaultVisibility;

// Sets the nvvm.maxreg property on the given function.
.def("set_nvvm_maxnreg",
[](llvm::Function *fn, int maxnreg) {
auto op = MDNode::get(
fn->getContext(),
{
ValueAsMetadata::get(fn),
MDString::get(fn->getContext(), "maxnreg"),
ConstantAsMetadata::get(ConstantInt::get(
Type::getInt32Ty(fn->getContext()), maxnreg)),
});
fn->getParent()
->getOrInsertNamedMetadata("nvvm.annotations")
->addOperand(op);
})
.def("is_declaration", &llvm::Function::isDeclaration);
// External functions that are definitions (i.e. not declarations) are
// kernel functions.
.def("is_declaration", &llvm::Function::isDeclaration)
.def("is_external_linkage", [](llvm::Function *fn) {
return fn->getLinkage() == llvm::GlobalValue::ExternalLinkage;
});

// optimization levels
py::class_<llvm::OptimizationLevel>(m, "optimization_level",
Expand Down Expand Up @@ -385,11 +430,31 @@ void init_triton_llvm(py::module &&m) {
}
libMod->setTargetTriple(dstMod->getTargetTriple());
libMod->setDataLayout(dstMod->getDataLayout());

std::unordered_set<std::string> externalFns;
for (llvm::Function &fn : libMod->functions()) {
if (!fn.isDeclaration())
externalFns.insert(fn.getName().str());
}

if (linker.linkInModule(std::move(libMod),
llvm::Linker::Flags::LinkOnlyNeeded)) {
std::string message = "Failed to link library at " + path;
throw std::invalid_argument(message);
}

// Mark linked-in functions as internal because backends use external
// linkage as a signifier of kernel functions.
for (llvm::Function &fn : dstMod->functions()) {
if (externalFns.count(fn.getName().str())) {
// FIXME: Temporary workaround to avoid marking SPIR_FUNC functions
// with InternalLinkage, which causes test_subprocess.py::test_assert
// to fail.
if (fn.getCallingConv() == CallingConv::SPIR_FUNC)
continue;
fn.setLinkage(llvm::GlobalValue::InternalLinkage);
}
}
}
});
}
65 changes: 45 additions & 20 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,8 +1388,6 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
if is_interpreter():
if dtype_x_str == 'float16':
pytest.xfail("Only test atomic float16 ops on GPU")
else:
check_cuda_or_hip(device)

n_programs = 5

Expand Down Expand Up @@ -4300,10 +4298,8 @@ def kernel(X, Y, BLOCK: tl.constexpr):

@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_inline_asm(num_ctas, device):
check_cuda_or_hip(device)

if is_hip():
pytest.skip("test_inline_asm is not supported in HIP")
if not is_cuda():
pytest.skip("test_inline_asm is only supported in CUDA")

@triton.jit
def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr):
Expand All @@ -4330,10 +4326,8 @@ def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr):

@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_inline_asm_packed(num_ctas, device):
check_cuda_or_hip(device)

if is_hip():
pytest.skip("test_inline_asm is not supported in HIP")
if not is_cuda():
pytest.skip("test_inline_asm is only supported in CUDA")

@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
Expand All @@ -4359,10 +4353,8 @@ def kernel(X, Y, BLOCK: tl.constexpr):

@pytest.mark.parametrize('num_ctas', num_ctas_list)
def test_inline_asm_with_pointers(num_ctas, device):
check_cuda_or_hip(device)

if is_hip():
pytest.skip('test_inline_asm is not supported in HIP')
if not is_cuda():
pytest.skip('test_inline_asm is only supported in CUDA')

@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
Expand All @@ -4386,9 +4378,8 @@ def kernel(X, Y, BLOCK: tl.constexpr):


def test_inline_asm_multiple_outputs(device):
check_cuda_or_hip(device)
if is_hip():
pytest.skip('This test uses PTX inline assembly, so is not compatible with AMD')
if not is_cuda():
pytest.skip('test_inline_asm is only supported in CUDA')

@triton.jit
def kernel(A, B, C, D, BLOCK: tl.constexpr):
Expand Down Expand Up @@ -4433,9 +4424,8 @@ def kernel(A, B, C, D, BLOCK: tl.constexpr):


def test_inline_asm_packed_multiple_outputs(device):
check_cuda_or_hip(device)
if is_hip():
pytest.skip('This test uses PTX inline assembly, so is not compatible with AMD')
if not is_cuda():
pytest.skip('test_inline_asm is only supported in CUDA')

@triton.jit
def kernel(A, B, C, D, BLOCK: tl.constexpr):
Expand Down Expand Up @@ -5404,3 +5394,38 @@ def test_tl_range(device):
ptx = pgm.asm['ptx']
# check that the loop got pipelined with the right number of stages.
assert 'cp.async.wait_group 0x6' in ptx


@triton.jit(noinline=True)
def maxnreg_noinline1(X):
tl.store(X, 0)


@triton.jit(noinline=True)
def maxnreg_noinline2(X):
tl.store(X, 0)


def test_maxnreg(device):
assert not is_interpreter(), "this test won't work with the interpreter"
if not is_cuda():
pytest.xfail('maxnreg only works on CUDA')

# triton kernel
@triton.jit
def kernel(X):
maxnreg_noinline1(X)
tl.store(X, 0)
maxnreg_noinline2(X)

X = torch.empty(1, dtype=torch.int32, device=device)
k = kernel[(1, )](X, maxnreg=42)

# Ensure that .maxnreg is set on the kernel function (marked with .entry)
# and not on either of the noinline functions (marked with .func).
try:
assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"])
assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"])
except AssertionError:
print("Failing ptx:\n", k.asm["ptx"])
raise
Loading

0 comments on commit 323baad

Please sign in to comment.