Skip to content

Commit

Permalink
Refactor Intel pass pipeline & cleanup (#1049)
Browse files Browse the repository at this point in the history
Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
  • Loading branch information
etiotto authored May 6, 2024
1 parent 26f7cf0 commit 9f408ff
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 80 deletions.
91 changes: 59 additions & 32 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from packaging.version import Version


@functools.lru_cache()
def _path_to_binary(binary: str):
paths = [
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
Expand Down Expand Up @@ -46,12 +47,14 @@ class XPUOptions:
max_num_imprecise_acc_default: int = 0 # `max_num_imprecise_acc` only applies to fp8 -> fp32 dot on sm_90 for cuda
extern_libs: dict = None
debug: bool = False
isBlockPtrEnabled: bool = os.environ.get("TRITON_INTEL_ENABLE_BLOCK_PTR", "0") == "1"

def __post_init__(self):
default_libdir = Path(__file__).parent / 'lib'
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
if not extern_libs.get('libdevice', None):
extern_libs['libdevice'] = str(default_libdir / 'libsycl-spir64-unknown-unknown.bc')
extern_libs['libdevice'] = os.getenv("TRITON_LIBDEVICE_PATH",
str(default_libdir / 'libsycl-spir64-unknown-unknown.bc'))
object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
"num_warps must be a power of 2"
Expand All @@ -63,6 +66,41 @@ def hash(self):

class XPUBackend(BaseBackend):

# Experimental pass pipeline for kernels using block pointers.
class Experimental:

@staticmethod
def make_ttir(mod, metadata, opt):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.common.add_inliner(pm)
passes.ttir.add_combine(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_reorder_broadcast(pm)
passes.common.add_cse(pm)
passes.common.add_licm(pm)
passes.common.add_symbol_dce(pm)
pm.run(mod)
return mod

@staticmethod
def make_ttgir(mod, metadata, opt, device_arch):
pm = ir.pass_manager(mod.context)
pm.enable_debug()

intel.passes.ttir.add_convert_to_ttgpuir_warp(pm, opt.num_warps)
# FIXME: Use a better way to check if prefetch instructions are supported once available.
# Prefetch instruction is not available in older drivers.
if Version(metadata["target"].arch['driver_version']) > Version("1.3.28202"):
intel.passes.ttgpuir.add_prefetch_block(pm)
intel.passes.ttgpuir.add_distribute_to_warps(pm)
intel.passes.ttgpuir.add_match_target_size(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
pm.run(mod)
return mod

@staticmethod
def supports_target(target: tuple):
return target.backend == 'xpu'
Expand All @@ -73,11 +111,11 @@ def __init__(self, target: tuple) -> None:
dirname = os.path.dirname(os.path.realpath(__file__))
mod = compile_module_from_src(Path(os.path.join(dirname, "arch_parser.c")).read_text(), "arch_utils")
self.parse_device_arch = mod.parse_device_arch
self.properties = self._parse_target(target.arch)
self.properties = self.parse_target(target.arch)
self.device_arch = self.properties["device_arch"]
self.binary_ext = "spv"

def _parse_target(self, tgt_prop) -> dict:
def parse_target(self, tgt_prop) -> dict:
dev_prop = {}
dev_prop['name'] = tgt_prop.get('name', 'xpu')
dev_prop['platform_name'] = tgt_prop.get('platform_name', None)
Expand All @@ -101,17 +139,24 @@ def parse_options(self, opts) -> Any:
def pack_metadata(self, metadata):
return metadata

def get_codegen_implementation(self):
from triton.language.extra.intel import convert_custom_float8
codegen_fns = {}
codegen_fns["convert_custom_types"] = convert_custom_float8
return codegen_fns

def load_dialects(self, ctx):
intel.load_dialects(ctx)

@staticmethod
def make_ttir(mod, metadata, opt):
if XPUOptions.isBlockPtrEnabled:
return XPUBackend.Experimental.make_ttir(mod, metadata, opt)

pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.common.add_inliner(pm)
isBlockPtrEnabled = os.environ.get("TRITON_INTEL_ENABLE_BLOCK_PTR", "0") == "1"
if not isBlockPtrEnabled:
passes.ttir.add_rewrite_tensor_pointer(pm)
passes.ttir.add_rewrite_tensor_pointer(pm)
passes.ttir.add_combine(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_reorder_broadcast(pm)
Expand All @@ -123,33 +168,20 @@ def make_ttir(mod, metadata, opt):

@staticmethod
def make_ttgir(mod, metadata, opt, device_arch):
if XPUOptions.isBlockPtrEnabled:
return XPUBackend.Experimental.make_ttgir(mod, metadata, opt, device_arch)

# TTIR -> TTGIR
pm = ir.pass_manager(mod.context)
pm.enable_debug()

isBlockPtrEnabled = os.environ.get("TRITON_INTEL_ENABLE_BLOCK_PTR", "0") == "1"
if isBlockPtrEnabled:
intel.passes.ttir.add_convert_to_ttgpuir_warp(pm, opt.num_warps)
# FIXME: Use a better way to check if prefetch instructions are supported once available.
# Prefetch instruction is not available in older drivers.
if Version(metadata["target"].arch['driver_version']) > Version("1.3.28202"):
intel.passes.ttgpuir.add_prefetch_block(pm)
intel.passes.ttgpuir.add_distribute_to_warps(pm)
intel.passes.ttgpuir.add_match_target_size(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
pm.run(mod)
return mod
passes.ttir.add_convert_to_ttgpuir(pm, f"xpu:{device_arch}", opt.num_warps, opt.threads_per_warp, opt.num_ctas)
# optimize TTGIR
passes.ttgpuir.add_coalesce(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_thread_locality(pm)

intel.passes.ttgpuir.add_accelerate_matmul(pm, device_arch)
passes.ttgpuir.add_remove_layout_conversions(pm)
if opt.optimize_epilogue:
passes.ttgpuir.add_optimize_epilogue(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
passes.common.add_cse(pm)
passes.ttgpuir.add_prefetch(pm)
Expand All @@ -169,11 +201,12 @@ def make_llir(src, metadata, options, device_arch):
num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta")
if num_warp_groups is not None:
metadata["num_warps"] *= num_warp_groups
# FIXME: On the `TRITON_INTEL_ENABLE_BLOCK_PTR` path, get_threads_per_warp always return 1.
isBlockPtrEnabled = os.environ.get("TRITON_INTEL_ENABLE_BLOCK_PTR", "0") == "1"
if not isBlockPtrEnabled:

# FIXME: On the experimental path, get_threads_per_warp always return 1.
if not XPUOptions.isBlockPtrEnabled:
threads_per_warp = ir.ttgpuir.get_threads_per_warp(src)
metadata["threads_per_warp"] = threads_per_warp

mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
Expand Down Expand Up @@ -222,9 +255,3 @@ def add_stages(self, stages, options):
def hash(self):
version = subprocess.check_output([_path_to_binary("spirv-dis")[0], "--version"], text=True).strip()
return f'{version}-{self.properties}'

def get_codegen_implementation(self):
from triton.language.extra.intel import convert_custom_float8
codegen_fns = {}
codegen_fns["convert_custom_types"] = convert_custom_float8
return codegen_fns
87 changes: 39 additions & 48 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,62 +15,53 @@

namespace py = pybind11;

using namespace mlir::triton;
using namespace mlir::triton::gpu;

// Macros to create a pass that takes pass options.
#define ADD_PASS_WRAPPER_OPT_1(name, builder, ty0) \
m.def(name, \
[](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder({val0})); })

#define ADD_PASS_WRAPPER_OPT_2(name, builder, ty0, ty1) \
m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \
pm.addPass(builder({val0, val1})); \
})

void init_triton_intel_passes_ttir(py::module &&m) {
m.def("add_convert_to_ttgpuir_warp", [](mlir::PassManager &pm, int numWarps) {
pm.addPass(mlir::triton::createConvertTritonToTritonGPUWarpPass(numWarps));
});
ADD_PASS_WRAPPER_1("add_convert_to_ttgpuir_warp",
mlir::triton::createConvertTritonToTritonGPUWarpPass,
unsigned);
}

void init_triton_intel_passes_ttgpuir(py::module &&m) {
using namespace mlir::triton::gpu;

// Device arch
py::enum_<intel::DeviceArch>(m, "DEVICE_ARCH", py::module_local())
.value("UNKNOWN", intel::DeviceArch::UNKNOWN)
.value("ATS", intel::DeviceArch::ATS)
.value("PVC", intel::DeviceArch::PVC)
.export_values();

m.def("add_to_llvmir", [](mlir::PassManager &pm) {
pm.addPass(intel::createConvertTritonIntelGPUToLLVM());
});
m.def(
"add_accelerate_matmul",
[](mlir::PassManager &pm, intel::DeviceArch arch) {
pm.addPass(intel::createTritonIntelGPUAccelerateMatmul({arch}));
},
py::arg("pm"), py::arg("arch") = intel::DeviceArch::UNKNOWN);
m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm) {
pm.addPass(intel::createIntelDecomposeUnsupportedConversions());
});
m.def("add_allocate_shared_memory", [](mlir::PassManager &pm) {
pm.addPass(intel::createIntelAllocateSharedMemory());
});
m.def(
"add_pipe_line_pass",
[](mlir::PassManager &pm, int numStages, intel::DeviceArch arch) {
pm.addPass(intel::createTritonIntelGPUPipeline({numStages, arch}));
},
py::arg("pm"), py::arg("numStages"),
py::arg("arch") = intel::DeviceArch::UNKNOWN);
m.def("add_remove_layout_conversions", [](mlir::PassManager &pm) {
pm.addPass(intel::createTritonIntelGPURemoveLayoutConversions());
});
m.def(
"add_rewrite_tensor_pointer",
[](mlir::PassManager &pm, intel::DeviceArch arch) {
pm.addPass(intel::createTritonIntelGPURewriteTensorPointer({arch}));
},
py::arg("pm"), py::arg("arch") = intel::DeviceArch::UNKNOWN);
m.def("add_prefetch_block", [](mlir::PassManager &pm) {
pm.addPass(intel::createTritonIntelGPUPrefetchBlock());
});
m.def("add_distribute_to_warps", [](mlir::PassManager &pm) {
pm.addPass(intel::createTritonIntelGPUDistributeToWarps());
});
m.def("add_match_target_size", [](mlir::PassManager &pm) {
pm.addPass(intel::createTritonIntelGPUMatchTargetSize());
});
ADD_PASS_WRAPPER_0("add_to_llvmir", intel::createConvertTritonIntelGPUToLLVM);
ADD_PASS_WRAPPER_OPT_1("add_accelerate_matmul",
intel::createTritonIntelGPUAccelerateMatmul,
intel::DeviceArch);
ADD_PASS_WRAPPER_0("add_decompose_unsupported_conversions",
intel::createIntelDecomposeUnsupportedConversions);
ADD_PASS_WRAPPER_0("add_allocate_shared_memory",
intel::createIntelAllocateSharedMemory);
ADD_PASS_WRAPPER_OPT_2("add_pipeline", intel::createTritonIntelGPUPipeline,
int, intel::DeviceArch);
ADD_PASS_WRAPPER_0("add_remove_layout_conversions",
intel::createTritonIntelGPURemoveLayoutConversions);
ADD_PASS_WRAPPER_OPT_1("add_rewrite_tensor_pointer",
intel::createTritonIntelGPURewriteTensorPointer,
intel::DeviceArch);
ADD_PASS_WRAPPER_0("add_prefetch_block",
intel::createTritonIntelGPUPrefetchBlock);
ADD_PASS_WRAPPER_0("add_distribute_to_warps",
intel::createTritonIntelGPUDistributeToWarps);
ADD_PASS_WRAPPER_0("add_match_target_size",
intel::createTritonIntelGPUMatchTargetSize);
}

void init_triton_intel(py::module &&m) {
Expand All @@ -81,8 +72,8 @@ void init_triton_intel(py::module &&m) {
// load dialects
m.def("load_dialects", [](mlir::MLIRContext &context) {
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonGEN::TritonGENDialect,
mlir::triton::gpu::intel::TritonIntelGPUDialect>();
registry
.insert<TritonGEN::TritonGENDialect, intel::TritonIntelGPUDialect>();
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
});
Expand Down

0 comments on commit 9f408ff

Please sign in to comment.