Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

implement pre-filtering feature to lancedb #187

Merged
merged 6 commits into from
Dec 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autollm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
and vector databases, along with various utility functions.
"""

__version__ = '0.1.5'
__version__ = '0.1.6'
__author__ = 'safevideo'
__license__ = 'AGPL-3.0'

Expand Down
2 changes: 0 additions & 2 deletions autollm/auto/vector_store_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ def from_defaults(
region=lancedb_region,
**kwargs)

vector_store = VectorStoreClass(uri=lancedb_uri, table_name=lancedb_table_name, **kwargs)

else:
vector_store = VectorStoreClass(**kwargs)

Expand Down
86 changes: 74 additions & 12 deletions autollm/utils/lancedb_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
from typing import Any, Optional

from dotenv import load_dotenv
from llama_index.schema import NodeRelationship, RelatedNodeInfo, TextNode
from llama_index.vector_stores import LanceDBVectorStore as LanceDBVectorStoreBase
from llama_index.vector_stores.lancedb import _to_lance_filter, _to_llama_similarities
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult
from pandas import DataFrame

load_dotenv()


class LanceDBVectorStore(LanceDBVectorStoreBase):
"""Advanced LanceDB Vector Store supporting cloud storage and prefiltering."""
from lancedb.query import LanceQueryBuilder
from lancedb.table import Table

def __init__(
self,
Expand All @@ -20,26 +27,81 @@ def __init__(
region: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Init params."""
self._setup_connection(uri, api_key, region)
self.uri = uri
self.table_name = table_name
self.nprobes = nprobes
self.refine_factor = refine_factor
self.api_key = api_key
self.region = region

def _setup_connection(self, uri: str, api_key: Optional[str] = None, region: Optional[str] = None):
"""Establishes a robust connection to LanceDB."""
api_key = api_key or os.getenv('LANCEDB_API_KEY')
region = region or os.getenv('LANCEDB_REGION')

import_err_msg = "`lancedb` package not found, please run `pip install lancedb`"
try:
import lancedb
except ImportError:
raise ImportError(import_err_msg)

# Check for API key and region in environment variables if not provided
if api_key is None:
api_key = os.getenv('LANCEDB_API_KEY')
if region is None:
region = os.getenv('LANCEDB_REGION')

if api_key and region:
self.connection = lancedb.connect(uri, api_key=api_key, region=region)
else:
self.connection = lancedb.connect(uri)

self.uri = uri
self.table_name = table_name
self.nprobes = nprobes
self.refine_factor = refine_factor
self.api_key = api_key
self.region = region
def query(
self,
query: VectorStoreQuery,
**kwargs: Any,
) -> VectorStoreQueryResult:
"""Enhanced query method to support prefiltering in LanceDB queries."""
table = self.connection.open_table(self.table_name)
lance_query = self._prepare_lance_query(query, table, **kwargs)

results = lance_query.to_df()
return self._construct_query_result(results)

def _prepare_lance_query(self, query: VectorStoreQuery, table: Table, **kwargs) -> LanceQueryBuilder:
"""Prepares the LanceDB query considering prefiltering and additional parameters."""
if query.filters is not None:
if "where" in kwargs:
raise ValueError(
"Cannot specify filter via both query and kwargs. "
"Use kwargs only for lancedb specific items that are "
"not supported via the generic query interface.")
where = _to_lance_filter(query.filters)
else:
where = kwargs.pop("where", None)
prefilter = kwargs.pop("prefilter", False)

table = self.connection.open_table(self.table_name)
lance_query = (
table.search(query.query_embedding).limit(query.similarity_top_k).where(
where, prefilter=prefilter).nprobes(self.nprobes))

if self.refine_factor is not None:
lance_query.refine_factor(self.refine_factor)

return lance_query

def _construct_query_result(self, results: DataFrame) -> VectorStoreQueryResult:
"""Constructs a VectorStoreQueryResult from a LanceDB query result."""
nodes = []

for _, row in results.iterrows():
node = TextNode(
text=row.get('text', ''), # ensure text is a string
id_=row['id'],
relationships={
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=row['doc_id']),
})
nodes.append(node)

return VectorStoreQueryResult(
nodes=nodes,
similarities=_to_llama_similarities(results),
ids=results["id"].tolist(),
)