From 14c28b93bb9211a6266795ec6105a22a8ef84de1 Mon Sep 17 00:00:00 2001 From: Fangrui Song Date: Mon, 20 May 2024 09:59:04 -0700 Subject: [PATCH 1/4] [backend][amd] Remove unneeded setUseAssemblerInfoForParsing (#3948) Follow-up to #3933. This might be copied from llvm/lib/Target/AMDGPU/AMDGPUAsmPrinter.cpp. `setUseAssemblerInfoForParsing` was only used for -mattr=dumpcode and should not be needed here. The llvm-project plan is to continue improving the integrated assembler and remove `setUseAssemblerInfoForParsing`. While here, drop the byte-order mark. --- third_party/amd/python/triton_amd.cc | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 50f640d66b..7c1f12c517 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -1,4 +1,4 @@ -#include "TritonAMDGPUToLLVM/Passes.h" +#include "TritonAMDGPUToLLVM/Passes.h" #include "TritonAMDGPUToLLVM/TargetUtils.h" #include "TritonAMDGPUTransforms/Passes.h" #include "mlir/Pass/PassManager.h" @@ -188,11 +188,6 @@ void init_triton_amd(py::module &&m) { std::move(ce), *sti, mcOptions.MCRelaxAll, mcOptions.MCIncrementalLinkerCompatible, /*DWARFMustBeAtTheEnd=*/false)); - // This was removed in https://github.com/llvm/llvm-project/pull/91082, - // but reverted in a later LLVM version. - // TODO(khasanovaa): uncomment the following line on the next LLVM - // update if it remains reverted. - // mcStreamer->setUseAssemblerInfoForParsing(true); std::unique_ptr parser( createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai)); From 57101d3baee8973748acfdf9c361f72741e9e4ad Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Mon, 20 May 2024 19:06:24 +0200 Subject: [PATCH 2/4] [AMD] Added MLIR and LLVM timing info (#3949) This PR adds the functionality to profile MLIR/LLVM passes. A user can set `MLIR_ENABLE_TIMING` or `LLVM_ENABLE_TIMING` env. variable to track the elapsed time of each MLIR or LLVM pass. --- README.md | 2 ++ include/triton/Tools/Sys/GetEnv.hpp | 2 ++ python/src/ir.cc | 6 ++++++ python/src/llvm.cc | 22 ++++++++++++++++++++++ 4 files changed, 32 insertions(+) diff --git a/README.md b/README.md index b62251715a..3841c6f3e8 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,8 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi Loop strength reduction is known to cause up to 10% performance changes for certain kernels with register pressure. - `TRITON_ALWAYS_COMPILE=1` forces to compile kernels regardless of cache hit. +- `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass. +- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass. # Changelog diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 035019d315..12584aa8f1 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -18,8 +18,10 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "DISABLE_MMA_V3", "DISABLE_PTXAS_OPT", "LLVM_IR_ENABLE_DUMP", + "LLVM_ENABLE_TIMING", "MLIR_ENABLE_DIAGNOSTICS", "MLIR_ENABLE_DUMP", + "MLIR_ENABLE_TIMING", "TRITON_DISABLE_LINE_INFO", "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE", "TRITON_ENABLE_LLVM_DEBUG", diff --git a/python/src/ir.cc b/python/src/ir.cc index ade64a9753..0befdc491b 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1583,6 +1583,7 @@ void init_triton_ir(py::module &&m) { .def("run", [](PassManager &self, ModuleOp &mod) { // TODO: maybe dump module to file and print error for better // diagnostics + auto reproducerPath = triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); if (!reproducerPath.empty()) { @@ -1616,6 +1617,11 @@ void init_triton_ir(py::module &&m) { ::llvm::setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); } + bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); + if (haveTiming) { + self.enableTiming(); + } + if (failed(self.run(mod.getOperation()))) throw std::runtime_error("PassManager::run failed"); }); diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 1b061d5997..0bfad31d46 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -81,7 +81,23 @@ std::string translateLLVMIRToASM(llvm::Module &module, llvm::legacy::PassManager pm; pm.add(llvm::createAlwaysInlinerLegacyPass()); pm.add(llvm::createVerifierPass()); + + const bool enabledTiming = triton::tools::getBoolEnv("LLVM_ENABLE_TIMING"); + if (enabledTiming) { + llvm::TimePassesIsEnabled = true; + llvm::TimePassesPerRun = true; + } + pm.run(module); + + SmallString<0> timePassesStr; + raw_svector_ostream reportStream(timePassesStr); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } // module->print(llvm::outs(), nullptr); // create machine @@ -116,6 +132,12 @@ std::string translateLLVMIRToASM(llvm::Module &module, : llvm::CodeGenFileType::AssemblyFile; machine->addPassesToEmitFile(pass, pstream, nullptr, fileType); pass.run(module); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } } return result; } From 708814624b1fde6271eda76b081dd35e4a0dc6c6 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Mon, 20 May 2024 10:31:34 -0700 Subject: [PATCH 3/4] Add maxnreg kernel config. (#3944) Like num_stages and num_warps, this tunes the generated kernel, controlling the maximum number of registers that one thread can consume. --- .../TritonGPUToLLVM/FuncOpToLLVM.cpp | 2 + python/src/llvm.cc | 45 ++++++++++++++-- python/test/unit/language/test_core.py | 35 ++++++++++++ python/triton/runtime/autotuner.py | 54 ++++++++++--------- python/triton/runtime/interpreter.py | 2 +- third_party/amd/backend/compiler.py | 10 ++-- third_party/nvidia/backend/compiler.py | 18 ++++--- 7 files changed, 124 insertions(+), 42 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp index 04fa645477..47f40ebecd 100644 --- a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -87,6 +87,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { // Set an attribute to indicate this function is a kernel entry. newFuncOp->setAttr("nvvm.kernel", rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); + newFuncOp.setLinkage(LLVM::Linkage::External); } else { // The noinline attribute will be used by the LLVM codegen to prevent // inlining. @@ -94,6 +95,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { newFuncOp.setPassthroughAttr( ArrayAttr::get(ctx, rewriter.getStringAttr("noinline"))); rewriter.eraseOp(amendedFuncOp); + newFuncOp.setLinkage(LLVM::Linkage::Internal); } // Set an attribute for maxntidx, it could be used in latter LLVM codegen // for `nvvm.annotation` metadata. diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 0bfad31d46..0039d1a2f5 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -184,6 +184,9 @@ void init_triton_llvm(py::module &&m) { .def( "get_functions", [](llvm::Module *mod) -> llvm::Module::FunctionListType & { + // Note: Backends assume that we are compiling exactly one kernel + // (i.e. one function that's that's called by the CPU) and that it's + // the first function in this list. return mod->getFunctionList(); }, ret::reference_internal) @@ -194,14 +197,33 @@ void init_triton_llvm(py::module &&m) { }); py::class_(m, "function", py::module_local()) + .def_property_readonly( + "name", [](llvm::Function *fn) { return fn->getName().str(); }) .def("set_calling_conv", &llvm::Function::setCallingConv) .def("add_fn_attr", [](llvm::Function *fn, std::string &name, std::string &val) { fn->addFnAttr(name, val); }) - .def("has_public_visibility", - [](llvm::Function *fn) { - return fn->getVisibility() == llvm::GlobalValue::DefaultVisibility; + + // Sets the nvvm.maxreg property on the given function. + .def("set_nvvm_maxnreg", + [](llvm::Function *fn, int maxnreg) { + auto op = MDNode::get( + fn->getContext(), + { + ValueAsMetadata::get(fn), + MDString::get(fn->getContext(), "maxnreg"), + ConstantAsMetadata::get(ConstantInt::get( + Type::getInt32Ty(fn->getContext()), maxnreg)), + }); + fn->getParent() + ->getOrInsertNamedMetadata("nvvm.annotations") + ->addOperand(op); }) - .def("is_declaration", &llvm::Function::isDeclaration); + // External functions that are definitions (i.e. not declarations) are + // kernel functions. + .def("is_declaration", &llvm::Function::isDeclaration) + .def("is_external_linkage", [](llvm::Function *fn) { + return fn->getLinkage() == llvm::GlobalValue::ExternalLinkage; + }); // optimization levels py::class_(m, "optimization_level", @@ -358,11 +380,26 @@ void init_triton_llvm(py::module &&m) { } libMod->setTargetTriple(dstMod->getTargetTriple()); libMod->setDataLayout(dstMod->getDataLayout()); + + std::unordered_set externalFns; + for (llvm::Function &fn : libMod->functions()) { + if (!fn.isDeclaration()) + externalFns.insert(fn.getName().str()); + } + if (linker.linkInModule(std::move(libMod), llvm::Linker::Flags::LinkOnlyNeeded)) { std::string message = "Failed to link library at " + path; throw std::invalid_argument(message); } + + // Mark linked-in functions as internal because backends use external + // linkage as a signifier of kernel functions. + for (llvm::Function &fn : dstMod->functions()) { + if (externalFns.count(fn.getName().str())) { + fn.setLinkage(llvm::GlobalValue::InternalLinkage); + } + } } }); } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a4b0a4798d..934732147b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5304,3 +5304,38 @@ def test_tl_range(device): ptx = pgm.asm['ptx'] # check that the loop got pipelined with the right number of stages. assert 'cp.async.wait_group 0x6' in ptx + + +@triton.jit(noinline=True) +def maxnreg_noinline1(X): + tl.store(X, 0) + + +@triton.jit(noinline=True) +def maxnreg_noinline2(X): + tl.store(X, 0) + + +def test_maxnreg(device): + assert not is_interpreter(), "this test won't work with the interpreter" + if is_hip(): + pytest.skip('maxnreg only works on CUDA') + + # triton kernel + @triton.jit + def kernel(X): + maxnreg_noinline1(X) + tl.store(X, 0) + maxnreg_noinline2(X) + + X = torch.empty(1, dtype=torch.int32, device=device) + k = kernel[(1, )](X, maxnreg=42) + + # Ensure that .maxnreg is set on the kernel function (marked with .entry) + # and not on either of the noinline functions (marked with .func). + try: + assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) + assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) + except AssertionError: + print("Failing ptx:\n", k.asm["ptx"]) + raise diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index c8e7ae04bd..9441f46ae8 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -104,7 +104,7 @@ def _bench(self, *args, config, **meta): raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." " Make sure that you don't re-define auto-tuned symbols.") # augment meta-parameters with tunable ones - current = dict(meta, **config.kwargs) + current = dict(meta, **config.all_kwargs()) full_nargs = {**self.nargs, **current} def kernel_call(): @@ -114,9 +114,6 @@ def kernel_call(): try: self.fn.run( *args, - num_warps=config.num_warps, - num_stages=config.num_stages, - num_ctas=config.num_ctas, **current, ) except Exception as e: @@ -170,16 +167,12 @@ def run(self, *args, **kwargs): if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: print(f"Triton autotuning for function {self.base_fn.__name__} finished after " f"{self.bench_time:.2f}s; best config selected: {self.best_config};") - full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs} if config.pre_hook is not None: - config.pre_hook(full_nargs) + config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()}) ret = self.fn.run( *args, - num_warps=config.num_warps, - num_stages=config.num_stages, - num_ctas=config.num_ctas, **kwargs, - **config.kwargs, + **config.all_kwargs(), ) self.nargs = None return ret @@ -194,14 +187,10 @@ def prune_configs(self, kwargs): top_k = int(len(self.configs) * top_k) if len(pruned_configs) > top_k: est_timing = { - config: - self.perf_model( + config: self.perf_model( **self.nargs, **kwargs, - **config.kwargs, - num_stages=config.num_stages, - num_warps=config.num_warps, - num_ctas=config.num_ctas, + **config.all_kwargs(), ) for config in pruned_configs } @@ -212,15 +201,11 @@ def warmup(self, *args, **kwargs): self.nargs = dict(zip(self.arg_names, args)) ret = [] for config in self.prune_configs(kwargs): - ret.append( - self.fn.warmup( - *args, - num_warps=config.num_warps, - num_ctas=config.num_ctas, - num_stages=config.num_stages, - **kwargs, - **config.kwargs, - )) + ret.append(self.fn.warmup( + *args, + **kwargs, + **config.all_kwargs(), + )) self.nargs = None return ret @@ -239,17 +224,33 @@ class Config: Mostly useful for matrix multiplication workloads on SM80+ GPUs. :type num_ctas: int :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :type maxnreg: Optional[int] + :ivar maxnreg: maximum number of registers one thread can use. Corresponds + to ptx .maxnreg directive. Not supported on all platforms. :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this function are args. """ - def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, pre_hook=None): + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, maxnreg=None, pre_hook=None): self.kwargs = kwargs self.num_warps = num_warps self.num_ctas = num_ctas self.num_stages = num_stages + self.maxnreg = maxnreg self.pre_hook = pre_hook + def all_kwargs(self): + return self.kwargs | { + k: v + for (k, v) in ( + ("num_warps", self.num_warps), + ("num_ctas", self.num_ctas), + ("num_stages", self.num_stages), + ("maxnreg", self.maxnreg), + ) + if v is not None + } + def __str__(self): res = [] for k, v in self.kwargs.items(): @@ -257,6 +258,7 @@ def __str__(self): res.append(f"num_warps: {self.num_warps}") res.append(f"num_ctas: {self.num_ctas}") res.append(f"num_stages: {self.num_stages}") + res.append(f"maxnreg: {self.maxnreg}") return ", ".join(res) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 4f892f669e..a82832ecf9 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -1027,7 +1027,7 @@ def _implicit_cvt(arg): interpreter_builder = InterpreterBuilder() # These keywords are not supported by the interpreter -RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid"] +RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg"] class GridExecutor: diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 479e60c265..a077b0b60f 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -186,13 +186,13 @@ def make_llir(src, metadata, options): amd.set_bool_control_constant(llvm_mod, "__oclc_wavefrontsize64", options.warp_size == 64) # Set kernel attributes first given this may affect later optimizations. - kernels = [fn for fn in llvm_mod.get_functions() if not fn.is_declaration()] + fns = [fn for fn in llvm_mod.get_functions() if not fn.is_declaration()] # The public kernel should be kernel 0. - kernels[0].set_calling_conv(amd.CALLING_CONV_AMDGPU_KERNEL) - kernels[0].add_fn_attr("amdgpu-flat-work-group-size", f"1,{options.num_warps*options.warp_size}") - kernels[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}") + fns[0].set_calling_conv(amd.CALLING_CONV_AMDGPU_KERNEL) + fns[0].add_fn_attr("amdgpu-flat-work-group-size", f"1,{options.num_warps*options.warp_size}") + fns[0].add_fn_attr("amdgpu-waves-per-eu", f"{options.waves_per_eu}") denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee" - kernels[0].add_fn_attr("denormal-fp-math-f32", denormal_mode) + fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode) if options.extern_libs: paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)] diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index dfb0f92b24..a9f389ad61 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import functools -from typing import Any, Tuple +from typing import Any, Tuple, Optional import hashlib import re import tempfile @@ -63,6 +63,9 @@ class CUDAOptions: num_warps: int = 4 num_ctas: int = 1 num_stages: int = 3 + # maxnreg corresponds to the ptx parameter .maxnreg, which controls the + # maximum number of 32-bit registers used by one thread. + maxnreg: Optional[int] = None cluster_dims: tuple = (1, 1, 1) ptx_version: int = None enable_fp_fusion: bool = True @@ -217,15 +220,18 @@ def make_llir(src, metadata, options, capability): context = llvm.context() llvm_mod = llvm.to_module(mod, context) nvidia.set_nvvm_reflect_ftz(llvm_mod) + + # Set maxnreg on all kernels, if it was provided. + if options.maxnreg is not None: + for k in llvm_mod.get_functions(): + if not k.is_declaration() and k.is_external_linkage(): + k.set_nvvm_maxnreg(options.maxnreg) + if options.extern_libs: paths = [path for (name, path) in options.extern_libs] llvm.link_extern_libs(llvm_mod, paths) + llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) - # Set kernel attributes - # kernels = [fn for fn in llvm_mod.get_functions() if fn.has_public_visibility() and not fn.is_declaration()] - # assert len(kernels) == 1 - # kernels[0].add_fn_attr("nvvm.maxntid", f"1, {options.num_warps*32}") - # kernels[0].add_fn_attr("nvvm.kernel", "1") # Get some metadata metadata["shared"] = src.get_int_attr("triton_gpu.shared") From 14800bfa0cfa0930fcaed22fc1e91a8e13430580 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Mon, 20 May 2024 20:18:03 -0400 Subject: [PATCH 4/4] [CI] Reduce usages of `check_cuda_or_hip` (#3955) Reduce usages of `check_cuda_or_hip`, so that more tests can be run for downstream backends by default. Signed-off-by: Whitney Tsang --- python/test/unit/language/test_core.py | 75 ++++++++++---------------- 1 file changed, 29 insertions(+), 46 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 934732147b..0bddb84cc3 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1347,13 +1347,7 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device): if is_interpreter(): if dtype_x_str == 'float16': pytest.skip("Only test atomic float16 ops on GPU") - else: - check_cuda_or_hip(device) - capability = torch.cuda.get_device_capability() - if capability[0] < 7: - if dtype_x_str == 'float16': - pytest.skip("Only test atomic float16 ops on devices with sm >= 70") n_programs = 5 # triton kernel @@ -3015,25 +3009,24 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty if in_dtype == 'bfloat16': pytest.skip("bfloat16 is not supported in the interpreter") else: - check_cuda_or_hip(device) - - capability = torch.cuda.get_device_capability() - - if capability[0] < 7: - pytest.skip("Only test tl.dot() on devices with sm >= 70") - if capability[0] < 8: - if capability[1] == 0 and in_dtype == 'int8': - pytest.skip("Only test int8 on devices with sm >= 75") - if input_precision != "ieee": - pytest.skip("Only test tf32 on devices with sm >= 80") - if capability[0] == 7: - if (M, N, K, num_warps) in [(128, 256, 32, 8), (64, 128, 128, 4), (64, 128, 128, 2)]: - pytest.skip("shared memory out of resource") - if out_dtype == 'float16': - # TODO: support out_dtype=float16 for tl.dot on V100 - pytest.skip("Only test out_dtype=float16 on devices with sm >=80") - if capability[0] < 9 and in_dtype == 'float8e4nv': - pytest.skip("float8e4nv not supported on sm <= 80") + if is_cuda(): + capability = torch.cuda.get_device_capability() + + if capability[0] < 7: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + if capability[0] < 8: + if capability[1] == 0 and in_dtype == 'int8': + pytest.skip("Only test int8 on devices with sm >= 75") + if input_precision != "ieee": + pytest.skip("Only test tf32 on devices with sm >= 80") + if capability[0] == 7: + if (M, N, K, num_warps) in [(128, 256, 32, 8), (64, 128, 128, 4), (64, 128, 128, 2)]: + pytest.skip("shared memory out of resource") + if out_dtype == 'float16': + # TODO: support out_dtype=float16 for tl.dot on V100 + pytest.skip("Only test out_dtype=float16 on devices with sm >=80") + if capability[0] < 9 and in_dtype == 'float8e4nv': + pytest.skip("float8e4nv not supported on sm <= 80") if is_hip() and (in_dtype == 'float8e4nv' or in_dtype == 'float8e5'): pytest.skip("float8e4nv and float8e5 not supported on HIP") if is_hip() and (input_precision != "ieee"): @@ -4220,10 +4213,8 @@ def kernel(X, Y, BLOCK: tl.constexpr): @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_inline_asm(num_ctas, device): - check_cuda_or_hip(device) - - if is_hip(): - pytest.skip("test_inline_asm is not supported in HIP") + if not is_cuda(): + pytest.skip("test_inline_asm is only supported in CUDA") @triton.jit def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr): @@ -4250,10 +4241,8 @@ def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr): @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_inline_asm_packed(num_ctas, device): - check_cuda_or_hip(device) - - if is_hip(): - pytest.skip("test_inline_asm is not supported in HIP") + if not is_cuda(): + pytest.skip("test_inline_asm is only supported in CUDA") @triton.jit def kernel(X, Y, BLOCK: tl.constexpr): @@ -4279,10 +4268,8 @@ def kernel(X, Y, BLOCK: tl.constexpr): @pytest.mark.parametrize('num_ctas', num_ctas_list) def test_inline_asm_with_pointers(num_ctas, device): - check_cuda_or_hip(device) - - if is_hip(): - pytest.skip('test_inline_asm is not supported in HIP') + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') @triton.jit def kernel(X, Y, BLOCK: tl.constexpr): @@ -4306,9 +4293,8 @@ def kernel(X, Y, BLOCK: tl.constexpr): def test_inline_asm_multiple_outputs(device): - check_cuda_or_hip(device) - if is_hip(): - pytest.skip('This test uses PTX inline assembly, so is not compatible with AMD') + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') @triton.jit def kernel(A, B, C, D, BLOCK: tl.constexpr): @@ -4353,9 +4339,8 @@ def kernel(A, B, C, D, BLOCK: tl.constexpr): def test_inline_asm_packed_multiple_outputs(device): - check_cuda_or_hip(device) - if is_hip(): - pytest.skip('This test uses PTX inline assembly, so is not compatible with AMD') + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') @triton.jit def kernel(A, B, C, D, BLOCK: tl.constexpr): @@ -4697,7 +4682,6 @@ def nested_while(data, countPtr): def test_num_threads(device): if is_hip(): pytest.skip("test_num_threads is not supported in HIP") - check_cuda_or_hip(device) @triton.jit def kernel(Out): @@ -5087,8 +5071,7 @@ def matmul_kernel( # def test_fp8_dot_acc(in_type_str, low_precision_acc, device): if is_hip(): pytest.skip('test_fp8_dot_acc for HIP currently broken in upstream.') - if not is_interpreter(): - check_cuda_or_hip(device) + if is_cuda(): cc = torch.cuda.get_device_capability() if cc[0] >= 9 and in_type_str == "float8e4b15": pytest.skip("Dot op does not support fp8e4b15 on CUDA arch >= 90")