Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge OpenAI Triton commit 14800bf #1174

Merged
merged 5 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
Loop strength reduction is known to cause up to 10% performance changes for
certain kernels with register pressure.
- `TRITON_ALWAYS_COMPILE=1` forces to compile kernels regardless of cache hit.
- `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass.
- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass.

# Changelog

Expand Down
2 changes: 2 additions & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"DISABLE_MMA_V3",
"DISABLE_PTXAS_OPT",
"LLVM_IR_ENABLE_DUMP",
"LLVM_ENABLE_TIMING",
"MLIR_ENABLE_DIAGNOSTICS",
"MLIR_ENABLE_DUMP",
"MLIR_ENABLE_TIMING",
"TRITON_DISABLE_LINE_INFO",
"TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE",
"TRITON_ENABLE_LLVM_DEBUG",
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,15 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
// Set an attribute to indicate this function is a kernel entry.
newFuncOp->setAttr("nvvm.kernel",
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
newFuncOp.setLinkage(LLVM::Linkage::External);
} else {
// The noinline attribute will be used by the LLVM codegen to prevent
// inlining.
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267
newFuncOp.setPassthroughAttr(
ArrayAttr::get(ctx, rewriter.getStringAttr("noinline")));
rewriter.eraseOp(amendedFuncOp);
newFuncOp.setLinkage(LLVM::Linkage::Internal);
}
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata.
Expand Down
6 changes: 6 additions & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,7 @@ void init_triton_ir(py::module &&m) {
.def("run", [](PassManager &self, ModuleOp &mod) {
// TODO: maybe dump module to file and print error for better
// diagnostics

auto reproducerPath =
triton::tools::getStrEnv("TRITON_REPRODUCER_PATH");
if (!reproducerPath.empty()) {
Expand Down Expand Up @@ -1632,6 +1633,11 @@ void init_triton_ir(py::module &&m) {
::llvm::setCurrentDebugTypes(debugTypes.data(), debugTypes.size());
}

bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING");
if (haveTiming) {
self.enableTiming();
}

if (failed(self.run(mod.getOperation())))
throw std::runtime_error("PassManager::run failed");
});
Expand Down
73 changes: 69 additions & 4 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,24 @@ std::string translateLLVMIRToASM(llvm::Module &module,
llvm::legacy::PassManager pm;
pm.add(llvm::createAlwaysInlinerLegacyPass());
pm.add(llvm::createVerifierPass());

const bool enabledTiming =
mlir::triton::tools::getBoolEnv("LLVM_ENABLE_TIMING");
if (enabledTiming) {
llvm::TimePassesIsEnabled = true;
llvm::TimePassesPerRun = true;
}

pm.run(module);

SmallString<0> timePassesStr;
raw_svector_ostream reportStream(timePassesStr);

if (enabledTiming) {
reportAndResetTimings(&reportStream);
llvm::dbgs() << reportStream.str();
timePassesStr.clear();
}
// module->print(llvm::outs(), nullptr);

// create machine
Expand Down Expand Up @@ -117,6 +134,12 @@ std::string translateLLVMIRToASM(llvm::Module &module,
: llvm::CodeGenFileType::AssemblyFile;
machine->addPassesToEmitFile(pass, pstream, nullptr, fileType);
pass.run(module);

if (enabledTiming) {
reportAndResetTimings(&reportStream);
llvm::dbgs() << reportStream.str();
timePassesStr.clear();
}
}
return result;
}
Expand Down Expand Up @@ -175,6 +198,9 @@ void init_triton_llvm(py::module &&m) {
.def(
"get_functions",
[](llvm::Module *mod) -> llvm::Module::FunctionListType & {
// Note: Backends assume that we are compiling exactly one kernel
// (i.e. one function that's that's called by the CPU) and that it's
// the first function in this list.
return mod->getFunctionList();
},
ret::reference_internal)
Expand All @@ -185,14 +211,33 @@ void init_triton_llvm(py::module &&m) {
});

py::class_<llvm::Function>(m, "function", py::module_local())
.def_property_readonly(
"name", [](llvm::Function *fn) { return fn->getName().str(); })
.def("set_calling_conv", &llvm::Function::setCallingConv)
.def("add_fn_attr", [](llvm::Function *fn, std::string &name,
std::string &val) { fn->addFnAttr(name, val); })
.def("has_public_visibility",
[](llvm::Function *fn) {
return fn->getVisibility() == llvm::GlobalValue::DefaultVisibility;

// Sets the nvvm.maxreg property on the given function.
.def("set_nvvm_maxnreg",
[](llvm::Function *fn, int maxnreg) {
auto op = MDNode::get(
fn->getContext(),
{
ValueAsMetadata::get(fn),
MDString::get(fn->getContext(), "maxnreg"),
ConstantAsMetadata::get(ConstantInt::get(
Type::getInt32Ty(fn->getContext()), maxnreg)),
});
fn->getParent()
->getOrInsertNamedMetadata("nvvm.annotations")
->addOperand(op);
})
.def("is_declaration", &llvm::Function::isDeclaration);
// External functions that are definitions (i.e. not declarations) are
// kernel functions.
.def("is_declaration", &llvm::Function::isDeclaration)
.def("is_external_linkage", [](llvm::Function *fn) {
return fn->getLinkage() == llvm::GlobalValue::ExternalLinkage;
});

// optimization levels
py::class_<llvm::OptimizationLevel>(m, "optimization_level",
Expand Down Expand Up @@ -385,11 +430,31 @@ void init_triton_llvm(py::module &&m) {
}
libMod->setTargetTriple(dstMod->getTargetTriple());
libMod->setDataLayout(dstMod->getDataLayout());

std::unordered_set<std::string> externalFns;
for (llvm::Function &fn : libMod->functions()) {
if (!fn.isDeclaration())
externalFns.insert(fn.getName().str());
}

if (linker.linkInModule(std::move(libMod),
llvm::Linker::Flags::LinkOnlyNeeded)) {
std::string message = "Failed to link library at " + path;
throw std::invalid_argument(message);
}

// Mark linked-in functions as internal because backends use external
// linkage as a signifier of kernel functions.
for (llvm::Function &fn : dstMod->functions()) {
if (externalFns.count(fn.getName().str())) {
// FIXME: Temporary workaround to avoid marking SPIR_FUNC functions
// with InternalLinkage, which causes test_subprocess.py::test_assert
// to fail.
if (fn.getCallingConv() == CallingConv::SPIR_FUNC)
continue;
fn.setLinkage(llvm::GlobalValue::InternalLinkage);
}
}
}
});
}
65 changes: 45 additions & 20 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,8 +1388,6 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
if is_interpreter():
if dtype_x_str == 'float16':
pytest.xfail("Only test atomic float16 ops on GPU")
else:
check_cuda_or_hip(device)

n_programs = 5

Expand Down Expand Up @@ -4300,10 +4298,8 @@ def kernel(X, Y, BLOCK: tl.constexpr):

@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_inline_asm(num_ctas, device):
check_cuda_or_hip(device)

if is_hip():
pytest.skip("test_inline_asm is not supported in HIP")
if not is_cuda():
pytest.skip("test_inline_asm is only supported in CUDA")

@triton.jit
def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr):
Expand All @@ -4330,10 +4326,8 @@ def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr):

@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_inline_asm_packed(num_ctas, device):
check_cuda_or_hip(device)

if is_hip():
pytest.skip("test_inline_asm is not supported in HIP")
if not is_cuda():
pytest.skip("test_inline_asm is only supported in CUDA")

@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
Expand All @@ -4359,10 +4353,8 @@ def kernel(X, Y, BLOCK: tl.constexpr):

@pytest.mark.parametrize('num_ctas', num_ctas_list)
def test_inline_asm_with_pointers(num_ctas, device):
check_cuda_or_hip(device)

if is_hip():
pytest.skip('test_inline_asm is not supported in HIP')
if not is_cuda():
pytest.skip('test_inline_asm is only supported in CUDA')

@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
Expand All @@ -4386,9 +4378,8 @@ def kernel(X, Y, BLOCK: tl.constexpr):


def test_inline_asm_multiple_outputs(device):
check_cuda_or_hip(device)
if is_hip():
pytest.skip('This test uses PTX inline assembly, so is not compatible with AMD')
if not is_cuda():
pytest.skip('test_inline_asm is only supported in CUDA')

@triton.jit
def kernel(A, B, C, D, BLOCK: tl.constexpr):
Expand Down Expand Up @@ -4433,9 +4424,8 @@ def kernel(A, B, C, D, BLOCK: tl.constexpr):


def test_inline_asm_packed_multiple_outputs(device):
check_cuda_or_hip(device)
if is_hip():
pytest.skip('This test uses PTX inline assembly, so is not compatible with AMD')
if not is_cuda():
pytest.skip('test_inline_asm is only supported in CUDA')

@triton.jit
def kernel(A, B, C, D, BLOCK: tl.constexpr):
Expand Down Expand Up @@ -5404,3 +5394,38 @@ def test_tl_range(device):
ptx = pgm.asm['ptx']
# check that the loop got pipelined with the right number of stages.
assert 'cp.async.wait_group 0x6' in ptx


@triton.jit(noinline=True)
def maxnreg_noinline1(X):
tl.store(X, 0)


@triton.jit(noinline=True)
def maxnreg_noinline2(X):
tl.store(X, 0)


def test_maxnreg(device):
assert not is_interpreter(), "this test won't work with the interpreter"
if not is_cuda():
pytest.xfail('maxnreg only works on CUDA')

# triton kernel
@triton.jit
def kernel(X):
maxnreg_noinline1(X)
tl.store(X, 0)
maxnreg_noinline2(X)

X = torch.empty(1, dtype=torch.int32, device=device)
k = kernel[(1, )](X, maxnreg=42)

# Ensure that .maxnreg is set on the kernel function (marked with .entry)
# and not on either of the noinline functions (marked with .func).
try:
assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"])
assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"])
except AssertionError:
print("Failing ptx:\n", k.asm["ptx"])
raise
Loading