diff --git a/autollm/__init__.py b/autollm/__init__.py index 8c9a6a1a..c9a89195 100644 --- a/autollm/__init__.py +++ b/autollm/__init__.py @@ -4,7 +4,7 @@ and vector databases, along with various utility functions. """ -__version__ = '0.1.8' +__version__ = '0.1.9' __author__ = 'safevideo' __license__ = 'AGPL-3.0' diff --git a/autollm/auto/query_engine.py b/autollm/auto/query_engine.py index eb3235c1..86c345fd 100644 --- a/autollm/auto/query_engine.py +++ b/autollm/auto/query_engine.py @@ -45,6 +45,7 @@ def create_query_engine( vector_store_type: str = "LanceDBVectorStore", lancedb_uri: str = "./.lancedb", lancedb_table_name: str = "vectors", + use_async: bool = True, exist_ok: bool = False, overwrite_existing: bool = False, **vector_store_kwargs) -> BaseQueryEngine: @@ -105,6 +106,7 @@ def create_query_engine( vector_store_type=vector_store_type, lancedb_uri=lancedb_uri, lancedb_table_name=lancedb_table_name, + use_async=use_async, documents=documents, nodes=nodes, service_context=service_context, @@ -214,6 +216,7 @@ def from_defaults( vector_store_type: str = "LanceDBVectorStore", lancedb_uri: str = "./.lancedb", lancedb_table_name: str = "vectors", + use_async: bool = True, exist_ok: bool = False, overwrite_existing: bool = False, **vector_store_kwargs) -> BaseQueryEngine: @@ -277,6 +280,7 @@ def from_defaults( vector_store_type=vector_store_type, lancedb_uri=lancedb_uri, lancedb_table_name=lancedb_table_name, + use_async=use_async, exist_ok=exist_ok, overwrite_existing=overwrite_existing, **vector_store_kwargs) diff --git a/autollm/auto/vector_store_index.py b/autollm/auto/vector_store_index.py index 181a61ef..8bd28c5b 100644 --- a/autollm/auto/vector_store_index.py +++ b/autollm/auto/vector_store_index.py @@ -4,6 +4,7 @@ from llama_index import Document, ServiceContext, StorageContext, VectorStoreIndex from llama_index.schema import BaseNode +from llama_index.vector_stores.types import VectorStore from autollm.utils.env_utils import on_rm_error from autollm.utils.lancedb_vectorstore import LanceDBVectorStore @@ -35,6 +36,7 @@ def from_defaults( lancedb_table_name: str = "vectors", lancedb_api_key: Optional[str] = None, lancedb_region: Optional[str] = None, + use_async: bool = False, documents: Optional[Sequence[Document]] = None, nodes: Optional[Sequence[BaseNode]] = None, service_context: Optional[ServiceContext] = None, @@ -49,6 +51,9 @@ def from_defaults( vector_store_type (str): The class name of the vector store. lancedb_uri (str): The URI for the LanceDB vector store. lancedb_table_name (str): The table name for the LanceDB vector store. + lancedb_api_key (Optional[str]): The API key for the LanceDB CLOUD vector store. + lancedb_region (Optional[str]): The region for the LanceDB CLOUD vector store. + use_async (bool): Flag to use async embedding. (Only supported for SimpleVectorStore) documents (Optional[Sequence[Document]]): Documents to initialize the vector store index from. service_context (Optional[ServiceContext]): Service context for initialization. exist_ok (bool): If True, allows adding to an existing database. @@ -67,6 +72,9 @@ def from_defaults( if documents is not None and nodes is not None: raise ValueError("documents and nodes cannot be provided at the same time") + if use_async and vector_store_type != "SimpleVectorStore": + logger.warning("use_async is only supported for SimpleVectorStore. Ignoring use_async.") + # Initialize vector store VectorStoreClass = import_vector_store_class(vector_store_type) @@ -95,20 +103,13 @@ def from_defaults( return index # Initialize vector store index from documents or nodes - storage_context = StorageContext.from_defaults(vector_store=vector_store) - - if documents is not None: - index = VectorStoreIndex.from_documents( - documents=documents, - storage_context=storage_context, - service_context=service_context, - show_progress=True) - else: - index = VectorStoreIndex( - nodes=nodes, - storage_context=storage_context, - service_context=service_context, - show_progress=True) + index = AutoVectorStoreIndex._create_index( + documents=documents, + nodes=nodes, + vector_store=vector_store, + service_context=service_context, + use_async=use_async, + show_progress=True) return index @@ -154,7 +155,7 @@ def _validate_and_setup_lancedb_uri(lancedb_uri, documents, exist_ok, overwrite_ db_exists = os.path.exists(lancedb_uri) if exist_ok and overwrite_existing: if db_exists: - shutil.rmtree(lancedb_uri) + shutil.rmtree(lancedb_uri, onerror=on_rm_error) logger.info(f"Overwriting existing database at {lancedb_uri}.") elif not exist_ok and overwrite_existing: raise ValueError("Cannot overwrite existing database without exist_ok set to True.") @@ -174,6 +175,54 @@ def _validate_and_setup_lancedb_uri(lancedb_uri, documents, exist_ok, overwrite_ def _increment_lancedb_uri(base_uri: str) -> str: """Increment the lancedb uri to create a new database.""" i = 1 + while os.path.exists(f"{base_uri}_{i}"): + # check if the directory at path is empty. if so, use this path + if not os.listdir(f"{base_uri}_{i}"): + break i += 1 return f"{base_uri}_{i}" + + @staticmethod + def _create_index( + documents: Optional[Sequence[Document]] = None, + nodes: Optional[Sequence[BaseNode]] = None, + vector_store: Optional[VectorStore] = None, + service_context: Optional[ServiceContext] = None, + use_async: Optional[bool] = False, + show_progress: Optional[bool] = True): + """ + Sets up the index from documents or nodes. + + Parameters: + documents (Sequence[Document]): Documents to initialize the vector store index from. + nodes (Sequence[BaseNode]): Nodes to initialize the vector store index from. + vector_store (VectorStore): The vector store to initialize the index from. + service_context (ServiceContext): Service context for initialization. + use_async (bool): Flag to use async embedding. + show_progress (bool): Flag to show progress. + """ + if documents is None and nodes is None: + raise ValueError("documents or nodes must be provided") + + if documents and nodes: + raise ValueError("documents and nodes cannot be provided at the same time") + + storage_context = StorageContext.from_defaults(vector_store=vector_store) + + if documents is not None: + index = VectorStoreIndex.from_documents( + documents=documents, + storage_context=storage_context, + service_context=service_context, + use_async=use_async, + show_progress=show_progress) + else: + index = VectorStoreIndex( + nodes=nodes, + storage_context=storage_context, + service_context=service_context, + use_async=use_async, + show_progress=show_progress) + + return index