Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
refactor and enhance Lancedb URI Handling in AutoVectorStoreIndex (#182)
Browse files Browse the repository at this point in the history
* bump autollm

* add incerement method to lancedb

* add validate and setup method for lancedb_uri

* some fixes

* fix functionality of exist_ok

* remove accidental lancedb dir

* more fixes

* remove deprecated parameters & methods
  • Loading branch information
SeeknnDestroy authored Dec 22, 2023
1 parent 0407ac6 commit 2ec1b21
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 100 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,4 @@ storage/
# vscode settings
.vscode
.lancedb
lancedb/
2 changes: 1 addition & 1 deletion autollm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
and vector databases, along with various utility functions.
"""

__version__ = '0.1.4'
__version__ = '0.1.5'
__author__ = 'safevideo'
__license__ = 'AGPL-3.0'

Expand Down
85 changes: 10 additions & 75 deletions autollm/auto/query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,8 @@ def create_query_engine(
vector_store_type: str = "LanceDBVectorStore",
lancedb_uri: str = "./.lancedb",
lancedb_table_name: str = "vectors",
# Deprecated parameters
llm_params: dict = None,
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
exist_ok: bool = False,
overwrite_existing: bool = False,
**vector_store_kwargs) -> BaseQueryEngine:
"""
Create a query engine from parameters.
Expand Down Expand Up @@ -84,27 +81,6 @@ def create_query_engine(
Returns:
A llama_index.BaseQueryEngine instance.
"""
# Check for deprecated parameters
if llm_params is not None:
raise ValueError(
"llm_params is deprecated. Instead of llm_params={'llm_model': 'model_name', ...}, "
"use llm_model='model_name', llm_api_base='api_base', llm_max_tokens=1028, llm_temperature=0.1 directly as arguments."
)
if vector_store_params is not None:
raise ValueError(
"vector_store_params is deprecated. Instead of vector_store_params={'vector_store_type': 'type', ...}, "
"use vector_store_type='type', lancedb_uri='uri', lancedb_table_name='table', enable_metadata_extraction=True directly as arguments."
)
if service_context_params is not None:
raise ValueError(
"service_context_params is deprecated. Use the explicit parameters like system_prompt='prompt', "
"query_wrapper_prompt='wrapper', enable_cost_calculator=True, embed_model='model', chunk_size=512, "
"chunk_overlap=..., context_window=... directly as arguments.")
if query_engine_params is not None:
raise ValueError(
"query_engine_params is deprecated. Instead of query_engine_params={'similarity_top_k': 5, ...}, "
"use similarity_top_k=5 directly as an argument.")

llm = AutoLiteLLM.from_defaults(
model=llm_model, api_base=llm_api_base, max_tokens=llm_max_tokens, temperature=llm_temperature)

Expand Down Expand Up @@ -132,6 +108,8 @@ def create_query_engine(
documents=documents,
nodes=nodes,
service_context=service_context,
exist_ok=exist_ok,
overwrite_existing=overwrite_existing,
**vector_store_kwargs)
if refine_prompt is not None:
refine_prompt_template = PromptTemplate(refine_prompt, prompt_type=PromptType.REFINE)
Expand Down Expand Up @@ -187,7 +165,6 @@ class AutoQueryEngine:
vector_store_type="LanceDBVectorStore",
lancedb_uri="./.lancedb",
lancedb_table_name="vectors",
enable_metadata_extraction=False,
**vector_store_kwargs)
)
```
Expand Down Expand Up @@ -237,12 +214,8 @@ def from_defaults(
vector_store_type: str = "LanceDBVectorStore",
lancedb_uri: str = "./.lancedb",
lancedb_table_name: str = "vectors",
enable_metadata_extraction: bool = False,
# Deprecated parameters
llm_params: dict = None,
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
exist_ok: bool = False,
overwrite_existing: bool = False,
**vector_store_kwargs) -> BaseQueryEngine:
"""
Create an AutoQueryEngine from default parameters.
Expand Down Expand Up @@ -272,6 +245,8 @@ def from_defaults(
vector_store_type (str): The vector store type to use for the query engine.
lancedb_uri (str): The URI to use for the LanceDB vector store.
lancedb_table_name (str): The table name to use for the LanceDB vector store.
exist_ok (bool): Flag to allow overwriting an existing vector store.
overwrite_existing (bool): Flag to allow overwriting an existing vector store.
Returns:
A llama_index.BaseQueryEngine instance.
Expand Down Expand Up @@ -302,50 +277,10 @@ def from_defaults(
vector_store_type=vector_store_type,
lancedb_uri=lancedb_uri,
lancedb_table_name=lancedb_table_name,
enable_metadata_extraction=enable_metadata_extraction,
# Deprecated parameters
llm_params=llm_params,
vector_store_params=vector_store_params,
service_context_params=service_context_params,
query_engine_params=query_engine_params,
exist_ok=exist_ok,
overwrite_existing=overwrite_existing,
**vector_store_kwargs)

@staticmethod
def from_parameters(
documents: Sequence[Document] = None,
system_prompt: str = None,
query_wrapper_prompt: str = None,
enable_cost_calculator: bool = True,
embed_model: Union[str, EmbedType] = "default", # ["default", "local"]
llm_params: dict = None,
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None) -> BaseQueryEngine:
"""
DEPRECATED. Use AutoQueryEngine.from_defaults instead.
Create an AutoQueryEngine from parameters.
Parameters:
documents (Sequence[Document]): Sequence of llama_index.Document instances.
system_prompt (str): The system prompt to use for the query engine.
query_wrapper_prompt (str): The query wrapper prompt to use for the query engine.
enable_cost_calculator (bool): Flag to enable cost calculator logging.
embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI,
"local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large)
llm_params (dict): Parameters for the LLM.
vector_store_params (dict): Parameters for the vector store.
service_context_params (dict): Parameters for the service context.
query_engine_params (dict): Parameters for the query engine.
Returns:
A llama_index.BaseQueryEngine instance.
"""

# TODO: Remove this method in the next release
raise ValueError(
"AutoQueryEngine.from_parameters is deprecated. Use AutoQueryEngine.from_defaults instead.")

@staticmethod
def from_config(
config_file_path: str,
Expand Down
100 changes: 93 additions & 7 deletions autollm/auto/vector_store_index.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import os
import shutil
from typing import Optional, Sequence

from llama_index import Document, ServiceContext, StorageContext, VectorStoreIndex
from llama_index.schema import BaseNode

from autollm.utils.env_utils import on_rm_error
from autollm.utils.logging import logger


def import_vector_store_class(vector_store_class_name: str):
"""
Expand All @@ -25,24 +30,33 @@ class name and additional parameters.
@staticmethod
def from_defaults(
vector_store_type: str = "LanceDBVectorStore",
lancedb_uri: str = "./.lancedb",
lancedb_uri: str = None,
lancedb_table_name: str = "vectors",
documents: Optional[Sequence[Document]] = None,
nodes: Optional[Sequence[BaseNode]] = None,
service_context: Optional[ServiceContext] = None,
exist_ok: bool = False,
overwrite_existing: bool = False,
**kwargs) -> VectorStoreIndex:
"""
Initializes a Vector Store index from Vector Store type and additional parameters.
Initializes a Vector Store index from Vector Store type and additional parameters. Handles lancedb
path and document management according to specified behaviors.
Parameters:
vector_store_type (str): The class name of the vector store (e.g., 'LanceDBVectorStore', 'SimpleVectorStore'..)
vector_store_type (str): The class name of the vector store.
lancedb_uri (str): The URI for the LanceDB vector store.
lancedb_table_name (str): The table name for the LanceDB vector store.
documents (Optional[Sequence[Document]]): Documents to initialize the vector store index from.
nodes (Optional[Sequence[BaseNode]]): Nodes to initialize the vector store index from.
service_context (Optional[ServiceContext]): Service context to initialize the vector store index from.
**kwargs: Additional parameters for initializing the vector store
service_context (Optional[ServiceContext]): Service context for initialization.
exist_ok (bool): If True, allows adding to an existing database.
overwrite_existing (bool): If True, allows overwriting an existing database.
**kwargs: Additional parameters for initialization.
Returns:
index (VectorStoreIndex): The initialized Vector Store index instance for given vector store type and parameter set.
VectorStoreIndex: The initialized Vector Store index instance.
Raises:
ValueError: For invalid parameter combinations or missing information.
"""
if documents is None and nodes is None and vector_store_type == "SimpleVectorStore":
raise ValueError("documents or nodes must be provided for SimpleVectorStore")
Expand All @@ -55,6 +69,12 @@ def from_defaults(

# If LanceDBVectorStore, use lancedb_uri and lancedb_table_name
if vector_store_type == "LanceDBVectorStore":
lancedb_uri = AutoVectorStoreIndex._validate_and_setup_lancedb_uri(
lancedb_uri=lancedb_uri,
documents=documents,
exist_ok=exist_ok,
overwrite_existing=overwrite_existing)

vector_store = VectorStoreClass(uri=lancedb_uri, table_name=lancedb_table_name, **kwargs)
else:
vector_store = VectorStoreClass(**kwargs)
Expand Down Expand Up @@ -82,3 +102,69 @@ def from_defaults(
show_progress=True)

return index

@staticmethod
def _validate_and_setup_lancedb_uri(lancedb_uri, documents, exist_ok, overwrite_existing):
"""
Validates and sets up the lancedb_uri based on the given parameters.
Parameters:
lancedb_uri (str): The URI for the LanceDB vector store.
documents (Sequence[Document]): Documents to initialize the vector store index from.
exist_ok (bool): Flag to allow adding to an existing database.
overwrite_existing (bool): Flag to allow overwriting an existing database.
Returns:
str: The validated and potentially modified lancedb_uri.
"""
default_lancedb_uri = "./lancedb/db"

# Scenario 0: Handle no lancedb uri and no documents provided
if not documents and not lancedb_uri:
raise ValueError(
"A lancedb uri is required to connect to a database. Please provide a lancedb uri.")

# Scenario 1: Handle lancedb_uri given but no documents provided
if not documents and lancedb_uri:
# Check if the database exists
db_exists = os.path.exists(lancedb_uri)
if not db_exists:
raise ValueError(
f"No existing database found at {lancedb_uri}. Please provide a valid lancedb uri.")

# Scenario 2: Handle no lancedb uri but documents provided
if documents and not lancedb_uri:
lancedb_uri = default_lancedb_uri
lancedb_uri = AutoVectorStoreIndex._increment_lancedb_uri(lancedb_uri)
logger.info(
f"A new database is being created at {lancedb_uri}. Please provide a lancedb path to use an existing database."
)

# Scenario 3: Handle lancedb uri given and documents provided
if documents and lancedb_uri:
db_exists = os.path.exists(lancedb_uri)
if exist_ok and overwrite_existing:
if db_exists:
shutil.rmtree(lancedb_uri)
logger.info(f"Overwriting existing database at {lancedb_uri}.")
elif not exist_ok and overwrite_existing:
raise ValueError("Cannot overwrite existing database without exist_ok set to True.")
elif db_exists:
if not exist_ok:
lancedb_uri = AutoVectorStoreIndex._increment_lancedb_uri(lancedb_uri)
logger.info(f"Existing database found. Creating a new database at {lancedb_uri}.")
logger.info(
"Please use exist_ok=True to add to the existing database and overwrite_existing=True to overwrite the existing database."
)
else:
logger.info(f"Adding documents to existing database at {lancedb_uri}.")

return lancedb_uri

@staticmethod
def _increment_lancedb_uri(base_uri: str) -> str:
"""Increment the lancedb uri to create a new database."""
i = 1
while os.path.exists(f"{base_uri}_{i}"):
i += 1
return f"{base_uri}_{i}"
19 changes: 2 additions & 17 deletions autollm/utils/document_reading.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
import shutil
import stat
from pathlib import Path
from typing import Callable, List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence

from llama_index.readers.file.base import SimpleDirectoryReader
from llama_index.schema import Document

from autollm.utils.env_utils import on_rm_error
from autollm.utils.git_utils import clone_or_pull_repository
from autollm.utils.logging import logger
from autollm.utils.markdown_reader import MarkdownReader
Expand Down Expand Up @@ -65,20 +64,6 @@ def read_files_as_documents(
return documents


# From http://stackoverflow.com/a/4829285/548792
def on_rm_error(func: Callable, path: str, exc_info: Tuple):
"""
Error handler for `shutil.rmtree` to handle permission errors.
Parameters:
func (Callable): The function that raised the error.
path (str): The path to the file or directory which couldn't be removed.
exc_info (Tuple): Exception information returned by sys.exc_info().
"""
os.chmod(path, stat.S_IWRITE)
os.unlink(path)


def read_github_repo_as_documents(
git_repo_url: str,
relative_folder_path: Optional[str] = None,
Expand Down
16 changes: 16 additions & 0 deletions autollm/utils/env_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import stat
from pathlib import Path
from typing import Callable, Tuple

import yaml
from dotenv import load_dotenv
Expand Down Expand Up @@ -55,3 +57,17 @@ def load_config_and_dotenv(config_file_path: str, env_file_path: str = None) ->
config = yaml.safe_load(f)

return config


# From http://stackoverflow.com/a/4829285/548792
def on_rm_error(func: Callable, path: str, exc_info: Tuple):
"""
Error handler for `shutil.rmtree` to handle permission errors.
Parameters:
func (Callable): The function that raised the error.
path (str): The path to the file or directory which couldn't be removed.
exc_info (Tuple): Exception information returned by sys.exc_info().
"""
os.chmod(path, stat.S_IWRITE)
os.unlink(path)

0 comments on commit 2ec1b21

Please sign in to comment.