diff --git a/cuda_core/cuda/core/experimental/_program.py b/cuda_core/cuda/core/experimental/_program.py index 22b9fd81..28936a88 100644 --- a/cuda_core/cuda/core/experimental/_program.py +++ b/cuda_core/cuda/core/experimental/_program.py @@ -26,32 +26,41 @@ class Program: """ - __slots__ = ("__weakref__", "_handle", "_backend") + class _MembersNeededForFinalize: + __slots__ = ("handle",) + + def __init__(self, program_obj, handle): + self.handle = handle + weakref.finalize(program_obj, self.close) + + def close(self): + if self.handle is not None: + handle_return(nvrtc.nvrtcDestroyProgram(self.handle)) + self.handle = None + + __slots__ = ("__weakref__", "_mnff", "_backend") _supported_code_type = ("c++",) _supported_target_type = ("ptx", "cubin", "ltoir") def __init__(self, code, code_type): + self._mnff = Program._MembersNeededForFinalize(self, None) + if code_type not in self._supported_code_type: raise NotImplementedError - self._handle = None - weakref.finalize(self, self.close) - if code_type.lower() == "c++": if not isinstance(code, str): raise TypeError # TODO: support pre-loaded headers & include names # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved - self._handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], [])) + self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), b"", 0, [], [])) self._backend = "nvrtc" else: raise NotImplementedError def close(self): """Destroy this program.""" - if self._handle is not None: - handle_return(nvrtc.nvrtcDestroyProgram(self._handle)) - self._handle = None + self._mnff.close() def compile(self, target_type, options=(), name_expressions=(), logs=None): """Compile the program with a specific compilation type. @@ -84,29 +93,29 @@ def compile(self, target_type, options=(), name_expressions=(), logs=None): if self._backend == "nvrtc": if name_expressions: for n in name_expressions: - handle_return(nvrtc.nvrtcAddNameExpression(self._handle, n.encode()), handle=self._handle) + handle_return(nvrtc.nvrtcAddNameExpression(self._mnff.handle, n.encode()), handle=self._mnff.handle) # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved options = list(o.encode() for o in options) - handle_return(nvrtc.nvrtcCompileProgram(self._handle, len(options), options), handle=self._handle) + handle_return(nvrtc.nvrtcCompileProgram(self._mnff.handle, len(options), options), handle=self._mnff.handle) size_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}Size") comp_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}") - size = handle_return(size_func(self._handle), handle=self._handle) + size = handle_return(size_func(self._mnff.handle), handle=self._mnff.handle) data = b" " * size - handle_return(comp_func(self._handle, data), handle=self._handle) + handle_return(comp_func(self._mnff.handle, data), handle=self._mnff.handle) symbol_mapping = {} if name_expressions: for n in name_expressions: symbol_mapping[n] = handle_return( - nvrtc.nvrtcGetLoweredName(self._handle, n.encode()), handle=self._handle + nvrtc.nvrtcGetLoweredName(self._mnff.handle, n.encode()), handle=self._mnff.handle ) if logs is not None: - logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._handle), handle=self._handle) + logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(self._mnff.handle), handle=self._mnff.handle) if logsize > 1: log = b" " * logsize - handle_return(nvrtc.nvrtcGetProgramLog(self._handle, log), handle=self._handle) + handle_return(nvrtc.nvrtcGetProgramLog(self._mnff.handle, log), handle=self._mnff.handle) logs.write(log.decode()) # TODO: handle jit_options for ptx? @@ -121,4 +130,4 @@ def backend(self): @property def handle(self): """Return the program handle object.""" - return self._handle + return self._mnff.handle