diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index ebfe924075..e646c1aad1 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -13,6 +13,7 @@ from packaging.version import Version +@functools.lru_cache() def _path_to_binary(binary: str): paths = [ os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), @@ -46,12 +47,14 @@ class XPUOptions: max_num_imprecise_acc_default: int = 0 # `max_num_imprecise_acc` only applies to fp8 -> fp32 dot on sm_90 for cuda extern_libs: dict = None debug: bool = False + isBlockPtrEnabled: bool = os.environ.get("TRITON_INTEL_ENABLE_BLOCK_PTR", "0") == "1" def __post_init__(self): default_libdir = Path(__file__).parent / 'lib' extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) if not extern_libs.get('libdevice', None): - extern_libs['libdevice'] = str(default_libdir / 'libsycl-spir64-unknown-unknown.bc') + extern_libs['libdevice'] = os.getenv("TRITON_LIBDEVICE_PATH", + str(default_libdir / 'libsycl-spir64-unknown-unknown.bc')) object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ "num_warps must be a power of 2" @@ -63,6 +66,41 @@ def hash(self): class XPUBackend(BaseBackend): + # Experimental pass pipeline for kernels using block pointers. + class Experimental: + + @staticmethod + def make_ttir(mod, metadata, opt): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_combine(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + return mod + + @staticmethod + def make_ttgir(mod, metadata, opt, device_arch): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + + intel.passes.ttir.add_convert_to_ttgpuir_warp(pm, opt.num_warps) + # FIXME: Use a better way to check if prefetch instructions are supported once available. + # Prefetch instruction is not available in older drivers. + if Version(metadata["target"].arch['driver_version']) > Version("1.3.28202"): + intel.passes.ttgpuir.add_prefetch_block(pm) + intel.passes.ttgpuir.add_distribute_to_warps(pm) + intel.passes.ttgpuir.add_match_target_size(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + return mod + @staticmethod def supports_target(target: tuple): return target.backend == 'xpu' @@ -73,11 +111,11 @@ def __init__(self, target: tuple) -> None: dirname = os.path.dirname(os.path.realpath(__file__)) mod = compile_module_from_src(Path(os.path.join(dirname, "arch_parser.c")).read_text(), "arch_utils") self.parse_device_arch = mod.parse_device_arch - self.properties = self._parse_target(target.arch) + self.properties = self.parse_target(target.arch) self.device_arch = self.properties["device_arch"] self.binary_ext = "spv" - def _parse_target(self, tgt_prop) -> dict: + def parse_target(self, tgt_prop) -> dict: dev_prop = {} dev_prop['name'] = tgt_prop.get('name', 'xpu') dev_prop['platform_name'] = tgt_prop.get('platform_name', None) @@ -101,17 +139,24 @@ def parse_options(self, opts) -> Any: def pack_metadata(self, metadata): return metadata + def get_codegen_implementation(self): + from triton.language.extra.intel import convert_custom_float8 + codegen_fns = {} + codegen_fns["convert_custom_types"] = convert_custom_float8 + return codegen_fns + def load_dialects(self, ctx): intel.load_dialects(ctx) @staticmethod def make_ttir(mod, metadata, opt): + if XPUOptions.isBlockPtrEnabled: + return XPUBackend.Experimental.make_ttir(mod, metadata, opt) + pm = ir.pass_manager(mod.context) pm.enable_debug() passes.common.add_inliner(pm) - isBlockPtrEnabled = os.environ.get("TRITON_INTEL_ENABLE_BLOCK_PTR", "0") == "1" - if not isBlockPtrEnabled: - passes.ttir.add_rewrite_tensor_pointer(pm) + passes.ttir.add_rewrite_tensor_pointer(pm) passes.ttir.add_combine(pm) passes.common.add_canonicalizer(pm) passes.ttir.add_reorder_broadcast(pm) @@ -123,33 +168,20 @@ def make_ttir(mod, metadata, opt): @staticmethod def make_ttgir(mod, metadata, opt, device_arch): + if XPUOptions.isBlockPtrEnabled: + return XPUBackend.Experimental.make_ttgir(mod, metadata, opt, device_arch) + # TTIR -> TTGIR pm = ir.pass_manager(mod.context) pm.enable_debug() - - isBlockPtrEnabled = os.environ.get("TRITON_INTEL_ENABLE_BLOCK_PTR", "0") == "1" - if isBlockPtrEnabled: - intel.passes.ttir.add_convert_to_ttgpuir_warp(pm, opt.num_warps) - # FIXME: Use a better way to check if prefetch instructions are supported once available. - # Prefetch instruction is not available in older drivers. - if Version(metadata["target"].arch['driver_version']) > Version("1.3.28202"): - intel.passes.ttgpuir.add_prefetch_block(pm) - intel.passes.ttgpuir.add_distribute_to_warps(pm) - intel.passes.ttgpuir.add_match_target_size(pm) - passes.common.add_canonicalizer(pm) - passes.common.add_cse(pm) - passes.common.add_symbol_dce(pm) - pm.run(mod) - return mod passes.ttir.add_convert_to_ttgpuir(pm, f"xpu:{device_arch}", opt.num_warps, opt.threads_per_warp, opt.num_ctas) # optimize TTGIR passes.ttgpuir.add_coalesce(pm) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_thread_locality(pm) + intel.passes.ttgpuir.add_accelerate_matmul(pm, device_arch) passes.ttgpuir.add_remove_layout_conversions(pm) - if opt.optimize_epilogue: - passes.ttgpuir.add_optimize_epilogue(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) passes.common.add_cse(pm) passes.ttgpuir.add_prefetch(pm) @@ -169,11 +201,12 @@ def make_llir(src, metadata, options, device_arch): num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") if num_warp_groups is not None: metadata["num_warps"] *= num_warp_groups - # FIXME: On the `TRITON_INTEL_ENABLE_BLOCK_PTR` path, get_threads_per_warp always return 1. - isBlockPtrEnabled = os.environ.get("TRITON_INTEL_ENABLE_BLOCK_PTR", "0") == "1" - if not isBlockPtrEnabled: + + # FIXME: On the experimental path, get_threads_per_warp always return 1. + if not XPUOptions.isBlockPtrEnabled: threads_per_warp = ir.ttgpuir.get_threads_per_warp(src) metadata["threads_per_warp"] = threads_per_warp + mod = src # TritonGPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) @@ -222,9 +255,3 @@ def add_stages(self, stages, options): def hash(self): version = subprocess.check_output([_path_to_binary("spirv-dis")[0], "--version"], text=True).strip() return f'{version}-{self.properties}' - - def get_codegen_implementation(self): - from triton.language.extra.intel import convert_custom_float8 - codegen_fns = {} - codegen_fns["convert_custom_types"] = convert_custom_float8 - return codegen_fns diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 0a697129d4..d8e3153288 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -15,62 +15,53 @@ namespace py = pybind11; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// Macros to create a pass that takes pass options. +#define ADD_PASS_WRAPPER_OPT_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder({val0})); }) + +#define ADD_PASS_WRAPPER_OPT_2(name, builder, ty0, ty1) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ + pm.addPass(builder({val0, val1})); \ + }) + void init_triton_intel_passes_ttir(py::module &&m) { - m.def("add_convert_to_ttgpuir_warp", [](mlir::PassManager &pm, int numWarps) { - pm.addPass(mlir::triton::createConvertTritonToTritonGPUWarpPass(numWarps)); - }); + ADD_PASS_WRAPPER_1("add_convert_to_ttgpuir_warp", + mlir::triton::createConvertTritonToTritonGPUWarpPass, + unsigned); } void init_triton_intel_passes_ttgpuir(py::module &&m) { - using namespace mlir::triton::gpu; - - // Device arch py::enum_(m, "DEVICE_ARCH", py::module_local()) .value("UNKNOWN", intel::DeviceArch::UNKNOWN) .value("ATS", intel::DeviceArch::ATS) .value("PVC", intel::DeviceArch::PVC) .export_values(); - m.def("add_to_llvmir", [](mlir::PassManager &pm) { - pm.addPass(intel::createConvertTritonIntelGPUToLLVM()); - }); - m.def( - "add_accelerate_matmul", - [](mlir::PassManager &pm, intel::DeviceArch arch) { - pm.addPass(intel::createTritonIntelGPUAccelerateMatmul({arch})); - }, - py::arg("pm"), py::arg("arch") = intel::DeviceArch::UNKNOWN); - m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm) { - pm.addPass(intel::createIntelDecomposeUnsupportedConversions()); - }); - m.def("add_allocate_shared_memory", [](mlir::PassManager &pm) { - pm.addPass(intel::createIntelAllocateSharedMemory()); - }); - m.def( - "add_pipe_line_pass", - [](mlir::PassManager &pm, int numStages, intel::DeviceArch arch) { - pm.addPass(intel::createTritonIntelGPUPipeline({numStages, arch})); - }, - py::arg("pm"), py::arg("numStages"), - py::arg("arch") = intel::DeviceArch::UNKNOWN); - m.def("add_remove_layout_conversions", [](mlir::PassManager &pm) { - pm.addPass(intel::createTritonIntelGPURemoveLayoutConversions()); - }); - m.def( - "add_rewrite_tensor_pointer", - [](mlir::PassManager &pm, intel::DeviceArch arch) { - pm.addPass(intel::createTritonIntelGPURewriteTensorPointer({arch})); - }, - py::arg("pm"), py::arg("arch") = intel::DeviceArch::UNKNOWN); - m.def("add_prefetch_block", [](mlir::PassManager &pm) { - pm.addPass(intel::createTritonIntelGPUPrefetchBlock()); - }); - m.def("add_distribute_to_warps", [](mlir::PassManager &pm) { - pm.addPass(intel::createTritonIntelGPUDistributeToWarps()); - }); - m.def("add_match_target_size", [](mlir::PassManager &pm) { - pm.addPass(intel::createTritonIntelGPUMatchTargetSize()); - }); + ADD_PASS_WRAPPER_0("add_to_llvmir", intel::createConvertTritonIntelGPUToLLVM); + ADD_PASS_WRAPPER_OPT_1("add_accelerate_matmul", + intel::createTritonIntelGPUAccelerateMatmul, + intel::DeviceArch); + ADD_PASS_WRAPPER_0("add_decompose_unsupported_conversions", + intel::createIntelDecomposeUnsupportedConversions); + ADD_PASS_WRAPPER_0("add_allocate_shared_memory", + intel::createIntelAllocateSharedMemory); + ADD_PASS_WRAPPER_OPT_2("add_pipeline", intel::createTritonIntelGPUPipeline, + int, intel::DeviceArch); + ADD_PASS_WRAPPER_0("add_remove_layout_conversions", + intel::createTritonIntelGPURemoveLayoutConversions); + ADD_PASS_WRAPPER_OPT_1("add_rewrite_tensor_pointer", + intel::createTritonIntelGPURewriteTensorPointer, + intel::DeviceArch); + ADD_PASS_WRAPPER_0("add_prefetch_block", + intel::createTritonIntelGPUPrefetchBlock); + ADD_PASS_WRAPPER_0("add_distribute_to_warps", + intel::createTritonIntelGPUDistributeToWarps); + ADD_PASS_WRAPPER_0("add_match_target_size", + intel::createTritonIntelGPUMatchTargetSize); } void init_triton_intel(py::module &&m) { @@ -81,8 +72,8 @@ void init_triton_intel(py::module &&m) { // load dialects m.def("load_dialects", [](mlir::MLIRContext &context) { mlir::DialectRegistry registry; - registry.insert(); + registry + .insert(); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); });