Skip to content

Commit

Permalink
python: implement close() and context manager interface (nomic-ai#2177)
Browse files Browse the repository at this point in the history
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
  • Loading branch information
cebtenzzre authored Mar 28, 2024
1 parent dddaf49 commit 3313c7d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 4 deletions.
32 changes: 30 additions & 2 deletions gpt4all-bindings/python/gpt4all/_pyllmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import threading
from enum import Enum
from queue import Queue
from typing import Any, Callable, Generic, Iterable, TypeVar, overload
from typing import Any, Callable, Generic, Iterable, NoReturn, TypeVar, overload

if sys.version_info >= (3, 9):
import importlib.resources as importlib_resources
Expand Down Expand Up @@ -200,20 +200,32 @@ def __init__(self, model_path: str, n_ctx: int, ngl: int):
if model is None:
s = err.value
raise RuntimeError(f"Unable to instantiate model: {'null' if s is None else s.decode()}")
self.model = model
self.model: ctypes.c_void_p | None = model

def __del__(self, llmodel=llmodel):
if hasattr(self, 'model'):
self.close()

def close(self) -> None:
if self.model is not None:
llmodel.llmodel_model_destroy(self.model)
self.model = None

def _raise_closed(self) -> NoReturn:
raise ValueError("Attempted operation on a closed LLModel")

def _list_gpu(self, mem_required: int) -> list[LLModelGPUDevice]:
assert self.model is not None
num_devices = ctypes.c_int32(0)
devices_ptr = llmodel.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices))
if not devices_ptr:
raise ValueError("Unable to retrieve available GPU devices")
return devices_ptr[:num_devices.value]

def init_gpu(self, device: str):
if self.model is None:
self._raise_closed()

mem_required = llmodel.llmodel_required_mem(self.model, self.model_path, self.n_ctx, self.ngl)

if llmodel.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device.encode()):
Expand Down Expand Up @@ -246,14 +258,21 @@ def load_model(self) -> bool:
-------
True if model loaded successfully, False otherwise
"""
if self.model is None:
self._raise_closed()

return llmodel.llmodel_loadModel(self.model, self.model_path, self.n_ctx, self.ngl)

def set_thread_count(self, n_threads):
if self.model is None:
self._raise_closed()
if not llmodel.llmodel_isModelLoaded(self.model):
raise Exception("Model not loaded")
llmodel.llmodel_setThreadCount(self.model, n_threads)

def thread_count(self):
if self.model is None:
self._raise_closed()
if not llmodel.llmodel_isModelLoaded(self.model):
raise Exception("Model not loaded")
return llmodel.llmodel_threadCount(self.model)
Expand Down Expand Up @@ -322,6 +341,9 @@ def generate_embeddings(
if not text:
raise ValueError("text must not be None or empty")

if self.model is None:
self._raise_closed()

if (single_text := isinstance(text, str)):
text = [text]

Expand Down Expand Up @@ -387,6 +409,9 @@ def prompt_model(
None
"""

if self.model is None:
self._raise_closed()

self.buffer.clear()
self.buff_expecting_cont_bytes = 0

Expand Down Expand Up @@ -419,6 +444,9 @@ def prompt_model(
def prompt_model_streaming(
self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
) -> Iterable[str]:
if self.model is None:
self._raise_closed()

output_queue: Queue[str | Sentinel] = Queue()

# Put response tokens into an output queue
Expand Down
27 changes: 26 additions & 1 deletion gpt4all-bindings/python/gpt4all/gpt4all.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import warnings
from contextlib import contextmanager
from pathlib import Path
from types import TracebackType
from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload

import requests
Expand All @@ -22,7 +23,7 @@
from ._pyllmodel import EmbedResult as EmbedResult

if TYPE_CHECKING:
from typing_extensions import TypeAlias
from typing_extensions import Self, TypeAlias

if sys.platform == 'darwin':
import fcntl
Expand Down Expand Up @@ -54,6 +55,18 @@ def __init__(self, model_name: str | None = None, n_threads: int | None = None,
model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf'
self.gpt4all = GPT4All(model_name, n_threads=n_threads, **kwargs)

def __enter__(self) -> Self:
return self

def __exit__(
self, typ: type[BaseException] | None, value: BaseException | None, tb: TracebackType | None,
) -> None:
self.close()

def close(self) -> None:
"""Delete the model instance and free associated system resources."""
self.gpt4all.close()

# return_dict=False
@overload
def embed(
Expand Down Expand Up @@ -190,6 +203,18 @@ def __init__(
self._history: list[MessageType] | None = None
self._current_prompt_template: str = "{0}"

def __enter__(self) -> Self:
return self

def __exit__(
self, typ: type[BaseException] | None, value: BaseException | None, tb: TracebackType | None,
) -> None:
self.close()

def close(self) -> None:
"""Delete the model instance and free associated system resources."""
self.model.close()

@property
def current_chat_session(self) -> list[MessageType] | None:
return None if self._history is None else list(self._history)
Expand Down
2 changes: 1 addition & 1 deletion gpt4all-bindings/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_long_description():

setup(
name=package_name,
version="2.3.2",
version="2.3.3",
description="Python bindings for GPT4All",
long_description=get_long_description(),
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 3313c7d

Please sign in to comment.