Skip to content

Commit

Permalink
handle culink and nvjitlink differences in the backend and test
Browse files Browse the repository at this point in the history
  • Loading branch information
ksimpson-work committed Dec 4, 2024
1 parent d7bf4cb commit 702fbaa
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 46 deletions.
50 changes: 19 additions & 31 deletions cuda_core/cuda/core/experimental/_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def _lazy_init():
_driver_ver = handle_return(cuda.cuDriverGetVersion())
_driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10)
try:
raise ImportError
from cuda.bindings import nvjitlink
from cuda.bindings._internal import nvjitlink as inner_nvjitlink
except ImportError:
Expand Down Expand Up @@ -247,7 +246,7 @@ def _init_nvjitlink(self):
self.formatted_options.append(f"-split-compile={self.split_compile}")
if self.split_compile_extended is not None:
self.formatted_options.append(f"-split-compile-extended={self.split_compile_extended}")
if self.no_cache is not None:
if self.no_cache is True:
self.formatted_options.append("-no-cache")

def _init_driver(self):
Expand All @@ -272,57 +271,46 @@ def _init_driver(self):
self.formatted_options.append(self.max_register_count)
self.option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS)
if self.time is not None:
self.formatted_options.append(1) # ctypes.c_int32(1)
self.option_keys.append(_driver.CUjit_option.CU_JIT_WALL_TIME)
raise ValueError("time option is not supported by the driver API")
if self.verbose is not None:
self.formatted_options.append(1) # ctypes.c_int32(1)
self.formatted_options.append(1)
self.option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE)
if self.link_time_optimization is not None:
self.formatted_options.append(1) # ctypes.c_int32(1)
self.formatted_options.append(1)
self.option_keys.append(_driver.CUjit_option.CU_JIT_LTO)
if self.ptx is not None:
self.formatted_options.append(1) # ctypes.c_int32(1)
self.option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
raise ValueError("ptx option is not supported by the driver API")
if self.optimization_level is not None:
self.formatted_options.append(self.optimization_level)
self.option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL)
if self.debug is not None:
self.formatted_options.append(1) # ctypes.c_int32(1)
self.formatted_options.append(1)
self.option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO)
if self.lineinfo is not None:
self.formatted_options.append(1) # ctypes.c_int32(1)
self.formatted_options.append(1)
self.option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
if self.ftz is not None:
self.formatted_options.append(1 if self.ftz else 0)
self.option_keys.append(_driver.CUjit_option.CU_JIT_FTZ)
raise ValueError("ftz option is deprecated in the driver API")
if self.prec_div is not None:
self.formatted_options.append(1 if self.prec_div else 0)
self.option_keys.append(_driver.CUjit_option.CU_JIT_PREC_DIV)
raise ValueError("prec_div option is deprecated in the driver API")
if self.prec_sqrt is not None:
self.formatted_options.append(1 if self.prec_sqrt else 0)
self.option_keys.append(_driver.CUjit_option.CU_JIT_PREC_SQRT)
raise ValueError("prec_sqrt option is deprecated in the driver API")
if self.fma is not None:
self.formatted_options.append(1 if self.fma else 0)
self.option_keys.append(_driver.CUjit_option.CU_JIT_FMA)
raise ValueError("fma options is deprecated in the driver API")
if self.kernels_used is not None:
for kernel in self.kernels_used:
self.formatted_options.append(kernel.encode())
self.option_keys.append(_driver.CUjit_option.CU_JIT_REFERENCED_KERNEL_NAMES)
raise ValueError("kernels_used is deprecated in the driver API")
if self.variables_used is not None:
for variable in self.variables_used:
self.formatted_options.append(variable.encode())
self.option_keys.append(_driver.CUjit_option.CU_JIT_REFERENCED_VARIABLE_NAMES)
raise ValueError("variables_used is deprecated in the driver API")
if self.optimize_unused_variables is not None:
self.formatted_options.append(1) # ctypes.c_int32(1)
self.option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZE_UNUSED_DEVICE_VARIABLES)
raise ValueError("optimize_unused_variables is deprecated in the driver API")
if self.xptxas is not None:
for opt in self.xptxas:
raise NotImplementedError("TODO: implement xptxas option")
raise ValueError("xptxas option is not supported by the driver API")
if self.split_compile is not None:
raise ValueError("split_compile option is not supported by the driver API")
if self.split_compile_extended is not None:
self.formatted_options.append(self.split_compile_extended)
self.option_keys.append(_driver.CUjit_option.CU_JIT_MIN_CTA_PER_SM)
raise ValueError("split_compile_extended option is not supported by the driver API")
if self.no_cache is not None:
self.formatted_options.append(1) # ctypes.c_int32(1)
self.formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE)
self.option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE)


Expand Down
51 changes: 36 additions & 15 deletions cuda_core/tests/test_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@
basic_kernel = "__device__ int B() { return 0; }"
addition_kernel = "__device__ int C(int a, int b) { return a + b; }"

try:
from cuda.bindings import nvjitlink # noqa F401
from cuda.bindings._internal import nvjitlink as inner_nvjitlink
except ImportError:
# binding is not available
culink_backend = True
else:
if inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") == 0:
# binding is available, but nvJitLink is not installed
culink_backend = True


@pytest.fixture(scope="function")
def compile_ptx_functions(init_cuda):
Expand All @@ -27,27 +38,36 @@ def compile_ltoir_functions(init_cuda):
return object_code_a_ltoir, object_code_b_ltoir, object_code_c_ltoir


culink_options = [
LinkerOptions(arch=ARCH),
LinkerOptions(arch=ARCH, max_register_count=32),
LinkerOptions(arch=ARCH, verbose=True),
LinkerOptions(arch=ARCH, optimization_level=3),
LinkerOptions(arch=ARCH, debug=True),
LinkerOptions(arch=ARCH, lineinfo=True),
LinkerOptions(arch=ARCH, no_cache=True),
]


@pytest.mark.parametrize(
"options",
[
LinkerOptions(arch=ARCH),
LinkerOptions(arch=ARCH, max_register_count=32),
culink_options
if culink_backend
else culink_options
+ [
LinkerOptions(arch=ARCH, time=True),
LinkerOptions(arch=ARCH, verbose=True),
LinkerOptions(arch=ARCH, optimization_level=3),
LinkerOptions(arch=ARCH, debug=True),
LinkerOptions(arch=ARCH, lineinfo=True),
LinkerOptions(arch=ARCH, ftz=True),
LinkerOptions(arch=ARCH, prec_div=True),
LinkerOptions(arch=ARCH, prec_sqrt=True),
LinkerOptions(arch=ARCH, fma=True),
LinkerOptions(arch=ARCH, kernels_used=["kernel1"]),
LinkerOptions(arch=ARCH, kernels_used=["kernel1", "kernel2"]),
LinkerOptions(arch=ARCH, variables_used=["var1"]),
LinkerOptions(arch=ARCH, variables_used=["var1", "var2"]),
LinkerOptions(arch=ARCH, optimize_unused_variables=True),
# LinkerOptions(arch=ARCH, xptxas=["-v"]),
# LinkerOptions(arch=ARCH, split_compile=0),
LinkerOptions(arch=ARCH, xptxas=["-v"]),
LinkerOptions(arch=ARCH, split_compile=0),
LinkerOptions(arch=ARCH, split_compile_extended=1),
# LinkerOptions(arch=ARCH, no_cache=True),
],
)
def test_linker_init(compile_ptx_functions, options):
Expand All @@ -62,11 +82,12 @@ def test_linker_init_invalid_arch():
Linker(options)


# def test_linker_link_ptx(compile_ltoir_functions):
# options = LinkerOptions(arch=ARCH, link_time_optimization=True, ptx=True)
# linker = Linker(*compile_ltoir_functions, options=options)
# linked_code = linker.link("ptx")
# assert isinstance(linked_code, ObjectCode)
@pytest.mark.skipif(culink_backend, reason="culink does not support ptx option")
def test_linker_link_ptx(compile_ltoir_functions):
options = LinkerOptions(arch=ARCH, link_time_optimization=True, ptx=True)
linker = Linker(*compile_ltoir_functions, options=options)
linked_code = linker.link("ptx")
assert isinstance(linked_code, ObjectCode)


def test_linker_link_cubin(compile_ptx_functions):
Expand Down

0 comments on commit 702fbaa

Please sign in to comment.