Skip to content

Commit

Permalink
Apply _MembersNeededForFinalize pattern to _program.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rwgk committed Nov 30, 2024
1 parent 08aa6ec commit b872767
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions cuda_core/cuda/core/experimental/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,41 @@ class Program:
"""

__slots__ = ("__weakref__", "_handle", "_backend")
class _MembersNeededForFinalize:
__slots__ = ("handle",)

def __init__(self, program_obj, handle):
self.handle = handle
weakref.finalize(program_obj, self.close)

def close(self):
if self.handle is not None:
handle_return(nvrtc.nvrtcDestroyProgram(self.handle))
self.handle = None

__slots__ = ("__weakref__", "_mnff", "_backend")
_supported_code_type = ("c++",)
_supported_target_type = ("ptx", "cubin", "ltoir")

def __init__(self, code, code_type):
self._mnff = Program._MembersNeededForFinalize(self, None)

if code_type not in self._supported_code_type:
raise NotImplementedError

self._handle = None
weakref.finalize(self, self.close)

if code_type.lower() == "c++":
if not isinstance(code, str):
raise TypeError
# TODO: support pre-loaded headers & include names
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
self._handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], []))
self._backend = "nvrtc"
else:
raise NotImplementedError

def close(self):
"""Destroy this program."""
if self._handle is not None:
handle_return(nvrtc.nvrtcDestroyProgram(self._handle))
self._handle = None
self._mnff.close()

def compile(self, target_type, options=(), name_expressions=(), logs=None):
"""Compile the program with a specific compilation type.
Expand Down Expand Up @@ -84,29 +93,29 @@ def compile(self, target_type, options=(), name_expressions=(), logs=None):
if self._backend == "nvrtc":
if name_expressions:
for n in name_expressions:
handle_return(nvrtc.nvrtcAddNameExpression(self._handle, n.encode()), handle=self._handle)
handle_return(nvrtc.nvrtcAddNameExpression(self._mnff.handle, n.encode()), handle=self._mnff.handle)
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
options = list(o.encode() for o in options)
handle_return(nvrtc.nvrtcCompileProgram(self._handle, len(options), options), handle=self._handle)
handle_return(nvrtc.nvrtcCompileProgram(self._mnff.handle, len(options), options), handle=self._mnff.handle)

size_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}Size")
comp_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}")
size = handle_return(size_func(self._handle), handle=self._handle)
size = handle_return(size_func(self._mnff.handle), handle=self._mnff.handle)
data = b" " * size
handle_return(comp_func(self._handle, data), handle=self._handle)
handle_return(comp_func(self._mnff.handle, data), handle=self._mnff.handle)

symbol_mapping = {}
if name_expressions:
for n in name_expressions:
symbol_mapping[n] = handle_return(
nvrtc.nvrtcGetLoweredName(self._handle, n.encode()), handle=self._handle
nvrtc.nvrtcGetLoweredName(self._mnff.handle, n.encode()), handle=self._mnff.handle
)

if logs is not None:
logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._handle), handle=self._handle)
logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._mnff.handle), handle=self._mnff.handle)
if logsize > 1:
log = b" " * logsize
handle_return(nvrtc.nvrtcGetProgramLog(self._handle, log), handle=self._handle)
handle_return(nvrtc.nvrtcGetProgramLog(self._mnff.handle, log), handle=self._mnff.handle)
logs.write(log.decode())

# TODO: handle jit_options for ptx?
Expand All @@ -121,4 +130,4 @@ def backend(self):
@property
def handle(self):
"""Return the program handle object."""
return self._handle
return self._mnff.handle

0 comments on commit b872767

Please sign in to comment.