Skip to content

Commit

Permalink
Support native code binary representation for XPU backend (#2148)
Browse files Browse the repository at this point in the history
Adds a new command line flag, `TRITON_XPU_GEN_NATIVE_CODE`, which is
used to enable generating native device code and storing it in the
`.spv` file instead of spirv. To avoid having to access the sycl runtime
inside the compiler, we use `ocloc` (just like the nvidia backend uses
`ptxas` to generate `cubin` from `ptx`. But, because there is no textual
representation of `spirv`, we do not store the spirv. Originally, I had
changed the file extension but decided to stick with `spv` for now while
we evaluate if/when we want to enable this functionality.

In my testing this makes very little difference in back-to-back runs
because the driver caches the native code. But this feature was
requested for Inductor AOT mode where the model is exported into a
self-contained library.

```
spirv:
compile eval time 0: 10.648161754
compile eval time 1: 0.013945419
compile eval time 2: 0.012984403
compile eval time 3: 0.012636915
compile eval time 4: 0.0126077
compile eval time 5: 0.012621725
compile eval time 6: 0.012668987
compile eval time 7: 0.012654901
compile eval time 8: 0.01264564
compile eval time 9: 0.013606956

generated native code:
compile eval time 0: 15.920665989
compile eval time 1: 0.013856625
compile eval time 2: 0.012849391
compile eval time 3: 0.01316768
compile eval time 4: 0.012814131
compile eval time 5: 0.012739703
compile eval time 6: 0.012831038
compile eval time 7: 0.012756367
compile eval time 8: 0.012842068
compile eval time 9: 0.013376863
```

Close #1792
  • Loading branch information
alexbaden authored Sep 25, 2024
1 parent 542c9f8 commit 5b4d89a
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 55 deletions.
61 changes: 59 additions & 2 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from types import ModuleType
import hashlib
import re
import tempfile
import signal
import os
import shutil
import subprocess
Expand Down Expand Up @@ -51,6 +53,7 @@ class XPUOptions:
max_num_imprecise_acc_default: int = 0 # `max_num_imprecise_acc` only applies to fp8 -> fp32 dot on sm_90 for cuda
extern_libs: dict = None
debug: bool = False
generate_native_code: bool = False
backend_name: str = 'intel'

def __post_init__(self):
Expand All @@ -62,6 +65,7 @@ def __post_init__(self):
object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
if self.num_warps <= 0 or (self.num_warps & (self.num_warps - 1)) != 0:
raise AssertionError("num_warps must be a power of 2")
self.generate_native_code = bool(os.getenv("TRITON_XPU_GEN_NATIVE_CODE", self.generate_native_code))

def hash(self):
key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()])
Expand Down Expand Up @@ -284,7 +288,7 @@ def make_llir(src, metadata, options):

@staticmethod
def make_spv(src, metadata, options):
ret, name = intel.translate_to_spirv(src)
spirv, name = intel.translate_to_spirv(src)
metadata["name"] = name
if options.grf_mode == 'small':
metadata["build_flags"] = "-cl-intel-128-GRF-per-thread"
Expand All @@ -297,7 +301,60 @@ def make_spv(src, metadata, options):
else:
metadata["build_flags"] = ""

return ret
if options.generate_native_code:
with tempfile.NamedTemporaryFile(delete=False, mode='wb', suffix='.spv') as fsrc, \
tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
fsrc.write(spirv)
fsrc.flush()
fbin = fsrc.name + '.o'

ocloc_cmd = [
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', 'pvc', '-options',
metadata["build_flags"]
]

try:
subprocess.run(ocloc_cmd, check=True, close_fds=False, stdout=flog, stderr=subprocess.STDOUT)
if os.path.exists(flog.name):
with open(flog.name) as log_file:
log = log_file.read().strip()
if 'spilled' in log:
"""
The exact message is something like:
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217
is "spilled" enough for now?
"""
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
# re-run with new build flags
ocloc_cmd[-1] = metadata["build_flags"]
subprocess.run(ocloc_cmd, check=True, close_fds=False, stdout=flog,
stderr=subprocess.STDOUT)
os.remove(flog.name)
if os.path.exists(fsrc.name):
os.remove(fsrc.name)
except subprocess.CalledProcessError as e:
with open(flog.name) as log_file:
log = log_file.read()
if os.path.exists(flog.name):
os.remove(flog.name)

if e.returncode == 255:
error = 'Internal Triton ZEBIN codegen error'
elif e.returncode == 128 + signal.SIGSEGV:
error = '`ocloc` raised SIGSEGV'
else:
error = f'`ocloc` failed with error code {e.returncode}'

raise RuntimeError(f'{error}\n'
f'`ocloc` stderr:\n{log}\n'
f'Repro command: {ocloc_cmd}\n')

with open(fbin, 'rb') as f:
zebin = f.read()
if os.path.exists(fbin):
os.remove(fbin)
return zebin
return spirv

def add_stages(self, stages, options):
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
Expand Down
70 changes: 39 additions & 31 deletions third_party/intel/backend/driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,21 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
const sycl::device sycl_device = sycl_l0_device_pair.first;

std::string kernel_name = name;
size_t binary_size = PyBytes_Size(py_bytes);
binary_size = binary_size / sizeof(uint32_t);
const size_t binary_size = PyBytes_Size(py_bytes);

uint8_t *binary_ptr = (uint8_t *)PyBytes_AsString(py_bytes);
const auto ctx = sycl_device.get_platform().ext_oneapi_get_default_context();
const auto l0_device =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
const auto l0_context =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(ctx);

const auto use_native_code =
isEnvValueBool(getStrEnv("TRITON_XPU_GEN_NATIVE_CODE"));
const bool is_spv = use_native_code ? !(*use_native_code) : true;

auto l0_module = checkSyclErrors(create_module(
l0_context, l0_device, binary_ptr, binary_size, build_flags));
l0_context, l0_device, binary_ptr, binary_size, build_flags, is_spv));

auto checkL0Errors = [&](auto l0_module) -> ze_kernel_handle_t {
if (PyErr_Occurred()) {
Expand All @@ -169,35 +173,39 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {

int32_t n_spills = props.spillMemSize;
const int32_t n_regs = 0;
constexpr int32_t max_reg_spill = 1000;
std::string build_flags_str(build_flags);
bool is_GRF_mode_specified = false;

// Check whether the GRF mode is specified by the build flags.
if (build_flags_str.find("-cl-intel-256-GRF-per-thread") !=
std::string::npos ||
build_flags_str.find("-cl-intel-128-GRF-per-thread") !=
std::string::npos ||
build_flags_str.find("-cl-intel-enable-auto-large-GRF-mode") !=
std::string::npos) {
is_GRF_mode_specified = true;
}

// If the register mode isn't set, and the number of spills is greater
// than the threshold, recompile the kernel using large GRF mode.
if (!is_GRF_mode_specified && n_spills > max_reg_spill) {
std::cout << "(I): Detected " << n_spills
<< " spills, recompiling the kernel using large GRF mode"
<< std::endl;
const std::string new_build_flags =
build_flags_str.append(" -cl-intel-256-GRF-per-thread");
l0_module =
checkSyclErrors(create_module(l0_context, l0_device, binary_ptr,
binary_size, new_build_flags.c_str()));
l0_kernel = checkL0Errors(l0_module);
gpuAssert(zeKernelGetProperties(l0_kernel, &props));
n_spills = props.spillMemSize;
std::cout << "(I): Kernel has now " << n_spills << " spills" << std::endl;
if (is_spv) {
constexpr int32_t max_reg_spill = 1000;
std::string build_flags_str(build_flags);
bool is_GRF_mode_specified = false;

// Check whether the GRF mode is specified by the build flags.
if (build_flags_str.find("-cl-intel-256-GRF-per-thread") !=
std::string::npos ||
build_flags_str.find("-cl-intel-128-GRF-per-thread") !=
std::string::npos ||
build_flags_str.find("-cl-intel-enable-auto-large-GRF-mode") !=
std::string::npos) {
is_GRF_mode_specified = true;
}

// If the register mode isn't set, and the number of spills is greater
// than the threshold, recompile the kernel using large GRF mode.
if (!is_GRF_mode_specified && n_spills > max_reg_spill) {
std::cout << "(I): Detected " << n_spills
<< " spills, recompiling the kernel using large GRF mode"
<< std::endl;
const std::string new_build_flags =
build_flags_str.append(" -cl-intel-256-GRF-per-thread");
l0_module = checkSyclErrors(
create_module(l0_context, l0_device, binary_ptr, binary_size,
new_build_flags.c_str(), is_spv));

l0_kernel = checkL0Errors(l0_module);
gpuAssert(zeKernelGetProperties(l0_kernel, &props));
n_spills = props.spillMemSize;
std::cout << "(I): Kernel has now " << n_spills << " spills" << std::endl;
}
}

auto mod = new sycl::kernel_bundle<sycl::bundle_state::executable>(
Expand Down
44 changes: 39 additions & 5 deletions third_party/intel/backend/include/sycl_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cassert>
#include <iostream>
#include <memory>
#include <optional>
#include <sstream>
#include <string>
#include <string_view>
Expand All @@ -23,6 +24,20 @@ typedef struct l0_resc_handles {

using SyclQueueMap = std::unordered_map<sycl::queue, l0_resc_handles>;

// Create an exception handler for asynchronous SYCL exceptions
auto exception_handler = [](sycl::exception_list e_list) {
for (std::exception_ptr const &e : e_list) {
try {
std::rethrow_exception(e);
} catch (std::exception const &e) {
#if _DEBUG
std::cout << "Failure" << std::endl;
#endif
std::terminate();
}
}
};

inline std::string parseZeResultCode(const ze_result_t code) {
const std::string prefix = "Triton Error [ZE]: ";
std::stringstream ss;
Expand All @@ -37,6 +52,15 @@ inline std::string parseZeResultCode(const ze_result_t code) {
} \
}

// TODO: share Triton GetEnv.hpp impl
inline std::string getStrEnv(const std::string &env) {
const char *cstr = std::getenv(env.c_str());
if (!cstr)
return "";
std::string result(cstr);
return result;
}

bool getBoolEnv(const std::string &env) {
const char *s = std::getenv(env.c_str());
std::string str(s ? s : "");
Expand All @@ -45,19 +69,29 @@ bool getBoolEnv(const std::string &env) {
return (str == "on" || str == "true" || str == "1");
}

inline std::optional<bool> isEnvValueBool(std::string str) {
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) { return std::tolower(c); });
if (str == "on" || str == "true" || str == "1")
return true;
if (str == "off" || str == "false" || str == "0")
return false;
return std::nullopt;
}

std::tuple<ze_module_handle_t, ze_result_t>
create_module(ze_context_handle_t context, ze_device_handle_t device,
uint8_t *binary_ptr, size_t binary_size,
const char *build_flags) {
uint8_t *binary_ptr, size_t binary_size, const char *build_flags,
const bool is_spv = true) {
assert(binary_ptr != nullptr && "binary_ptr should not be NULL");
assert(build_flags != nullptr && "build_flags should not be NULL");

const ze_module_format_t format = ZE_MODULE_FORMAT_IL_SPIRV;
const ze_module_format_t format =
is_spv ? ZE_MODULE_FORMAT_IL_SPIRV : ZE_MODULE_FORMAT_NATIVE;
ze_module_desc_t module_description = {};
module_description.stype = ZE_STRUCTURE_TYPE_MODULE_DESC;
module_description.format = format;
module_description.inputSize =
static_cast<uint32_t>(binary_size * sizeof(uint32_t));
module_description.inputSize = static_cast<uint32_t>(binary_size);
module_description.pInputModule = binary_ptr;
module_description.pBuildFlags = build_flags;
ze_module_build_log_handle_t buildlog;
Expand Down
19 changes: 2 additions & 17 deletions utils/SPIRVRunner/SPIRVRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,6 @@

#include "sycl_functions.h"

// Create an exception handler for asynchronous SYCL exceptions
static auto exception_handler = [](sycl::exception_list e_list) {
for (std::exception_ptr const &e : e_list) {
try {
std::rethrow_exception(e);
} catch (std::exception const &e) {
#if _DEBUG
std::cout << "Failure" << std::endl;
#endif
std::terminate();
}
}
};

auto load_tensor(const std::string &filename) {
std::ifstream ins(filename, std::ios::binary);
if (!ins.is_open()) {
Expand Down Expand Up @@ -297,9 +283,8 @@ int main() {
auto spirv = read_spirv("add_kernel.spv");
std::cout << "Read " << spirv.size() << " byte kernel." << std::endl;

auto [kernel_bundle, kernel, n_regs, n_spills] =
loadBinary("add_kernel", reinterpret_cast<uint8_t *>(spirv.data()),
spirv.size() / sizeof(uint32_t), 0);
auto [kernel_bundle, kernel, n_regs, n_spills] = loadBinary(
"add_kernel", reinterpret_cast<uint8_t *>(spirv.data()), spirv.size(), 0);

// TODO: missing number of registers
std::cout << "Loaded kernel with " << n_regs << " registers and " << n_spills
Expand Down

0 comments on commit 5b4d89a

Please sign in to comment.