diff --git a/.github/workflows/test-backends.yml b/.github/workflows/test-backends.yml index d2ab76f17b..332758f0a2 100644 --- a/.github/workflows/test-backends.yml +++ b/.github/workflows/test-backends.yml @@ -55,16 +55,6 @@ jobs: python3 setup.py build python3 -m pip install --no-build-isolation -vvv '.[tests]' - - name: Run shared middle-layer lit tests - run: | - python3 -m pip install lit - cd python - LIT_TEST_DIR="build/$(ls build | grep -i cmake)/third_party/triton_shared/test" - if [ ! -d "${LIT_TEST_DIR}" ]; then - echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1 - fi - lit -v "${LIT_TEST_DIR}" - Integration-Tests-AMD: needs: Runner-Preparation diff --git a/.gitignore b/.gitignore index d533f6f099..0b08ef149e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,9 +6,15 @@ python/build/ python/triton.egg-info/ python/triton/_C/libtriton.pyd python/triton/_C/libtriton.so -python/triton/backends/cuda +python/triton/backends/nvidia python/triton/backends/xpu +# Backends copied from submodules +python/triton/backends/ +!python/triton/backends/__init__.py +!python/triton/backends/compiler.py +!python/triton/backends/driver.py + # Python caches __pycache__/ *.py[cod] @@ -46,3 +52,6 @@ docs/getting-started/tutorials /compile_commands.json .vscode .vs + +# Vim +*.swp diff --git a/.gitmodules b/.gitmodules index e69de29bb2..d964b77672 100644 --- a/.gitmodules +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/amd"] + path = third_party/amd + url = https://github.com/ptillet/triton.git diff --git a/CMakeLists.txt b/CMakeLists.txt index d9a1586f9f..bee30c80a7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,8 @@ if(NOT WIN32) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") endif() + + # Options option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON) option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) @@ -82,8 +84,41 @@ include(TableGen) # required by AddMLIR include(AddLLVM) include(AddMLIR) +# Utilities +function(add_triton_object name) + cmake_parse_arguments(ARG "" "" "DEPENDS;LINK_LIBS" ${ARGN}) + add_library(${name} OBJECT) + target_sources(${name} + PRIVATE ${ARG_UNPARSED_ARGUMENTS} + INTERFACE $ + ) + + + # add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS}) + if(ARG_DEPENDS) + add_dependencies(${name} ${ARG_DEPENDS}) + endif() + if(ARG_LINK_LIBS) + target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS}) + endif() +endfunction(add_triton_object) + +set_property(GLOBAL PROPERTY TRITON_LIBS "") +function(add_triton_library name) + set_property(GLOBAL APPEND PROPERTY TRITON_LIBS ${name}) + add_triton_object(${name} ${ARGN}) + llvm_update_compile_flags(${name}) +endfunction() + +set_property(GLOBAL PROPERTY TRITON_PLUGINS "") +function(add_triton_plugin name) + set_property(GLOBAL APPEND PROPERTY TRITON_PLUGINS ${name}) + add_triton_object(${name} ${ARGN}) +endfunction() + + # Disable warnings that show up in external code (gtest;pybind11) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") include_directories(".") include_directories(${MLIR_INCLUDE_DIRS}) @@ -99,9 +134,6 @@ add_subdirectory(lib) set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) - # TODO: Figure out which target is sufficient to fix errors; triton is # apparently not enough. Currently set linking libstdc++fs for all targets # to support some old version GCC compilers like 8.3.0. @@ -128,33 +160,26 @@ if(TRITON_BUILD_PYTHON_MODULE) add_link_options(${Python3_LINK_OPTIONS}) endif() - set(TRITON_CODEGEN_BACKENDS "xpu") foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS}) add_subdirectory(third_party/${CODEGEN_BACKEND}) endforeach() + + get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) + get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS) set(TRITON_LIBRARIES - TritonIR - TritonAnalysis - TritonTransforms - TritonToTritonGPU - TritonGPUIR - TritonGPUTransforms - TritonLLVMIR - TritonNvidiaGPUIR - MLIRAMDGPUDialect - TritonAnalysis - NVGPUToLLVM - TritonNvidiaGPUTransforms - TritonGPUToLLVM + ${triton_libs} + ${triton_plugins} TritonSPIRV + + # mlir + MLIRAMDGPUDialect MLIRNVVMDialect MLIRNVVMToLLVMIRTranslation MLIRGPUToNVVMTransforms MLIRGPUToGPURuntimeTransforms MLIRGPUTransforms - - # optimizations + MLIRIR MLIRControlFlowToLLVM MLIRBytecodeWriter MLIRPass @@ -166,7 +191,9 @@ if(TRITON_BUILD_PYTHON_MODULE) MLIRROCDLToLLVMIRTranslation MLIRGENXToLLVMIRTranslation MLIRGPUDialect - MLIRIR + MLIRSCFToControlFlow + MLIRIndexToLLVM + MLIRGPUToROCDLTransforms # LLVM LLVMPasses @@ -180,12 +207,14 @@ if(TRITON_BUILD_PYTHON_MODULE) ) # Define triton library + string(JOIN "," TRITON_BACKENDS_TUPLE ${TRITON_CODEGEN_BACKENDS}) + set(TRITON_BACKENDS_TUPLE "(${TRITON_BACKENDS_TUPLE})") + add_compile_definitions(TRITON_BACKENDS_TUPLE=${TRITON_BACKENDS_TUPLE}) add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/ir.cc ${PYTHON_SRC_PATH}/passes.cc ${PYTHON_SRC_PATH}/interpreter.cc - ${PYTHON_SRC_PATH}/llvm.cc - ${CMAKE_CURRENT_SOURCE_DIR}/third_party/xpu/triton_xpu.cc) + ${PYTHON_SRC_PATH}/llvm.cc) # Link triton with its dependencies target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES}) @@ -195,7 +224,6 @@ if(TRITON_BUILD_PYTHON_MODULE) target_link_libraries(triton PRIVATE z) endif() target_link_options(triton PRIVATE ${LLVM_LDFLAGS} ${GenISAIntrinsics_LDFLAGS}) - set_target_properties(triton PROPERTIES INTERFACE_LINK_LIBRARIES "") endif() if(UNIX AND NOT APPLE) @@ -210,7 +238,7 @@ if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32) set(PYTHON_LDFLAGS "-undefined dynamic_lookup -flto") endif() - target_link_libraries(triton ${PYTHON_LDFLAGS}) + target_link_libraries(triton PRIVATE ${PYTHON_LDFLAGS}) endif() add_subdirectory(bin) diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 91e76b5c38..7f358917b2 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -12,6 +12,7 @@ target_link_libraries(triton-opt PRIVATE TritonNvidiaGPUTransforms ${dialect_libs} ${conversion_libs} + ${triton_libs} # tests TritonTestAnalysis # MLIR core @@ -33,6 +34,7 @@ target_link_libraries(triton-reduce PRIVATE TritonNvidiaGPUTransforms ${dialect_libs} ${conversion_libs} + ${triton_libs} # tests TritonTestAnalysis # MLIR core @@ -43,6 +45,29 @@ target_link_libraries(triton-reduce PRIVATE mlir_check_all_link_libraries(triton-reduce) +add_llvm_executable(triton-lsp triton-lsp.cpp PARTIAL_SOURCES_INTENDED) +mlir_check_all_link_libraries(triton-lsp) + +llvm_update_compile_flags(triton-lsp) +target_link_libraries(triton-lsp PRIVATE + TritonAnalysis + TritonTransforms + TritonGPUTransforms + TritonNvidiaGPUTransforms + ${dialect_libs} + ${conversion_libs} + ${triton_libs} + # tests + TritonTestAnalysis + # MLIR core + MLIRLspServerLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(triton-lsp) + + add_llvm_executable(triton-llvm-opt triton-llvm-opt.cpp diff --git a/bin/triton-lsp.cpp b/bin/triton-lsp.cpp new file mode 100644 index 0000000000..b185b03748 --- /dev/null +++ b/bin/triton-lsp.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + mlir::MLIRContext context(registry); + return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); +} diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index 97a22d97e7..f65daa1ddc 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -97,5 +97,14 @@ def TT_RoundingModeAttr : I32EnumAttr< let cppNamespace = "::mlir::triton"; } +// PropagateNan. +def TT_PropagateNanAttr : I32EnumAttr< + "PropagateNan", "", + [ + I32EnumAttrCase<"NONE", 0, "none">, + I32EnumAttrCase<"ALL", 0xFFFF, "all">, + ]> { + let cppNamespace = "::mlir::triton"; +} #endif diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index b757de9feb..34d2a11577 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -100,6 +100,24 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, let hasVerifier = 1; } +def TT_ClampFOp : TT_Op<"clampf", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Clamp operation for floating point types"; + + let description = [{ + Clamp operation for floating point types. + + The operation takes three arguments: x, min, and max. It returns a tensor of the same shape as x with its values clamped to the range [min, max]. + }]; + + let arguments = (ins TT_FloatLike:$x, TT_FloatLike:$min, TT_FloatLike:$max, TT_PropagateNanAttr:$propagateNan); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$x `,` $min `,` $max attr-dict `:` type($result)"; +} + // // Pointer Arith Ops // diff --git a/include/triton/Target/PTX/TmaMetadata.h b/include/triton/Target/PTX/TmaMetadata.h index eb11a74693..0eb9e14377 100644 --- a/include/triton/Target/PTX/TmaMetadata.h +++ b/include/triton/Target/PTX/TmaMetadata.h @@ -24,7 +24,7 @@ #ifndef TRITON_TARGET_PTX_TMAMETADATA_H #define TRITON_TARGET_PTX_TMAMETADATA_H -#include "third_party/cuda/backend/include/cuda.h" +#include "third_party/nvidia/backend/include/cuda.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Format.h" diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index aecc2345ac..a84f0649b6 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_library(TritonAnalysis +add_triton_library(TritonAnalysis AxisInfo.cpp Allocation.cpp Membar.cpp @@ -10,7 +10,6 @@ add_mlir_library(TritonAnalysis TritonGPUAttrDefsIncGen LINK_LIBS PUBLIC - ASMBuilder MLIRAnalysis MLIRLLVMDialect TritonIR diff --git a/lib/Conversion/NVGPUToLLVM/CMakeLists.txt b/lib/Conversion/NVGPUToLLVM/CMakeLists.txt index 9af2636866..153a9d6de3 100644 --- a/lib/Conversion/NVGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/NVGPUToLLVM/CMakeLists.txt @@ -1,16 +1,9 @@ -add_mlir_conversion_library(NVGPUToLLVM +add_triton_library(NVGPUToLLVM NVGPUToLLVMPass.cpp - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/triton/Conversion/NVGPUToLLVM - ${PROJECT_BINARY_DIR}/include/triton/Conversion/NVGPUToLLVM - DEPENDS NVGPUConversionPassIncGen - LINK_COMPONENTS - Core - LINK_LIBS PUBLIC MLIRIR MLIRPass diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 76a2e3438f..d6349d86a9 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -1,10 +1,4 @@ -# Separate out PTX/GCN builders to avoid cyclic dependencies as TritonAnalysis -# depends on it. -set(LLVM_OPTIONAL_SOURCES - PTXAsmFormat.cpp - ) - -add_mlir_conversion_library(TritonGPUToLLVM +add_triton_library(TritonGPUToLLVM ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp @@ -16,8 +10,8 @@ add_mlir_conversion_library(TritonGPUToLLVM DotOpToLLVM/MMAv2.cpp DotOpToLLVM/WGMMA.cpp DotOpToLLVM.cpp - ElementwiseOpToLLVM.cpp HistogramOpToLLVM.cpp + ElementwiseOpToLLVM.cpp LoadStoreOpToLLVM.cpp BarrierOpToLLVM.cpp TritonGPUToLLVM.cpp @@ -30,19 +24,12 @@ add_mlir_conversion_library(TritonGPUToLLVM TensorPtrOpsToLLVM.cpp ClusterOpsToLLVM.cpp RegReallocOpToLLVM.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM - ${PROJECT_BINARY_DIR}/include/triton/Conversion/TritonGPUToLLVM + PTXAsmFormat.cpp DEPENDS TritonGPUConversionPassIncGen - LINK_COMPONENTS - Core - LINK_LIBS PUBLIC - ASMBuilder MLIRIR MLIRPass MLIRGENXDialect @@ -58,14 +45,3 @@ add_mlir_conversion_library(TritonGPUToLLVM TritonNvidiaGPUTransforms NVGPUIR ) - -add_mlir_library(ASMBuilder - PTXAsmFormat.cpp - - DEPENDS - TritonTableGen - - LINK_LIBS PUBLIC - MLIRAnalysis - MLIRLLVMDialect -) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index d5136894ad..dbd4ca46d7 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -658,7 +658,13 @@ struct ConvertLayoutOpConversion processReplicaForMMAV1(loc, rewriter, /*stNotRd*/ true, srcTy, multiDimRepId, inVec, paddedRepShape, outOrd, vals, smemBase, shape); - else + else if (isStMatrixCompatible(srcTy) && accumNumReplicates == 1 && + outOrd[0] == 1 && paddedRepShape[1] % 8 == 0) { + Value llvmSrc = adaptor.getSrc(); + storeDistributedToSharedWithStMatrix(srcTy, llvmSrc, smemBase, + paddedRepShape, origRepShape, + loc, rewriter); + } else processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd, vals, smemBase); @@ -747,6 +753,104 @@ struct ConvertLayoutOpConversion return success(); } + Value computeStMatrixAddr(Value laneId, int matStride, Location loc, + ConversionPatternRewriter &rewriter) const { + Value rowInMat = urem(laneId, i32_val(8)); // row in the 8x8 matrix + // linear index of the matrix in the 2x2 matrices + // Decompose matIndex => s_0, s_1, that is the coordinate in 2x2 matrices in + // a warp. + Value matIndex = udiv(laneId, i32_val(8)); + Value s0 = urem(matIndex, i32_val(2)); + Value s1 = udiv(matIndex, i32_val(2)); + Value mIndex = add(rowInMat, mul(s0, i32_val(8))); + int m8n8Stride = 8; + Value offset = + add(mul(mIndex, i32_val(matStride)), mul(s1, i32_val(m8n8Stride))); + return offset; + } + + void stMatrixm8n8x4(Value offset, ArrayRef vals, int indexOffset, + Value smemBase, Type elemTy, Location loc, + ConversionPatternRewriter &rewriter) const { + SmallVector inputs; + auto prTy = ptr_ty(rewriter.getContext(), 3); + // Pack the input into 2xf16 + Type packedTy = vec_ty(vals[0].getType(), 2); + for (int i = 0; i < 4; i++) { + Value input = undef(packedTy); + for (int j = 0; j < 2; j++) { + input = insert_element(packedTy, input, vals[indexOffset + i * 2 + j], + i32_val(j)); + } + inputs.push_back(bitcast(input, i32_ty)); + } + Value addr = gep(smemBase.getType(), + getTypeConverter()->convertType(elemTy), smemBase, offset); + rewriter.create(loc, addr, inputs); + } + + void storeDistributedToSharedWithStMatrix( + RankedTensorType tensorTy, Value llvmSrc, Value smemBase, + ArrayRef paddedRepShape, ArrayRef origRepShape, + Location loc, ConversionPatternRewriter &rewriter) const { + auto shapePerCTA = getShapePerCTA(tensorTy); + auto mmaLayout = tensorTy.getEncoding().cast(); + auto order = triton::gpu::getOrder(mmaLayout); + auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); + auto shapePerCTATile = getShapePerCTATile(mmaLayout); + ArrayRef mmaShape = mmaLayout.getInstrShape(); + // 4xm8n8 matches exactly the size of 1 warp of wgmma layout for 16bit type + // and has a shape of 16x16. + int instrN = mmaShape[1] * warpsPerCTA[1]; + int instrM = mmaShape[0] * warpsPerCTA[0]; + std::array numRep = {ceil((int)origRepShape[0], instrM), + ceil((int)origRepShape[1], instrN)}; + + Value thread = getThreadId(rewriter, loc); + Value warp = udiv(thread, i32_val(32)); + Value lane = urem(thread, i32_val(32)); + + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warp, warpsPerCTA); + + auto inVals = getTypeConverter()->unpackLLElements(loc, llvmSrc, rewriter); + // Compute the relative offset for each lane. + Value stMatrixLaneOffset = + computeStMatrixAddr(lane, paddedRepShape[1], loc, rewriter); + multiDimWarpId[0] = mul(multiDimWarpId[0], i32_val(mmaShape[0])); + multiDimWarpId[1] = mul(multiDimWarpId[1], i32_val(mmaShape[1])); + SmallVector multiDimOffsetWrapped = + getWrappedMultiDimOffset(rewriter, loc, multiDimWarpId, origRepShape, + shapePerCTATile, shapePerCTA); + Value relativeOffset = + linearize(rewriter, loc, multiDimOffsetWrapped, paddedRepShape, order); + relativeOffset = add(relativeOffset, stMatrixLaneOffset); + int indexOffset = 0; + int m8n8x4Stride = 16; + int numNChunk = mmaShape[1] / m8n8x4Stride; + for (int m = 0; m < numRep[0]; m++) { + for (int n = 0; n < numRep[1]; n++) { + for (int k = 0; k < numNChunk; k++) { + Value addr = + add(relativeOffset, i32_val(k * m8n8x4Stride + n * instrN + + m * instrM * paddedRepShape[1])); + stMatrixm8n8x4(addr, inVals, indexOffset, smemBase, + tensorTy.getElementType(), loc, rewriter); + indexOffset += 8; + } + } + } + } + + bool isStMatrixCompatible(RankedTensorType tensorTy) const { + auto mmaLayout = tensorTy.getEncoding().dyn_cast(); + if (!mmaLayout || !mmaLayout.isHopper()) + return false; + if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16) + return false; + return true; + } + // blocked -> shared. // Swizzling in shared memory to avoid bank conflict. Normally used for // A/B operands of dots. @@ -774,97 +878,11 @@ struct ConvertLayoutOpConversion int32_t elemSize = elemTy.getIntOrFloatBitWidth(); auto mmaLayout = srcLayout.dyn_cast(); unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); - if (mmaLayout && mmaLayout.isHopper() && elemSize == 16 && - inOrd == outOrd && numElems >= 16) { - auto inVals = - getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(), rewriter); - - auto srcShapePerCTA = getShapePerCTA(mmaLayout, srcShape); - auto instrShape = mmaLayout.getInstrShape(); - auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); - uint32_t repM = - ceil(srcShapePerCTA[0], instrShape[0] * warpsPerCTA[0]); - uint32_t numElemsPerRep = numElems / repM; - // rowStride in bytes - uint32_t rowStrideInBytes = dstShapePerCTA[outOrd[0]] * 2; - uint32_t swizzlingByteWidth = rowStrideInBytes; - if (swizzlingByteWidth > 128) - swizzlingByteWidth = 128; - - unsigned numElemsPerSwizzlingRow = swizzlingByteWidth * 8 / elemSize; - unsigned leadingDimOffset = - numElemsPerSwizzlingRow * srcShapePerCTA[outOrd[1]]; - - auto ptrSharedTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - - uint32_t rowsPerRep = getShapePerCTATile(mmaLayout)[0]; - - Value threadId = getThreadId(rewriter, loc); - Value warpId = udiv(threadId, i32_val(32)); - Value warpId0 = urem(urem(warpId, i32_val(warpsPerCTA[0])), - i32_val(srcShape[0] / instrShape[0])); - - unsigned inVec = - inOrd == outOrd ? triton::gpu::getContigPerThread(mmaLayout)[inOrd[0]] - : 1; - unsigned outVec = dstSharedLayout.getVec(); - unsigned minVec = std::min(outVec, inVec); - assert(minVec == 2); - auto wordTy = vec_ty(elemTy, minVec); - - for (int rep = 0; rep < repM; ++rep) { - Value rowOfWarp = add(mul(warpId0, i32_val(instrShape[0])), - i32_val(rep * rowsPerRep)); - uint32_t elemIdxOffset = rep * numElemsPerRep; - - for (unsigned idx = 0; idx < numElemsPerRep; idx += 8) { - uint32_t elemIdx = elemIdxOffset + idx; - - Value offset = rewriter.create( - loc, i32_ty, threadId, rowOfWarp, i32_val(idx), leadingDimOffset, - numElemsPerSwizzlingRow, true); - - Value addr = gep(elemPtrTy, getTypeConverter()->convertType(elemTy), - smemBase, offset); - - Value words[4]; - for (unsigned i = 0; i < 8; ++i) { - if (i % minVec == 0) - words[i / 2] = undef(wordTy); - words[i / 2] = insert_element( - wordTy, words[i / 2], inVals[elemIdx + i], i32_val(i % minVec)); - } - - rewriter.create( - loc, bitcast(addr, ptrSharedTy), - ValueRange{bitcast(words[0], i32_ty), bitcast(words[1], i32_ty), - bitcast(words[2], i32_ty), bitcast(words[3], i32_ty)}); - } - } - // TODO[shuhaoj]: change hard code style of numThreads. Hide async_agent - // attr. Better way to determine barId (number of agents are limited). - if (auto optionalAgentId = getWSAgentId(op)) { - int agentId = *optionalAgentId, roleId = 0; - if (auto optionalRoleId = getWSRoleId(op)) - roleId = *optionalRoleId; - int barId = agentId + roleId + nameBarrierIdBegin; - assert(barId < nameBarrierIdEnd); - auto bar = rewriter.create( - loc, i32_ty, rewriter.getI32IntegerAttr(barId)); - auto kNumThreads = rewriter.create( - loc, i32_ty, rewriter.getI32IntegerAttr(128)); - rewriter.create(loc, bar, - kNumThreads); - } else { - barrier(); - } - } else { - auto dstStrides = - getStridesFromShapeAndOrder(dstShapePerCTA, outOrd, loc, rewriter); - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy, false); - storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices, - dst, smemBase, elemTy, loc, rewriter); - } + auto dstStrides = + getStridesFromShapeAndOrder(dstShapePerCTA, outOrd, loc, rewriter); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy, false); + storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices, dst, + smemBase, elemTy, loc, rewriter); auto smemObj = SharedMemoryObject(smemBase, elemTy, dstShapePerCTA, outOrd, loc, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index e9e17c2d4b..2e6ddca9d5 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -2286,6 +2286,117 @@ struct AbsFOpConversion } }; +struct ClampFOpConversion + : ElementwiseOpConversionBase { + using Base = + ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + explicit ClampFOpConversion(TritonGPUToLLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + int computeCapability, Target target, + PatternBenefit benefit = 1) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, target, + benefit), + computeCapability(computeCapability) {} + + SmallVector createDestOps(mlir::triton::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + bool xorsignAbsAvailable = (computeCapability >= 90); + // Pattern matching the sequence of clamp(x, -limit, limit) to generate more + // efficient PTX code. + // NOTE: This pattern matching is not general enough, but it is sufficient. + // We detect only two cases here: + // 1. where the "-limit" is computed as 0 - limit: + // %cst = arith.constant dense<0.000000e+00> + // %8 = tt.load %7, %2 + // %11 = arith.subf %cst, %8 + // %12 = tt.clamp %5, %11, %8 + // 2. where "-limit" and "limit" are constants. + // %cst_6 = arith.constant dense<-6.0000e+00> + // %cst_7 = arith.constant dense<6.0000e+00> + // %160 = tt.clamp %158, %cst_6, %cst_7 + bool clipPatternFound = false; + + auto getSplatInitializer = [](Value v) -> std::optional { + if (auto constOp = v.getDefiningOp()) { + if (auto attr = constOp.getValueAttr() + .dyn_cast()) { + if (attr.isSplat()) { + return attr.getSplatValue().convertToDouble(); + } + } + } + return std::nullopt; + }; + + if (xorsignAbsAvailable) { + if (auto subOp = op.getOperand(1).getDefiningOp()) { + if (subOp.getOperand(1) == op.getOperand(2)) { + auto initializer = getSplatInitializer(subOp.getOperand(0)); + if (initializer.has_value() && initializer.value() == 0.0) { + clipPatternFound = true; + } + } + } else { + auto initializer1 = getSplatInitializer(op.getOperand(1)); + auto initializer2 = getSplatInitializer(op.getOperand(2)); + if (initializer1.has_value() && initializer2.has_value() && + initializer1.value() == -initializer2.value()) { + clipPatternFound = true; + } + } + } + + assert(elemTy.isF32() || elemTy.isF16()); + + if (clipPatternFound) { + // min.xorsign.abs + PTXBuilder ptxBuilder; + bool propNan = (op.getPropagateNan() == mlir::triton::PropagateNan::ALL); + auto &minXorsign = ptxBuilder.create("min") + ->o("NaN", propNan) + .o("xorsign") + .o("abs"); + const char *outType = nullptr; + const char *inType = nullptr; + if (elemTy.isF32()) { + minXorsign.o("f32"); + outType = "=f"; + inType = "f"; + } else if (elemTy.isF16()) { + minXorsign.o("f16"); + outType = "=h"; + inType = "h"; + } + auto output = ptxBuilder.newOperand(outType); + auto inputA = ptxBuilder.newOperand(operands[0][0], inType); + auto inputB = ptxBuilder.newOperand(operands[0][2], inType); + minXorsign(output, inputA, inputB); + + return {ptxBuilder.launch(rewriter, loc, elemTy, false)}; + } + + // Clip pattern not found, use min/max. + if (op.getPropagateNan() == triton::PropagateNan::ALL) { + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + return {rewriter.create(loc, v, operands[0][2])}; + } + + assert(op.getPropagateNan() == triton::PropagateNan::NONE); + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + return {rewriter.create(loc, v, operands[0][2])}; + } + +private: + int computeCapability; +}; + /// The lowering of index_cast becomes an integer conversion since index /// becomes an integer. If the bit width of the source and target integer /// types is the same, just erase the cast. If the target type is wider, @@ -2443,4 +2554,6 @@ void populateElementwiseOpToLLVMPatterns( // __nv_expf for higher-precision calculation patterns.add(typeConverter, axisInfoAnalysis, target, benefit); + patterns.add(typeConverter, axisInfoAnalysis, + computeCapability, target); } diff --git a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt index 834d10a4de..d770aeb22c 100644 --- a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -1,16 +1,9 @@ -add_mlir_conversion_library(TritonToTritonGPU +add_triton_library(TritonToTritonGPU TritonToTritonGPUPass.cpp - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonToTritonGPU - ${PROJECT_BINARY_DIR}/include/triton/Conversion/TritonToTritonGPU - DEPENDS TritonConversionPassIncGen - LINK_COMPONENTS - Core - LINK_LIBS PUBLIC MLIRIR MLIRPass diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index d17d110f9f..fb52f71203 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -511,6 +511,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, GenericOpPattern, TritonBroadcastPattern, GenericOpPattern, TritonCatPattern, TritonInterleaveOpPattern, + GenericOpPattern, GenericOpPattern, TritonReducePattern, GenericOpPattern, TritonScanPattern, GenericOpPattern, diff --git a/lib/Dialect/NVGPU/IR/CMakeLists.txt b/lib/Dialect/NVGPU/IR/CMakeLists.txt index 24a93ce58e..1fd118d2be 100644 --- a/lib/Dialect/NVGPU/IR/CMakeLists.txt +++ b/lib/Dialect/NVGPU/IR/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect_library(NVGPUIR +add_triton_library(NVGPUIR Dialect.cpp DEPENDS diff --git a/lib/Dialect/Triton/IR/CMakeLists.txt b/lib/Dialect/Triton/IR/CMakeLists.txt index 6ee110718c..71165b17fc 100644 --- a/lib/Dialect/Triton/IR/CMakeLists.txt +++ b/lib/Dialect/Triton/IR/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect_library(TritonIR +add_triton_library(TritonIR Dialect.cpp Ops.cpp Types.cpp diff --git a/lib/Dialect/Triton/Transforms/CMakeLists.txt b/lib/Dialect/Triton/Transforms/CMakeLists.txt index d06c01566c..2983987506 100644 --- a/lib/Dialect/Triton/Transforms/CMakeLists.txt +++ b/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -2,7 +2,7 @@ set(LLVM_TARGET_DEFINITIONS Combine.td) mlir_tablegen(TritonCombine.inc -gen-rewriters) add_public_tablegen_target(TritonCombineIncGen) -add_mlir_dialect_library(TritonTransforms +add_triton_library(TritonTransforms Combine.cpp ReorderBroadcast.cpp RewriteTensorPointer.cpp diff --git a/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/lib/Dialect/TritonGPU/IR/CMakeLists.txt index bab4a5dec4..82cf23f052 100644 --- a/lib/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect_library(TritonGPUIR +add_triton_library(TritonGPUIR Dialect.cpp Traits.cpp Types.cpp diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index a479c8aa50..fbf9e9c770 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -313,17 +313,17 @@ SmallVector getCTASplitNum(Attribute layout) { } SmallVector getCTAOrder(Attribute layout) { - ArrayRef ref; + SmallVector res; if (auto distributedLayout = layout.dyn_cast()) { - ref = distributedLayout.getCTAOrder(); + res = distributedLayout.getCTAOrder(); } else if (auto mfmaLayout = layout.dyn_cast()) { return {0, 1}; } else if (auto sharedLayout = layout.dyn_cast()) { - ref = sharedLayout.getCTALayout().getCTAOrder(); + res = SmallVector(sharedLayout.getCTALayout().getCTAOrder()); } else { llvm::report_fatal_error("Unimplemented usage of getCTAOrder"); } - return SmallVector(ref.begin(), ref.end()); + return res; } SmallVector getShapePerCTA(ArrayRef CTASplitNum, diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 2cd6e26725..c93feb8151 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect_library(TritonGPUTransforms +add_triton_library(TritonGPUTransforms AccelerateMatmul.cpp Coalesce.cpp DecomposeConversions.cpp diff --git a/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt b/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt index 99f2ef6b70..4369542327 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect_library(TritonNvidiaGPUIR +add_triton_library(TritonNvidiaGPUIR Dialect.cpp Ops.cpp Traits.cpp diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt index 53674ebfc6..a147b7e996 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_dialect_library(TritonNvidiaGPUTransforms +add_triton_library(TritonNvidiaGPUTransforms MaterializeLoadStore.cpp PlanCTA.cpp WSDecomposing.cpp diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt index 510cfab9c8..f2f9adf8f4 100644 --- a/lib/Target/LLVMIR/CMakeLists.txt +++ b/lib/Target/LLVMIR/CMakeLists.txt @@ -1,10 +1,7 @@ -add_mlir_translation_library(TritonLLVMIR +add_triton_library(TritonLLVMIR LLVMDIScope.cpp LLVMIRBreakPhiStruct.cpp - LINK_COMPONENTS - Core - DEPENDS LLVMIRIncGen diff --git a/python/setup.py b/python/setup.py index a23cf2b11b..caf564c57e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -15,6 +15,44 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext from setuptools.command.build_py import build_py +from dataclasses import dataclass + + +@dataclass +class Backend: + name: str + package_data: dict + src_dir: str + + +def _copy_backends(active): + ret = [] + root_dir = os.path.join(os.pardir, "third_party") + for backend in active: + curr_path = os.path.join(root_dir, backend) + backend_path = os.path.join(curr_path, "backend") + # initialize submodule if there is one + try: + subprocess.run(["git", "submodule", "update", "--init", f"{backend}"], check=True, + stdout=subprocess.DEVNULL, cwd=root_dir) + except subprocess.CalledProcessError: + pass + except FileNotFoundError: + pass + # check conditions + assert backend in os.listdir(root_dir), f"{backend} is requested for install but not present in {root_dir}" + assert os.path.exists(backend_path), f"{backend_path} does not exist!" + for file in ["compiler.py", "driver.py"]: + assert os.path.exists(os.path.join(backend_path, file)) + # copy backend over + dst_path = os.path.join(os.path.dirname(__file__), "triton", "backends", backend) + if os.path.exists(dst_path): + shutil.rmtree(dst_path) + shutil.copytree(backend_path, dst_path) + # update + package_data = [f"{os.path.relpath(p, backend_path)}/*" for p, _, _, in os.walk(backend_path)] + ret.append(Backend(name=backend, package_data=package_data, src_dir=curr_path)) + return ret # Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py @@ -106,6 +144,9 @@ def open_url(url): return urllib.request.urlopen(request) +# ---- package data --- + + def get_thirdparty_packages(triton_cache_path): packages = [get_pybind11_package_info(), get_llvm_package_info()] thirdparty_cmake_args = [] @@ -135,9 +176,6 @@ def get_thirdparty_packages(triton_cache_path): return thirdparty_cmake_args -# ---- package data --- - - def download_and_copy(src_path, variable, version, url_func): if variable in os.environ: return @@ -146,7 +184,7 @@ def download_and_copy(src_path, variable, version, url_func): if arch == "x86_64": arch = "64" url = url_func(arch, version) - dst_path = os.path.join(base_dir, os.pardir, "third_party", "cuda", "backend", src_path) + dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", src_path) is_linux = platform.system() == "Linux" download = False if is_linux: @@ -249,20 +287,14 @@ def build_extension(self, ext): # python directories python_include_dir = sysconfig.get_path("platinclude") cmake_args = [ - "-G", - "Ninja", # Ninja is much faster than make + "-G", "Ninja", # Ninja is much faster than make "-DCMAKE_MAKE_PROGRAM=" + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path - "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", - "-DLLVM_ENABLE_WERROR=ON", - "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, - "-DLLVM_SPIRV_DIR=" + llvm_spirv_path, - "-DTRITON_BUILD_TUTORIALS=OFF", - "-DTRITON_BUILD_PYTHON_MODULE=ON", - "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable, - "-DCMAKE_EXPORT_COMPILE_COMMANDS:BOOL=ON", - "-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON", - "-DPYTHON_INCLUDE_DIRS=" + python_include_dir, + "-DLLVM_SPIRV_DIR=" + llvm_spirv_path, "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DLLVM_ENABLE_WERROR=ON", + "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTRITON_BUILD_TUTORIALS=OFF", + "-DTRITON_BUILD_PYTHON_MODULE=ON", "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable, + "-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON", "-DPYTHON_INCLUDE_DIRS=" + python_include_dir, + "-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends]) ] if lit_dir is not None: cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir) @@ -345,14 +377,11 @@ def build_extension(self, ext): url_func=lambda arch, version: f"https://anaconda.org/nvidia/cuda-nvdisasm/12.3.52/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", ) +backends = _copy_backends(["xpu"]) -plugins = ["xpu"] -for plugin in plugins: - src_path = os.path.join(os.pardir, "third_party", plugin, "backend") - dst_path = os.path.join(os.path.dirname(__file__), "triton", "backends", plugin) - if os.path.exists(dst_path): - shutil.rmtree(dst_path) - shutil.copytree(src_path, dst_path) +package_data = dict() +package_data["triton/tools"] = ["compile.h", "compile.c"] +package_data.update({f"triton/backends/{b.name}": b.package_data for b in backends}) setup( name=os.environ.get("TRITON_WHEEL_NAME", "triton"), @@ -371,14 +400,10 @@ def build_extension(self, ext): "triton/ops/blocksparse", "triton/runtime", "triton/backends", - "triton/backends/xpu", "triton/tools", - ], + ] + [f'triton/backends/{backend.name}' for backend in backends], install_requires=["filelock"], - package_data={ - "triton/tools": ["compile.h", "compile.c"], - "triton/backends/xpu": ["bin/*", "lib/*", "include/*"], - }, + package_data=package_data, include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], cmdclass={"build_ext": CMakeBuild, "build_py": CMakeBuildPy, "clean": CMakeClean}, diff --git a/python/src/ir.cc b/python/src/ir.cc index e8b5340a51..d7c2db0b3c 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -184,6 +184,10 @@ void init_triton_ir(py::module &&m) { .value("RTZ", mlir::triton::RoundingMode::RTZ) .value("RTNE", mlir::triton::RoundingMode::RTNE); + py::enum_(m, "PROPAGATE_NAN", py::module_local()) + .value("NONE", mlir::triton::PropagateNan::NONE) + .value("ALL", mlir::triton::PropagateNan::ALL); + py::class_(m, "context", py::module_local()) .def(py::init<>()); @@ -1037,6 +1041,13 @@ void init_triton_ir(py::module &&m) { mlir::Value &rhs) -> mlir::Value { return mlir::Value(self.create(lhs, rhs)); }) + .def("create_clampf", + [](TritonOpBuilder &self, mlir::Value &input, mlir::Value &min, + mlir::Value &max, + mlir::triton::PropagateNan propagateNan) -> mlir::Value { + return mlir::Value(self.create( + input, min, max, propagateNan)); + }) // AddPtr (similar to GEP) .def("create_addptr", [](TritonOpBuilder &self, mlir::Value &ptr, diff --git a/python/src/main.cc b/python/src/main.cc index 36f5b2679f..5ad4be7d55 100644 --- a/python/src/main.cc +++ b/python/src/main.cc @@ -1,12 +1,43 @@ #include namespace py = pybind11; +#define FOR_EACH_1(MACRO, X) MACRO(X) +#define FOR_EACH_2(MACRO, X, ...) MACRO(X) FOR_EACH_1(MACRO, __VA_ARGS__) +#define FOR_EACH_3(MACRO, X, ...) MACRO(X) FOR_EACH_2(MACRO, __VA_ARGS__) +#define FOR_EACH_4(MACRO, X, ...) MACRO(X) FOR_EACH_3(MACRO, __VA_ARGS__) + +#define FOR_EACH_NARG(...) FOR_EACH_NARG_(__VA_ARGS__, FOR_EACH_RSEQ_N()) +#define FOR_EACH_NARG_(...) FOR_EACH_ARG_N(__VA_ARGS__) +#define FOR_EACH_ARG_N(_1, _2, _3, _4, N, ...) N +#define FOR_EACH_RSEQ_N() 4, 3, 2, 1, 0 + +#define CONCATENATE(x, y) CONCATENATE1(x, y) +#define CONCATENATE1(x, y) x##y + +#define FOR_EACH(MACRO, ...) \ + CONCATENATE(FOR_EACH_, FOR_EACH_NARG_HELPER(__VA_ARGS__))(MACRO, __VA_ARGS__) +#define FOR_EACH_NARG_HELPER(...) FOR_EACH_NARG(__VA_ARGS__) + +// New macro to remove parentheses +#define REMOVE_PARENS(...) __VA_ARGS__ + +// Intermediate macro to ensure correct expansion +#define FOR_EACH_P_INTERMEDIATE(MACRO, ...) FOR_EACH(MACRO, __VA_ARGS__) + +// Modified FOR_EACH to handle parentheses +#define FOR_EACH_P(MACRO, ARGS_WITH_PARENS) \ + FOR_EACH_P_INTERMEDIATE(MACRO, REMOVE_PARENS ARGS_WITH_PARENS) + +#define DECLARE_BACKEND(name) void init_triton_##name(pybind11::module &&m); + +#define INIT_BACKEND(name) init_triton_##name(m.def_submodule(#name)); + void init_triton_env_vars(pybind11::module &m); void init_triton_ir(pybind11::module &&m); void init_triton_llvm(pybind11::module &&m); void init_triton_interpreter(pybind11::module &&m); void init_triton_passes(pybind11::module &&m); -void init_triton_xpu(pybind11::module &&m); +FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE) PYBIND11_MODULE(libtriton, m) { m.doc() = "Python bindings to the C++ Triton API"; @@ -15,5 +46,5 @@ PYBIND11_MODULE(libtriton, m) { init_triton_passes(m.def_submodule("passes")); init_triton_interpreter(m.def_submodule("interpreter")); init_triton_llvm(m.def_submodule("llvm")); - init_triton_xpu(m.def_submodule("xpu")); + FOR_EACH_P(INIT_BACKEND, TRITON_BACKENDS_TUPLE) } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index ce6252ac97..c96e51f6ac 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -821,7 +821,7 @@ def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): # TODO: Tests with unsigned integers failed at compilation stage. -@pytest.mark.parametrize("dtype", int_dtypes + float_dtypes + ["bfloat16"]) +@pytest.mark.parametrize("dtype", int_dtypes + uint_dtypes + float_dtypes + ["bfloat16"]) @pytest.mark.parametrize("op", ["maximum", "minimum"]) def test_maximum_minium(dtype, op, device): expr = f'tl.{op}(x, y)' @@ -1566,56 +1566,6 @@ def deserialize_fp8(np_data, in_dtype): return np_data -@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4b15x4, tl.float8e4nv, tl.float8e5]) -@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32]) -def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): - """ - For all possible float8 values (ref_fp8 = range(0, 256)), test that: - - conversion tri_fp16 = convert(input=ref_fp8, out=out_dtype) matches the reference - - conversion tri_fp8 = convert(input=tri_fp16, out=out_dtype) matches the original - this is only possible if both conversions are correct - """ - check_type_supported(in_dtype, device) - check_type_supported(out_dtype, device) - if is_hip(): - pytest.skip('test_fp8_fpN_roundtrip not supported on HIP.') - - @triton.jit - def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - input = tl.load(input_ptr + offsets, mask=mask) - output = input - tl.store(output_ptr + offsets, output, mask=mask) - - # initialize array containing all possible f8 values except NaN - ref_fp8 = np.array(range(-128, 128), dtype=np.int8) - exp_mask = 0b01111111 ^ ((1 << in_dtype.fp_mantissa_width) - 1) - is_nan = (ref_fp8 & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width - is_subnormal = np.logical_or((ref_fp8 & exp_mask) == 0, (ref_fp8 & exp_mask) == exp_mask) - tri_fp8 = torch.from_numpy(serialize_fp8(ref_fp8, in_dtype)).xpu() - # check that non-subnormal fp8 are correctly converted to fp16 - tri_fp16 = torch.empty(256, dtype=out_dtype, device=device) - copy_kernel[(1, )](triton.reinterpret(tri_fp8, in_dtype), tri_fp16, tri_fp16.shape[0], BLOCK_SIZE=1024) - ref_fp8 = torch.from_numpy(ref_fp8).xpu() - ref_fp16 = convert_float_to_float32(ref_fp8, in_dtype) - assert torch.all(tri_fp16[~is_subnormal] == ref_fp16[~is_subnormal]) - # check that values are properly converted back to float8 - ref_fp8 = torch.empty_like(tri_fp16, dtype=torch.int8, device=device) - copy_kernel[(1, )](tri_fp16, triton.reinterpret(ref_fp8, in_dtype), tri_fp16.shape[0], BLOCK_SIZE=1024) - if in_dtype == tl.float8e4b15: - assert torch.all(tri_fp8[:127] == ref_fp8[:127]) - assert torch.all(tri_fp8[128:255] == ref_fp8[128:255]) - if not is_xpu(device): - assert ref_fp8[126] == ref_fp8[127] # -1.875 saturates to -1.75 - assert ref_fp8[254] == ref_fp8[255] # 1.875 saturates to 1.75 - else: - import warnings - warnings.warn("Assertions above fails on XPU", RuntimeWarning) - else: - assert torch.all(tri_fp8[~is_subnormal] == ref_fp8[~is_subnormal]) - - # --------------- # test reduce # --------------- @@ -4428,32 +4378,107 @@ def mul_add(data): # ----------------------- -@pytest.mark.parametrize("propagate_nan", ['tl.PropagateNan.NONE', 'tl.PropagateNan.ALL']) -@pytest.mark.parametrize("func", ['tl.minimum', 'tl.maximum']) -def test_propagate_nan(propagate_nan, func, device): - if is_xpu(device) and propagate_nan == 'tl.PropagateNan.ALL': +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL']) +@pytest.mark.parametrize("func", ['minimum', 'maximum', 'clamp']) +def test_propagate_nan(dtype, propagate_nan, func, device): + if is_xpu(device) and propagate_nan == 'ALL': pytest.skip("FIXME: Incorrect result on XPU") @triton.jit - def kernel(A, B, C): - tl.store(C, FUNC(tl.load(A), tl.load(B), propagate_nan=PROPAGATE_NAN)) - - kernel = patch_kernel(kernel, {'FUNC': func, 'PROPAGATE_NAN': propagate_nan}) + def kernel(A, B, C, propagate_nan: tl.constexpr, func: tl.constexpr): + if func == 'clamp': + tl.store( + C, + getattr(tl, func)(tl.load(A), -tl.load(B), tl.load(B), + propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + else: + tl.store(C, + getattr(tl, func)(tl.load(A), tl.load(B), propagate_nan=getattr(tl.PropagateNan, propagate_nan))) for mode in ['A', 'B', 'both']: - A = torch.randn((1, ), device=device, dtype=torch.float32) + A = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) if mode == 'A' or mode == 'both': A[0] = torch.nan - B = torch.randn((1, ), device=device, dtype=torch.float32) + B = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) if mode == 'B' or mode == 'both': B[0] = torch.nan - C = torch.zeros_like(A, device=device, dtype=torch.float32) - kernel[(1, )](A, B, C) + C = torch.zeros_like(A, device=device, dtype=getattr(torch, dtype)) + kernel[(1, )](A, B, C, propagate_nan, func) - if mode == 'both' or eval(propagate_nan) == tl.PropagateNan.ALL: + if mode == 'both' or propagate_nan == 'ALL': assert torch.isnan(C[0]) else: assert not torch.isnan(C[0]) +# ----------------------- +# test clamp +# ----------------------- + + +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +def test_clamp(dtype, device): + + @triton.jit + def kernel(x_ptr, min_ptr, max_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + min = tl.load(min_ptr + off, mask=mask) + max = tl.load(max_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, min, max), mask=mask) + ref_val = tl.minimum(tl.maximum(x, min), max) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + a = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + b = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + min = torch.min(a, b) + max = torch.max(a, b) + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, min, max, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + +# Test for symmetric clamp(x, -limit, limit), as it may go through optimized +# codegen in the backends +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +def test_clamp_symmetric(dtype, device): + + @triton.jit + def kernel(x_ptr, limit_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + limit = tl.load(limit_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, -limit, limit), mask=mask) + ref_val = tl.minimum(tl.maximum(x, -limit), limit) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + limit = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)).abs() + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, limit, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + # ----------------------- # test sort # ----------------------- diff --git a/python/test/unit/language/test_line_info.py b/python/test/unit/language/test_line_info.py index 5992fe44e7..1d35d1afdc 100644 --- a/python/test/unit/language/test_line_info.py +++ b/python/test/unit/language/test_line_info.py @@ -166,9 +166,9 @@ def test_line_info(func: str): elif func == "multi_files": assert (check_file_lines(file_lines, "test_line_info.py", 48)) assert (check_file_lines(file_lines, "test_line_info.py", 50)) - assert (check_file_lines(file_lines, "standard.py", 35)) + assert (check_file_lines(file_lines, "standard.py", 33)) + assert (check_file_lines(file_lines, "standard.py", 34)) assert (check_file_lines(file_lines, "standard.py", 36)) - assert (check_file_lines(file_lines, "standard.py", 38)) elif func == "autotune": assert (check_file_lines(file_lines, "test_line_info.py", 61)) assert (check_file_lines(file_lines, "test_line_info.py", 62)) diff --git a/python/test/unit/tools/test_aot.py b/python/test/unit/tools/test_aot.py index 5ec600e92e..4565696575 100644 --- a/python/test/unit/tools/test_aot.py +++ b/python/test/unit/tools/test_aot.py @@ -7,7 +7,7 @@ import numpy as np import triton -from triton.backends.cuda.driver import include_dir, library_dir +from triton.backends.nvidia.driver import include_dir, library_dir kernel_utils_src = """ import triton diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index cabcec71b8..fbf65d9e90 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -1,5 +1,5 @@ import os -import importlib +import importlib.util import inspect from dataclasses import dataclass from .driver import DriverBase diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index b5e44ea03a..c4b5bf8864 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -4,7 +4,6 @@ from . import math from . import extra from .standard import ( - PropagateNan, argmax, argmin, cdiv, @@ -25,6 +24,7 @@ zeros_like, ) from .core import ( + PropagateNan, TRITON_MAX_TENSOR_NUMEL, abs, advance, @@ -45,6 +45,7 @@ cat, constexpr, cos, + clamp, debug_barrier, device_assert, device_print, @@ -137,6 +138,7 @@ "builtin", "cat", "cdiv", + "clamp", "constexpr", "cos", "cumprod", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 98f596c499..dfc057333c 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -14,6 +14,8 @@ TRITON_BUILTIN = "__triton_builtin__" +PropagateNan = ir.PROPAGATE_NAN + def builtin(fn: T) -> T: """Mark a function as a builtin.""" @@ -1359,6 +1361,36 @@ def fdiv(x, y, ieee_rounding=False, _builder=None): return semantic.fdiv(x, y, ieee_rounding, _builder) +@builtin +def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Clamps the input tensor :code:`x` within the range [min, max]. + Behavior when :code:`min` > :code:`max` is undefined. + + :param x: the input tensor + :type x: Block + :param min: the lower bound for clamping + :type min: Block + :param max: the upper bound for clamping + :type max: Block + :param propagate_nan: whether to propagate NaN values. Applies only to the :code:`x` tensor. + If either :code:`min` or :code:`max` is NaN, the result is undefined. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + min = _to_tensor(min, _builder) + max = _to_tensor(max, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + min = _promote_bfloat16_to_float32(min, _builder=_builder) + max = _promote_bfloat16_to_float32(max, _builder=_builder) + + propagate_nan = _constexpr_to_value(propagate_nan) + + return semantic.clamp(x, min, max, propagate_nan, _builder) + + def _add_math_1arg_docstr(name: str) -> Callable[[T], T]: def _decorator(func: T) -> T: diff --git a/python/triton/language/math.py b/python/triton/language/math.py index 7b3d3366e5..3c12be1683 100644 --- a/python/triton/language/math.py +++ b/python/triton/language/math.py @@ -1,5 +1,3 @@ -from enum import IntEnum - from . import core @@ -8,16 +6,6 @@ def is_spirv(): return torch.xpu.is_available() -class PropagateNan(IntEnum): - """ - PropagateNan is an enum class that specifies how NaNs are handled in min/max operations. - PropagateNan.ALL means that if either input is NaN, the result is NaN. PropagateNan.NONE - means that if either input is NaN, the result is the non-NaN input. This is the default. - """ - ALL = 0xFFFFFFFF - NONE = 0x00000000 - - @core.extern def clz(arg0, _builder=None): if is_spirv(): @@ -63,7 +51,7 @@ def byte_perm(arg0, arg1, arg2, _builder=None): @core.extern -def min(arg0, arg1, propagate_nan: core.constexpr = PropagateNan.NONE, _builder=None): +def min(arg0, arg1, propagate_nan: core.constexpr = core.PropagateNan.NONE, _builder=None): arg0 = core._to_tensor(arg0, _builder) arg1 = core._to_tensor(arg1, _builder) arg0 = core._promote_bfloat16_to_float32(arg0, _builder=_builder) @@ -71,22 +59,22 @@ def min(arg0, arg1, propagate_nan: core.constexpr = PropagateNan.NONE, _builder= arg0, arg1 = core.binary_op_type_legalization(arg0, arg1, _builder) dtype = arg0.dtype if dtype.is_floating(): - if propagate_nan == core.constexpr(PropagateNan.ALL): + if propagate_nan == core.constexpr(core.PropagateNan.ALL): return core.tensor(_builder.create_minimumf(arg0.handle, arg1.handle), arg0.type) - elif propagate_nan == core.constexpr(PropagateNan.NONE): + elif propagate_nan == core.constexpr(core.PropagateNan.NONE): return core.tensor(_builder.create_minnumf(arg0.handle, arg1.handle), arg0.type) else: assert False, f"Unexpected propagate_nan {propagate_nan}" elif dtype.is_int_signed(): return core.tensor(_builder.create_minsi(arg0.handle, arg1.handle), arg0.type) elif dtype.is_int_unsigned(): - return core.tensor(_builder.create_minui(arg0.handle, arg1.handle), arg0.dtype) + return core.tensor(_builder.create_minui(arg0.handle, arg1.handle), arg0.type) else: assert False, f"Unexpected dtype {dtype}" @core.extern -def max(arg0, arg1, propagate_nan: core.constexpr = PropagateNan.NONE, _builder=None): +def max(arg0, arg1, propagate_nan: core.constexpr = core.PropagateNan.NONE, _builder=None): arg0 = core._to_tensor(arg0, _builder) arg1 = core._to_tensor(arg1, _builder) arg0 = core._promote_bfloat16_to_float32(arg0, _builder=_builder) @@ -94,16 +82,16 @@ def max(arg0, arg1, propagate_nan: core.constexpr = PropagateNan.NONE, _builder= arg0, arg1 = core.binary_op_type_legalization(arg0, arg1, _builder) dtype = arg0.dtype if dtype.is_floating(): - if propagate_nan == core.constexpr(PropagateNan.ALL): + if propagate_nan == core.constexpr(core.PropagateNan.ALL): return core.tensor(_builder.create_maximumf(arg0.handle, arg1.handle), arg0.type) - elif propagate_nan == core.constexpr(PropagateNan.NONE): + elif propagate_nan == core.constexpr(core.PropagateNan.NONE): return core.tensor(_builder.create_maxnumf(arg0.handle, arg1.handle), arg0.type) else: assert False, f"Unexpected propagate_nan {propagate_nan}" elif dtype.is_int_signed(): return core.tensor(_builder.create_maxsi(arg0.handle, arg1.handle), arg0.type) elif dtype.is_int_unsigned(): - return core.tensor(_builder.create_maxui(arg0.handle, arg1.handle), arg0.dtype) + return core.tensor(_builder.create_maxui(arg0.handle, arg1.handle), arg0.type) else: assert False, f"Unexpected dtype {dtype}" diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index a0cb792484..0fe2bec6df 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -249,6 +249,25 @@ def mod(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: assert False +############## +# other arithmetic ops +############## + + +def clamp(x: tl.tensor, min: tl.tensor, max: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + min, max = binary_op_type_checking_impl(min, max, builder) + x, min = binary_op_type_checking_impl(x, min, builder) + x, max = binary_op_type_checking_impl(x, max, builder) + + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type) + return tl.tensor(builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type) + else: + assert False, f"Unexpected dtype {dtype}. Only floating point clamp is supported" + + ############## # bitwise ops ############## diff --git a/python/triton/language/standard.py b/python/triton/language/standard.py index 7d8b3af67b..4ffdc13318 100644 --- a/python/triton/language/standard.py +++ b/python/triton/language/standard.py @@ -7,8 +7,6 @@ # Standard library # ----------------------- -from .math import PropagateNan - @jit def cdiv(x, div): @@ -101,7 +99,7 @@ def zeros_like(input): @jit -def minimum(x, y, propagate_nan: core.constexpr = PropagateNan.NONE): +def minimum(x, y, propagate_nan: core.constexpr = core.PropagateNan.NONE): """ Computes the element-wise minimum of :code:`x` and :code:`y`. @@ -118,7 +116,7 @@ def minimum(x, y, propagate_nan: core.constexpr = PropagateNan.NONE): @jit -def maximum(x, y, propagate_nan: core.constexpr = PropagateNan.NONE): +def maximum(x, y, propagate_nan: core.constexpr = core.PropagateNan.NONE): """ Computes the element-wise maximum of :code:`x` and :code:`y`. diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index d226caa050..9e47051cc1 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -11,7 +11,6 @@ import torch import intel_extension_for_pytorch as ipex -from .interpreter import InterpretedFunction from ..runtime.driver import driver @@ -594,6 +593,7 @@ def jit( def decorator(fn: T) -> JITFunction[T]: assert callable(fn) if os.getenv("TRITON_INTERPRET", "0") == "1": + from .interpreter import InterpretedFunction return InterpretedFunction(fn) else: return JITFunction( diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 3f09757757..0a006b90fc 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -304,3 +304,32 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c tt.return } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +// CHECK-LABEL: clamp +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @clamp(%x : tensor<1024xf32, #blocked>, %limit : tensor<1024xf32, #blocked>) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<1024xf32, #blocked> + %neg_limit = arith.subf %cst, %limit : tensor<1024xf32, #blocked> + + // CHECK: %{{[a-zA-Z0-9]+}} = llvm.inline_asm asm_dialect = att operand_attrs = [] "min.xorsign.abs.f32 $0, $1, $2;", "=f,f,f" %{{[a-zA-Z0-9]+}}, %{{[a-zA-Z0-9]+}} : (f32, f32) -> f32 + %12 = tt.clampf %x, %neg_limit, %limit {propagateNan = 0 : i32} : tensor<1024xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 16]}> +// CHECK-LABEL: convert_mma_to_blocked +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @convert_mma_to_blocked(%a: tensor<128x256xf16, #mma>) { + // CHECK-COUNT-16: nvgpu.stmatrix + // CHECK: nvvm.barrier0 + %c = triton_gpu.convert_layout %a : (tensor<128x256xf16, #mma>) -> tensor<128x256xf16, #blocked> + tt.return + } +} diff --git a/test/lib/Analysis/CMakeLists.txt b/test/lib/Analysis/CMakeLists.txt index e3c7743795..1bf9d84702 100644 --- a/test/lib/Analysis/CMakeLists.txt +++ b/test/lib/Analysis/CMakeLists.txt @@ -6,6 +6,5 @@ add_mlir_library(TritonTestAnalysis LINK_LIBS PUBLIC MLIRPass - TritonAnalysis - ${dialect_libs} + ${triton_libs} ) diff --git a/third_party/amd b/third_party/amd new file mode 160000 index 0000000000..a3c7061800 --- /dev/null +++ b/third_party/amd @@ -0,0 +1 @@ +Subproject commit a3c7061800f31db179ba34e1369725841ec8cb0d diff --git a/third_party/cuda/CMakeLists.txt b/third_party/cuda/CMakeLists.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/third_party/nvidia/CMakeLists.txt b/third_party/nvidia/CMakeLists.txt new file mode 100644 index 0000000000..f6a91676b5 --- /dev/null +++ b/third_party/nvidia/CMakeLists.txt @@ -0,0 +1 @@ +add_triton_plugin(TritonNVIDIA ${CMAKE_CURRENT_SOURCE_DIR}/triton_nvidia.cc) \ No newline at end of file diff --git a/third_party/cuda/backend/__init__.py b/third_party/nvidia/backend/__init__.py similarity index 100% rename from third_party/cuda/backend/__init__.py rename to third_party/nvidia/backend/__init__.py diff --git a/third_party/cuda/backend/compiler.py b/third_party/nvidia/backend/compiler.py similarity index 100% rename from third_party/cuda/backend/compiler.py rename to third_party/nvidia/backend/compiler.py diff --git a/third_party/cuda/backend/driver.c b/third_party/nvidia/backend/driver.c similarity index 100% rename from third_party/cuda/backend/driver.c rename to third_party/nvidia/backend/driver.c diff --git a/third_party/cuda/backend/driver.py b/third_party/nvidia/backend/driver.py similarity index 100% rename from third_party/cuda/backend/driver.py rename to third_party/nvidia/backend/driver.py diff --git a/third_party/cuda/backend/include/cuda.h b/third_party/nvidia/backend/include/cuda.h similarity index 100% rename from third_party/cuda/backend/include/cuda.h rename to third_party/nvidia/backend/include/cuda.h diff --git a/third_party/cuda/backend/lib/libdevice.10.bc b/third_party/nvidia/backend/lib/libdevice.10.bc similarity index 100% rename from third_party/cuda/backend/lib/libdevice.10.bc rename to third_party/nvidia/backend/lib/libdevice.10.bc diff --git a/third_party/cuda/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc similarity index 100% rename from third_party/cuda/triton_nvidia.cc rename to third_party/nvidia/triton_nvidia.cc diff --git a/third_party/xpu/CMakeLists.txt b/third_party/xpu/CMakeLists.txt index e69de29bb2..72ab20fac6 100644 --- a/third_party/xpu/CMakeLists.txt +++ b/third_party/xpu/CMakeLists.txt @@ -0,0 +1 @@ +add_triton_plugin(TritonXPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_xpu.cc) \ No newline at end of file diff --git a/unittest/Analysis/CMakeLists.txt b/unittest/Analysis/CMakeLists.txt index af11f1d807..e94696bf5f 100644 --- a/unittest/Analysis/CMakeLists.txt +++ b/unittest/Analysis/CMakeLists.txt @@ -6,4 +6,5 @@ add_triton_ut( TritonIR TritonGPUIR ${dialect_libs} + ${triton_libs} ) diff --git a/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt b/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt index 0ba2be07f2..592d1b7c23 100644 --- a/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -7,5 +7,5 @@ add_triton_ut( add_triton_ut( NAME TestEmitIndices SRCS EmitIndicesTest.cpp DumpLayout.cpp - LIBS TritonGPUIR TritonNvidiaGPUIR ${dialect_libs} ${conversion_libs} + LIBS TritonGPUIR TritonNvidiaGPUIR ${dialect_libs} ${conversion_libs} ${triton_libs} ) diff --git a/unittest/Dialect/TritonGPU/CMakeLists.txt b/unittest/Dialect/TritonGPU/CMakeLists.txt index 3dfa69701f..28576d7fd4 100644 --- a/unittest/Dialect/TritonGPU/CMakeLists.txt +++ b/unittest/Dialect/TritonGPU/CMakeLists.txt @@ -1,5 +1,5 @@ add_triton_ut( NAME TestSwizzling SRCS SwizzleTest.cpp - LIBS TritonGPUIR TritonNvidiaGPUIR TritonTransforms ${dialect_libs} ${conversion_libs} + LIBS TritonGPUIR TritonNvidiaGPUIR TritonTransforms ${dialect_libs} ${conversion_libs} ${triton_libs} )