Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

refactor and enhance Lancedb URI Handling in AutoVectorStoreIndex #182

Merged
merged 8 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,4 @@ storage/
# vscode settings
.vscode
.lancedb
lancedb/
2 changes: 1 addition & 1 deletion autollm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
and vector databases, along with various utility functions.
"""

__version__ = '0.1.4'
__version__ = '0.1.5'
__author__ = 'safevideo'
__license__ = 'AGPL-3.0'

Expand Down
85 changes: 10 additions & 75 deletions autollm/auto/query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,8 @@ def create_query_engine(
vector_store_type: str = "LanceDBVectorStore",
lancedb_uri: str = "./.lancedb",
lancedb_table_name: str = "vectors",
# Deprecated parameters
llm_params: dict = None,
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
exist_ok: bool = False,
overwrite_existing: bool = False,
**vector_store_kwargs) -> BaseQueryEngine:
"""
Create a query engine from parameters.
Expand Down Expand Up @@ -84,27 +81,6 @@ def create_query_engine(
Returns:
A llama_index.BaseQueryEngine instance.
"""
# Check for deprecated parameters
if llm_params is not None:
raise ValueError(
"llm_params is deprecated. Instead of llm_params={'llm_model': 'model_name', ...}, "
"use llm_model='model_name', llm_api_base='api_base', llm_max_tokens=1028, llm_temperature=0.1 directly as arguments."
)
if vector_store_params is not None:
raise ValueError(
"vector_store_params is deprecated. Instead of vector_store_params={'vector_store_type': 'type', ...}, "
"use vector_store_type='type', lancedb_uri='uri', lancedb_table_name='table', enable_metadata_extraction=True directly as arguments."
)
if service_context_params is not None:
raise ValueError(
"service_context_params is deprecated. Use the explicit parameters like system_prompt='prompt', "
"query_wrapper_prompt='wrapper', enable_cost_calculator=True, embed_model='model', chunk_size=512, "
"chunk_overlap=..., context_window=... directly as arguments.")
if query_engine_params is not None:
raise ValueError(
"query_engine_params is deprecated. Instead of query_engine_params={'similarity_top_k': 5, ...}, "
"use similarity_top_k=5 directly as an argument.")

llm = AutoLiteLLM.from_defaults(
model=llm_model, api_base=llm_api_base, max_tokens=llm_max_tokens, temperature=llm_temperature)

Expand Down Expand Up @@ -132,6 +108,8 @@ def create_query_engine(
documents=documents,
nodes=nodes,
service_context=service_context,
exist_ok=exist_ok,
overwrite_existing=overwrite_existing,
**vector_store_kwargs)
if refine_prompt is not None:
refine_prompt_template = PromptTemplate(refine_prompt, prompt_type=PromptType.REFINE)
Expand Down Expand Up @@ -187,7 +165,6 @@ class AutoQueryEngine:
vector_store_type="LanceDBVectorStore",
lancedb_uri="./.lancedb",
lancedb_table_name="vectors",
enable_metadata_extraction=False,
**vector_store_kwargs)
)
```
Expand Down Expand Up @@ -237,12 +214,8 @@ def from_defaults(
vector_store_type: str = "LanceDBVectorStore",
lancedb_uri: str = "./.lancedb",
lancedb_table_name: str = "vectors",
enable_metadata_extraction: bool = False,
# Deprecated parameters
llm_params: dict = None,
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None,
exist_ok: bool = False,
overwrite_existing: bool = False,
**vector_store_kwargs) -> BaseQueryEngine:
"""
Create an AutoQueryEngine from default parameters.
Expand Down Expand Up @@ -272,6 +245,8 @@ def from_defaults(
vector_store_type (str): The vector store type to use for the query engine.
lancedb_uri (str): The URI to use for the LanceDB vector store.
lancedb_table_name (str): The table name to use for the LanceDB vector store.
exist_ok (bool): Flag to allow overwriting an existing vector store.
overwrite_existing (bool): Flag to allow overwriting an existing vector store.

Returns:
A llama_index.BaseQueryEngine instance.
Expand Down Expand Up @@ -302,50 +277,10 @@ def from_defaults(
vector_store_type=vector_store_type,
lancedb_uri=lancedb_uri,
lancedb_table_name=lancedb_table_name,
enable_metadata_extraction=enable_metadata_extraction,
# Deprecated parameters
llm_params=llm_params,
vector_store_params=vector_store_params,
service_context_params=service_context_params,
query_engine_params=query_engine_params,
exist_ok=exist_ok,
overwrite_existing=overwrite_existing,
**vector_store_kwargs)

@staticmethod
def from_parameters(
documents: Sequence[Document] = None,
system_prompt: str = None,
query_wrapper_prompt: str = None,
enable_cost_calculator: bool = True,
embed_model: Union[str, EmbedType] = "default", # ["default", "local"]
llm_params: dict = None,
vector_store_params: dict = None,
service_context_params: dict = None,
query_engine_params: dict = None) -> BaseQueryEngine:
"""
DEPRECATED. Use AutoQueryEngine.from_defaults instead.

Create an AutoQueryEngine from parameters.

Parameters:
documents (Sequence[Document]): Sequence of llama_index.Document instances.
system_prompt (str): The system prompt to use for the query engine.
query_wrapper_prompt (str): The query wrapper prompt to use for the query engine.
enable_cost_calculator (bool): Flag to enable cost calculator logging.
embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI,
"local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large)
llm_params (dict): Parameters for the LLM.
vector_store_params (dict): Parameters for the vector store.
service_context_params (dict): Parameters for the service context.
query_engine_params (dict): Parameters for the query engine.

Returns:
A llama_index.BaseQueryEngine instance.
"""

# TODO: Remove this method in the next release
raise ValueError(
"AutoQueryEngine.from_parameters is deprecated. Use AutoQueryEngine.from_defaults instead.")

@staticmethod
def from_config(
config_file_path: str,
Expand Down
100 changes: 93 additions & 7 deletions autollm/auto/vector_store_index.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import os
import shutil
from typing import Optional, Sequence

from llama_index import Document, ServiceContext, StorageContext, VectorStoreIndex
from llama_index.schema import BaseNode

from autollm.utils.env_utils import on_rm_error
from autollm.utils.logging import logger


def import_vector_store_class(vector_store_class_name: str):
"""
Expand All @@ -25,24 +30,33 @@ class name and additional parameters.
@staticmethod
def from_defaults(
vector_store_type: str = "LanceDBVectorStore",
lancedb_uri: str = "./.lancedb",
lancedb_uri: str = None,
lancedb_table_name: str = "vectors",
documents: Optional[Sequence[Document]] = None,
nodes: Optional[Sequence[BaseNode]] = None,
service_context: Optional[ServiceContext] = None,
exist_ok: bool = False,
overwrite_existing: bool = False,
**kwargs) -> VectorStoreIndex:
"""
Initializes a Vector Store index from Vector Store type and additional parameters.
Initializes a Vector Store index from Vector Store type and additional parameters. Handles lancedb
path and document management according to specified behaviors.

Parameters:
vector_store_type (str): The class name of the vector store (e.g., 'LanceDBVectorStore', 'SimpleVectorStore'..)
vector_store_type (str): The class name of the vector store.
lancedb_uri (str): The URI for the LanceDB vector store.
lancedb_table_name (str): The table name for the LanceDB vector store.
documents (Optional[Sequence[Document]]): Documents to initialize the vector store index from.
nodes (Optional[Sequence[BaseNode]]): Nodes to initialize the vector store index from.
service_context (Optional[ServiceContext]): Service context to initialize the vector store index from.
**kwargs: Additional parameters for initializing the vector store
service_context (Optional[ServiceContext]): Service context for initialization.
exist_ok (bool): If True, allows adding to an existing database.
overwrite_existing (bool): If True, allows overwriting an existing database.
**kwargs: Additional parameters for initialization.

Returns:
index (VectorStoreIndex): The initialized Vector Store index instance for given vector store type and parameter set.
VectorStoreIndex: The initialized Vector Store index instance.

Raises:
ValueError: For invalid parameter combinations or missing information.
"""
if documents is None and nodes is None and vector_store_type == "SimpleVectorStore":
raise ValueError("documents or nodes must be provided for SimpleVectorStore")
Expand All @@ -55,6 +69,12 @@ def from_defaults(

# If LanceDBVectorStore, use lancedb_uri and lancedb_table_name
if vector_store_type == "LanceDBVectorStore":
lancedb_uri = AutoVectorStoreIndex._validate_and_setup_lancedb_uri(
lancedb_uri=lancedb_uri,
documents=documents,
exist_ok=exist_ok,
overwrite_existing=overwrite_existing)

vector_store = VectorStoreClass(uri=lancedb_uri, table_name=lancedb_table_name, **kwargs)
else:
vector_store = VectorStoreClass(**kwargs)
Expand Down Expand Up @@ -82,3 +102,69 @@ def from_defaults(
show_progress=True)

return index

@staticmethod
def _validate_and_setup_lancedb_uri(lancedb_uri, documents, exist_ok, overwrite_existing):
"""
Validates and sets up the lancedb_uri based on the given parameters.

Parameters:
lancedb_uri (str): The URI for the LanceDB vector store.
documents (Sequence[Document]): Documents to initialize the vector store index from.
exist_ok (bool): Flag to allow adding to an existing database.
overwrite_existing (bool): Flag to allow overwriting an existing database.

Returns:
str: The validated and potentially modified lancedb_uri.
"""
default_lancedb_uri = "./lancedb/db"

# Scenario 0: Handle no lancedb uri and no documents provided
if not documents and not lancedb_uri:
raise ValueError(
"A lancedb uri is required to connect to a database. Please provide a lancedb uri.")

# Scenario 1: Handle lancedb_uri given but no documents provided
if not documents and lancedb_uri:
# Check if the database exists
db_exists = os.path.exists(lancedb_uri)
if not db_exists:
raise ValueError(
f"No existing database found at {lancedb_uri}. Please provide a valid lancedb uri.")

# Scenario 2: Handle no lancedb uri but documents provided
if documents and not lancedb_uri:
lancedb_uri = default_lancedb_uri
lancedb_uri = AutoVectorStoreIndex._increment_lancedb_uri(lancedb_uri)
logger.info(
f"A new database is being created at {lancedb_uri}. Please provide a lancedb path to use an existing database."
)

# Scenario 3: Handle lancedb uri given and documents provided
if documents and lancedb_uri:
db_exists = os.path.exists(lancedb_uri)
if exist_ok and overwrite_existing:
if db_exists:
shutil.rmtree(lancedb_uri)
logger.info(f"Overwriting existing database at {lancedb_uri}.")
elif not exist_ok and overwrite_existing:
raise ValueError("Cannot overwrite existing database without exist_ok set to True.")
elif db_exists:
if not exist_ok:
lancedb_uri = AutoVectorStoreIndex._increment_lancedb_uri(lancedb_uri)
logger.info(f"Existing database found. Creating a new database at {lancedb_uri}.")
logger.info(
"Please use exist_ok=True to add to the existing database and overwrite_existing=True to overwrite the existing database."
)
else:
logger.info(f"Adding documents to existing database at {lancedb_uri}.")

return lancedb_uri

@staticmethod
def _increment_lancedb_uri(base_uri: str) -> str:
"""Increment the lancedb uri to create a new database."""
i = 1
while os.path.exists(f"{base_uri}_{i}"):
i += 1
return f"{base_uri}_{i}"
19 changes: 2 additions & 17 deletions autollm/utils/document_reading.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
import shutil
import stat
from pathlib import Path
from typing import Callable, List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence

from llama_index.readers.file.base import SimpleDirectoryReader
from llama_index.schema import Document

from autollm.utils.env_utils import on_rm_error
from autollm.utils.git_utils import clone_or_pull_repository
from autollm.utils.logging import logger
from autollm.utils.markdown_reader import MarkdownReader
Expand Down Expand Up @@ -65,20 +64,6 @@ def read_files_as_documents(
return documents


# From http://stackoverflow.com/a/4829285/548792
def on_rm_error(func: Callable, path: str, exc_info: Tuple):
"""
Error handler for `shutil.rmtree` to handle permission errors.

Parameters:
func (Callable): The function that raised the error.
path (str): The path to the file or directory which couldn't be removed.
exc_info (Tuple): Exception information returned by sys.exc_info().
"""
os.chmod(path, stat.S_IWRITE)
os.unlink(path)


def read_github_repo_as_documents(
git_repo_url: str,
relative_folder_path: Optional[str] = None,
Expand Down
16 changes: 16 additions & 0 deletions autollm/utils/env_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import stat
from pathlib import Path
from typing import Callable, Tuple

import yaml
from dotenv import load_dotenv
Expand Down Expand Up @@ -55,3 +57,17 @@ def load_config_and_dotenv(config_file_path: str, env_file_path: str = None) ->
config = yaml.safe_load(f)

return config


# From http://stackoverflow.com/a/4829285/548792
def on_rm_error(func: Callable, path: str, exc_info: Tuple):
"""
Error handler for `shutil.rmtree` to handle permission errors.

Parameters:
func (Callable): The function that raised the error.
path (str): The path to the file or directory which couldn't be removed.
exc_info (Tuple): Exception information returned by sys.exc_info().
"""
os.chmod(path, stat.S_IWRITE)
os.unlink(path)